Vinsingh commited on
Commit
2fd1376
·
verified ·
1 Parent(s): b8743a0

Upload 5 files

Browse files
Files changed (4) hide show
  1. APE_tr1.csv +0 -0
  2. APE_tr2.ipynb +813 -0
  3. APR_tr2_2.ipynb +0 -0
  4. digital_green_process_data.py +62 -0
APE_tr1.csv ADDED
The diff for this file is too large to render. See raw diff
 
APE_tr2.ipynb ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "9db57e75-ba95-4e96-836a-ce2eb9689c7b",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "!pip install torch\n",
11
+ "\n",
12
+ "\n",
13
+ "from torch import Tensor\n",
14
+ "import torch\n",
15
+ "import torch.nn as nn\n",
16
+ "from torch.nn import Transformer\n",
17
+ "import math\n",
18
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
19
+ "import os\n",
20
+ "from argparse import Namespace\n",
21
+ "from collections import Counter\n",
22
+ "import json\n",
23
+ "import re\n",
24
+ "import string\n",
25
+ "import datetime\n",
26
+ "\n",
27
+ "import numpy as np\n",
28
+ "import pandas as pd\n",
29
+ "import torch\n",
30
+ "import torch.nn as nn\n",
31
+ "from torch.nn import functional as F\n",
32
+ "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
33
+ "import torch.optim as optima\n",
34
+ "from torch.utils.data import Dataset, DataLoader\n",
35
+ "\n",
36
+ "\n",
37
+ "\n",
38
+ "\n",
39
+ "\n",
40
+ "\n",
41
+ "class Vocabulary(object):\n",
42
+ " \"\"\"Class to process text and extract vocabulary for mapping\"\"\"\n",
43
+ "\n",
44
+ " def __init__(self, token_to_idx=None):\n",
45
+ " \"\"\"\n",
46
+ " Args:\n",
47
+ " token_to_idx (dict): a pre-existing map of tokens to indices\n",
48
+ " \"\"\"\n",
49
+ "\n",
50
+ " if token_to_idx is None:\n",
51
+ " token_to_idx = {}\n",
52
+ " self._token_to_idx = token_to_idx\n",
53
+ "\n",
54
+ " self._idx_to_token = {idx: token \n",
55
+ " for token, idx in self._token_to_idx.items()}\n",
56
+ " \n",
57
+ " def to_serializable(self):\n",
58
+ " \"\"\" returns a dictionary that can be serialized \"\"\"\n",
59
+ " return {'token_to_idx': self._token_to_idx}\n",
60
+ "\n",
61
+ " @classmethod\n",
62
+ " def from_serializable(cls, contents):\n",
63
+ " \"\"\" instantiates the Vocabulary from a serialized dictionary \"\"\"\n",
64
+ " return cls(**contents)\n",
65
+ "\n",
66
+ " def add_token(self, token):\n",
67
+ " \"\"\"Update mapping dicts based on the token.\n",
68
+ "\n",
69
+ " Args:\n",
70
+ " token (str): the item to add into the Vocabulary\n",
71
+ " Returns:\n",
72
+ " index (int): the integer corresponding to the token\n",
73
+ " \"\"\"\n",
74
+ " if token in self._token_to_idx:\n",
75
+ " index = self._token_to_idx[token]\n",
76
+ " else:\n",
77
+ " index = len(self._token_to_idx)\n",
78
+ " self._token_to_idx[token] = index\n",
79
+ " self._idx_to_token[index] = token\n",
80
+ " return index\n",
81
+ " \n",
82
+ " def add_many(self, tokens):\n",
83
+ " \"\"\"Add a list of tokens into the Vocabulary\n",
84
+ " \n",
85
+ " Args:\n",
86
+ " tokens (list): a list of string tokens\n",
87
+ " Returns:\n",
88
+ " indices (list): a list of indices corresponding to the tokens\n",
89
+ " \"\"\"\n",
90
+ " return [self.add_token(token) for token in tokens]\n",
91
+ "\n",
92
+ " def lookup_token(self, token):\n",
93
+ " \"\"\"Retrieve the index associated with the token \n",
94
+ " \n",
95
+ " Args:\n",
96
+ " token (str): the token to look up \n",
97
+ " Returns:\n",
98
+ " index (int): the index corresponding to the token\n",
99
+ " \"\"\"\n",
100
+ " return self._token_to_idx[token]\n",
101
+ "\n",
102
+ " def lookup_index(self, index):\n",
103
+ " \"\"\"Return the token associated with the index\n",
104
+ " \n",
105
+ " Args: \n",
106
+ " index (int): the index to look up\n",
107
+ " Returns:\n",
108
+ " token (str): the token corresponding to the index\n",
109
+ " Raises:\n",
110
+ " KeyError: if the index is not in the Vocabulary\n",
111
+ " \"\"\"\n",
112
+ " if index not in self._idx_to_token:\n",
113
+ " raise KeyError(\"the index (%d) is not in the Vocabulary\" % index)\n",
114
+ " return self._idx_to_token[index]\n",
115
+ "\n",
116
+ " def __str__(self):\n",
117
+ " return \"<Vocabulary(size=%d)>\" % len(self)\n",
118
+ "\n",
119
+ " def __len__(self):\n",
120
+ " return len(self._token_to_idx)\n",
121
+ " \n",
122
+ "\n",
123
+ "\n",
124
+ "\n",
125
+ "\n",
126
+ "class SequenceVocabulary(Vocabulary):\n",
127
+ " def __init__(self, token_to_idx=None, unk_token=\"<UNK>\",\n",
128
+ " mask_token=\"<MASK>\", begin_seq_token=\"<BEGIN>\",\n",
129
+ " end_seq_token=\"<END>\"):\n",
130
+ "\n",
131
+ " super(SequenceVocabulary, self).__init__(token_to_idx)\n",
132
+ "\n",
133
+ " self._mask_token = mask_token\n",
134
+ " self._unk_token = unk_token\n",
135
+ " self._begin_seq_token = begin_seq_token\n",
136
+ " self._end_seq_token = end_seq_token\n",
137
+ "\n",
138
+ " self.mask_index = self.add_token(self._mask_token)\n",
139
+ " self.unk_index = self.add_token(self._unk_token)\n",
140
+ " self.begin_seq_index = self.add_token(self._begin_seq_token)\n",
141
+ " self.end_seq_index = self.add_token(self._end_seq_token)\n",
142
+ "\n",
143
+ " def to_serializable(self):\n",
144
+ " contents = super(SequenceVocabulary, self).to_serializable()\n",
145
+ " contents.update({'unk_token': self._unk_token,\n",
146
+ " 'mask_token': self._mask_token,\n",
147
+ " 'begin_seq_token': self._begin_seq_token,\n",
148
+ " 'end_seq_token': self._end_seq_token})\n",
149
+ " return contents\n",
150
+ "\n",
151
+ " def lookup_token(self, token):\n",
152
+ " \"\"\"Retrieve the index associated with the token \n",
153
+ " or the UNK index if token isn't present.\n",
154
+ " \n",
155
+ " Args:\n",
156
+ " token (str): the token to look up \n",
157
+ " Returns:\n",
158
+ " index (int): the index corresponding to the token\n",
159
+ " Notes:\n",
160
+ " `unk_index` needs to be >=0 (having been added into the Vocabulary) \n",
161
+ " for the UNK functionality \n",
162
+ " \"\"\"\n",
163
+ " if self.unk_index >= 0:\n",
164
+ " return self._token_to_idx.get(token, self.unk_index)\n",
165
+ " else:\n",
166
+ " return self._token_to_idx[token]\n",
167
+ " \n",
168
+ "\n",
169
+ "\n",
170
+ "\n",
171
+ "class NMTVectorizer(object):\n",
172
+ " \"\"\" The Vectorizer which coordinates the Vocabularies and puts them to use\"\"\" \n",
173
+ " def __init__(self, source_vocab, target_vocab, max_source_length, max_target_length):\n",
174
+ " \"\"\"\n",
175
+ " Args:\n",
176
+ " source_vocab (SequenceVocabulary): maps source words to integers\n",
177
+ " target_vocab (SequenceVocabulary): maps target words to integers\n",
178
+ " max_source_length (int): the longest sequence in the source dataset\n",
179
+ " max_target_length (int): the longest sequence in the target dataset\n",
180
+ " \"\"\"\n",
181
+ " self.source_vocab = source_vocab\n",
182
+ " self.target_vocab = target_vocab\n",
183
+ " \n",
184
+ " self.max_source_length = max_source_length\n",
185
+ " self.max_target_length = max_target_length\n",
186
+ " \n",
187
+ "\n",
188
+ " def _vectorize(self, indices, vector_length=-1, mask_index=0):\n",
189
+ " \"\"\"Vectorize the provided indices\n",
190
+ " \n",
191
+ " Args:\n",
192
+ " indices (list): a list of integers that represent a sequence\n",
193
+ " vector_length (int): an argument for forcing the length of index vector\n",
194
+ " mask_index (int): the mask_index to use; almost always 0\n",
195
+ " \"\"\"\n",
196
+ " if vector_length < 0:\n",
197
+ " vector_length = len(indices)\n",
198
+ " \n",
199
+ " vector = np.zeros(vector_length, dtype=np.int64)\n",
200
+ " vector[:len(indices)] = indices\n",
201
+ " vector[len(indices):] = mask_index\n",
202
+ "\n",
203
+ " return vector\n",
204
+ " \n",
205
+ " def _get_source_indices(self, text):\n",
206
+ " \"\"\"Return the vectorized source text\n",
207
+ " \n",
208
+ " Args:\n",
209
+ " text (str): the source text; tokens should be separated by spaces\n",
210
+ " Returns:\n",
211
+ " indices (list): list of integers representing the text\n",
212
+ " \"\"\"\n",
213
+ " indices = [self.source_vocab.begin_seq_index]\n",
214
+ " indices.extend(self.source_vocab.lookup_token(token) for token in text.split(\" \"))\n",
215
+ " indices.append(self.source_vocab.end_seq_index)\n",
216
+ " return indices\n",
217
+ " \n",
218
+ " def _get_target_indices(self, text):\n",
219
+ " \"\"\"Return the vectorized source text\n",
220
+ " \n",
221
+ " Args:\n",
222
+ " text (str): the source text; tokens should be separated by spaces\n",
223
+ " Returns:\n",
224
+ " a tuple: (x_indices, y_indices)\n",
225
+ " x_indices (list): list of integers representing the observations in target decoder \n",
226
+ " y_indices (list): list of integers representing predictions in target decoder\n",
227
+ " \"\"\"\n",
228
+ " indices = [self.target_vocab.lookup_token(token) for token in text.split(\" \")]\n",
229
+ " x_indices = [self.target_vocab.begin_seq_index] + indices\n",
230
+ " y_indices = indices + [self.target_vocab.end_seq_index]\n",
231
+ " return x_indices, y_indices\n",
232
+ " \n",
233
+ " def vectorize(self, source_text, target_text, use_dataset_max_lengths=True):\n",
234
+ " \"\"\"Return the vectorized source and target text\n",
235
+ " \n",
236
+ " The vetorized source text is just the a single vector.\n",
237
+ " The vectorized target text is split into two vectors in a similar style to \n",
238
+ " the surname modeling in Chapter 7.\n",
239
+ " At each timestep, the first vector is the observation and the second vector is the target. \n",
240
+ " \n",
241
+ " \n",
242
+ " Args:\n",
243
+ " source_text (str): text from the source language\n",
244
+ " target_text (str): text from the target language\n",
245
+ " use_dataset_max_lengths (bool): whether to use the global max vector lengths\n",
246
+ " Returns:\n",
247
+ " The vectorized data point as a dictionary with the keys: \n",
248
+ " source_vector, target_x_vector, target_y_vector, source_length\n",
249
+ " \"\"\"\n",
250
+ " source_vector_length = -1\n",
251
+ " target_vector_length = -1\n",
252
+ " \n",
253
+ " if use_dataset_max_lengths:\n",
254
+ " source_vector_length = self.max_source_length + 2\n",
255
+ " target_vector_length = self.max_target_length + 1\n",
256
+ " \n",
257
+ " source_indices = self._get_source_indices(source_text)\n",
258
+ " source_vector = self._vectorize(source_indices, \n",
259
+ " vector_length=source_vector_length, \n",
260
+ " mask_index=self.source_vocab.mask_index)\n",
261
+ " \n",
262
+ " target_x_indices, target_y_indices = self._get_target_indices(target_text)\n",
263
+ " target_x_vector = self._vectorize(target_x_indices,\n",
264
+ " vector_length=target_vector_length,\n",
265
+ " mask_index=self.target_vocab.mask_index)\n",
266
+ " target_y_vector = self._vectorize(target_y_indices,\n",
267
+ " vector_length=target_vector_length,\n",
268
+ " mask_index=self.target_vocab.mask_index)\n",
269
+ " return {\"source_vector\": source_vector, \n",
270
+ " \"target_x_vector\": target_x_vector, \n",
271
+ " \"target_y_vector\": target_y_vector, \n",
272
+ " \"source_length\": len(source_indices)}\n",
273
+ " \n",
274
+ " @classmethod\n",
275
+ " def from_dataframe(cls, bitext_df):\n",
276
+ " \"\"\"Instantiate the vectorizer from the dataset dataframe\n",
277
+ " \n",
278
+ " Args:\n",
279
+ " bitext_df (pandas.DataFrame): the parallel text dataset\n",
280
+ " Returns:\n",
281
+ " an instance of the NMTVectorizer\n",
282
+ " \"\"\"\n",
283
+ " source_vocab = SequenceVocabulary()\n",
284
+ " target_vocab = SequenceVocabulary()\n",
285
+ " \n",
286
+ " max_source_length = 50\n",
287
+ " max_target_length = 25\n",
288
+ "\n",
289
+ " for _, row in bitext_df.iterrows():\n",
290
+ " source_tokens = row[\"source_language\"].split(\" \")\n",
291
+ " if len(source_tokens) > max_source_length:\n",
292
+ " max_source_length = len(source_tokens)\n",
293
+ " for token in source_tokens:\n",
294
+ " source_vocab.add_token(token)\n",
295
+ " \n",
296
+ " target_tokens = row[\"target_language\"].split(\" \")\n",
297
+ " if len(target_tokens) > max_target_length:\n",
298
+ " max_target_length = len(target_tokens)\n",
299
+ " for token in target_tokens:\n",
300
+ " target_vocab.add_token(token)\n",
301
+ " \n",
302
+ " return cls(source_vocab, target_vocab, max_source_length, max_target_length)\n",
303
+ "\n",
304
+ " @classmethod\n",
305
+ " def from_serializable(cls, contents):\n",
306
+ " source_vocab = SequenceVocabulary.from_serializable(contents[\"source_vocab\"])\n",
307
+ " target_vocab = SequenceVocabulary.from_serializable(contents[\"target_vocab\"])\n",
308
+ " \n",
309
+ " return cls(source_vocab=source_vocab, \n",
310
+ " target_vocab=target_vocab, \n",
311
+ " max_source_length=contents[\"max_source_length\"], \n",
312
+ " max_target_length=contents[\"max_target_length\"])\n",
313
+ "\n",
314
+ " def to_serializable(self):\n",
315
+ " return {\"source_vocab\": self.source_vocab.to_serializable(), \n",
316
+ " \"target_vocab\": self.target_vocab.to_serializable(), \n",
317
+ " \"max_source_length\": self.max_source_length,\n",
318
+ " \"max_target_length\": self.max_target_length}\n",
319
+ " \n",
320
+ "\n",
321
+ "\n",
322
+ "\n",
323
+ "\n",
324
+ "class NMTDataset(Dataset):\n",
325
+ " def __init__(self, text_df, vectorizer):\n",
326
+ " \"\"\"\n",
327
+ " Args:\n",
328
+ " surname_df (pandas.DataFrame): the dataset\n",
329
+ " vectorizer (SurnameVectorizer): vectorizer instatiated from dataset\n",
330
+ " \"\"\"\n",
331
+ " self.text_df = text_df\n",
332
+ " self._vectorizer = vectorizer\n",
333
+ "\n",
334
+ " self.train_df = self.text_df[self.text_df.split=='train']\n",
335
+ " self.train_size = len(self.train_df)\n",
336
+ "\n",
337
+ " self.val_df = self.text_df[self.text_df.split=='val']\n",
338
+ " self.validation_size = len(self.val_df)\n",
339
+ "\n",
340
+ " self.test_df = self.text_df[self.text_df.split=='test']\n",
341
+ " self.test_size = len(self.test_df)\n",
342
+ "\n",
343
+ " self._lookup_dict = {'train': (self.train_df, self.train_size),\n",
344
+ " 'val': (self.val_df, self.validation_size),\n",
345
+ " 'test': (self.test_df, self.test_size)}\n",
346
+ "\n",
347
+ " self.set_split('train')\n",
348
+ "\n",
349
+ " @classmethod\n",
350
+ " def load_dataset_and_make_vectorizer(cls, dataset_csv):\n",
351
+ " \"\"\"Load dataset and make a new vectorizer from scratch\n",
352
+ " \n",
353
+ " Args:\n",
354
+ " surname_csv (str): location of the dataset\n",
355
+ " Returns:\n",
356
+ " an instance of SurnameDataset\n",
357
+ " \"\"\"\n",
358
+ " text_df = pd.read_csv(dataset_csv).fillna(' ')\n",
359
+ " train_subset = text_df[text_df.split=='train']\n",
360
+ " return cls(text_df, NMTVectorizer.from_dataframe(train_subset))\n",
361
+ "\n",
362
+ " @classmethod\n",
363
+ " def load_dataset_and_load_vectorizer(cls, dataset_csv, vectorizer_filepath):\n",
364
+ " \"\"\"Load dataset and the corresponding vectorizer. \n",
365
+ " Used in the case in the vectorizer has been cached for re-use\n",
366
+ " \n",
367
+ " Args:\n",
368
+ " surname_csv (str): location of the dataset\n",
369
+ " vectorizer_filepath (str): location of the saved vectorizer\n",
370
+ " Returns:\n",
371
+ " an instance of SurnameDataset\n",
372
+ " \"\"\"\n",
373
+ " text_df = pd.read_csv(dataset_csv).fillna(' ')\n",
374
+ " vectorizer = cls.load_vectorizer_only(vectorizer_filepath)\n",
375
+ " return cls(text_df, vectorizer)\n",
376
+ "\n",
377
+ " @staticmethod\n",
378
+ " def load_vectorizer_only(vectorizer_filepath):\n",
379
+ " \"\"\"a static method for loading the vectorizer from file\n",
380
+ " \n",
381
+ " Args:\n",
382
+ " vectorizer_filepath (str): the location of the serialized vectorizer\n",
383
+ " Returns:\n",
384
+ " an instance of SurnameVectorizer\n",
385
+ " \"\"\"\n",
386
+ " with open(vectorizer_filepath) as fp:\n",
387
+ " return NMTVectorizer.from_serializable(json.load(fp))\n",
388
+ "\n",
389
+ " def save_vectorizer(self, vectorizer_filepath):\n",
390
+ " \"\"\"saves the vectorizer to disk using json\n",
391
+ " \n",
392
+ " Args:\n",
393
+ " vectorizer_filepath (str): the location to save the vectorizer\n",
394
+ " \"\"\"\n",
395
+ " with open(vectorizer_filepath, \"w\") as fp:\n",
396
+ " json.dump(self._vectorizer.to_serializable(), fp)\n",
397
+ "\n",
398
+ " def get_vectorizer(self):\n",
399
+ " \"\"\" returns the vectorizer \"\"\"\n",
400
+ " return self._vectorizer\n",
401
+ "\n",
402
+ " def set_split(self, split=\"train\"):\n",
403
+ " self._target_split = split\n",
404
+ " self._target_df, self._target_size = self._lookup_dict[split]\n",
405
+ "\n",
406
+ " def __len__(self):\n",
407
+ " return self._target_size\n",
408
+ "\n",
409
+ " def __getitem__(self, index):\n",
410
+ " \"\"\"the primary entry point method for PyTorch datasets\n",
411
+ " \n",
412
+ " Args:\n",
413
+ " index (int): the index to the data point \n",
414
+ " Returns:\n",
415
+ " a dictionary holding the data point: (x_data, y_target, class_index)\n",
416
+ " \"\"\"\n",
417
+ " row = self._target_df.iloc[index]\n",
418
+ "\n",
419
+ " vector_dict = self._vectorizer.vectorize(row.source_language, row.target_language)\n",
420
+ "\n",
421
+ " return {\"x_source\": vector_dict[\"source_vector\"], \n",
422
+ " \"x_target\": vector_dict[\"target_x_vector\"],\n",
423
+ " \"y_target\": vector_dict[\"target_y_vector\"], \n",
424
+ " \"x_source_length\": vector_dict[\"source_length\"]}\n",
425
+ " \n",
426
+ " def get_num_batches(self, batch_size):\n",
427
+ " \"\"\"Given a batch size, return the number of batches in the dataset\n",
428
+ " \n",
429
+ " Args:\n",
430
+ " batch_size (int)\n",
431
+ " Returns:\n",
432
+ " number of batches in the dataset\n",
433
+ " \"\"\"\n",
434
+ " return len(self) // batch_size\n",
435
+ " \n",
436
+ "\n",
437
+ "\n",
438
+ "\n",
439
+ "def generate_nmt_batches(dataset, batch_size, shuffle=True, \n",
440
+ " drop_last=True, device=\"cpu\"):\n",
441
+ " \"\"\"A generator function which wraps the PyTorch DataLoader. The NMT Version \"\"\"\n",
442
+ " dataloader = DataLoader(dataset=dataset, batch_size=batch_size,\n",
443
+ " shuffle=shuffle, drop_last=drop_last)\n",
444
+ "\n",
445
+ " for data_dict in dataloader:\n",
446
+ " lengths = data_dict['x_source_length'].numpy()\n",
447
+ " # Get the indices according to sorted length\n",
448
+ " sorted_length_indices = lengths.argsort()[::-1].tolist()\n",
449
+ " \n",
450
+ " # Sort the minibatch\n",
451
+ " out_data_dict = {}\n",
452
+ " for name, tensor in data_dict.items():\n",
453
+ " out_data_dict[name] = data_dict[name][sorted_length_indices].to(device)\n",
454
+ " yield out_data_dict\n",
455
+ "\n",
456
+ "\n",
457
+ "\n",
458
+ "\n",
459
+ "class PositionalEncoding(nn.Module):\n",
460
+ " def __init__(self, emb_size, drop_out, max_len:int = 200):\n",
461
+ " super(PositionalEncoding, self).__init__()\n",
462
+ " den = torch.exp(-torch.arange(0, emb_size,2)*math.log(10000)/emb_size)\n",
463
+ " pos = torch.arange(0,max_len).reshape(max_len,1)\n",
464
+ " pos_embedding = torch.zeros((max_len, emb_size))\n",
465
+ " pos_embedding[:,0::2]= torch.sin(pos*den)\n",
466
+ " pos_embedding[:,1::2] = torch.cos(pos*den)\n",
467
+ " pos_embedding = pos_embedding.unsqueeze(-2)\n",
468
+ " self.dropout = nn.Dropout(drop_out)\n",
469
+ " self.register_buffer('pos_embedding', pos_embedding)\n",
470
+ "\n",
471
+ " def forward(self, token_embedding:Tensor):\n",
472
+ " return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])\n",
473
+ "\n",
474
+ "class TokenEmbedding(nn.Module):\n",
475
+ " def __init__(self, vocab_size:int, emb_size):\n",
476
+ " super(TokenEmbedding, self).__init__()\n",
477
+ " self.embedding = nn.Embedding(vocab_size, emb_size)\n",
478
+ " self.emb_size = emb_size\n",
479
+ "\n",
480
+ " def forward(self, tokens:Tensor):\n",
481
+ " return self.embedding(tokens.long())*math.sqrt(self.emb_size)\n",
482
+ "\n",
483
+ "\n",
484
+ "class Seq2SeqTransformer(nn.Module):\n",
485
+ " def __init__(self, num_encoder_layers,num_decoder_layers, emb_size, nhead,src_vocab_size,tgt_vocab_size, dim_feedforward = 512, dropout = 0.1):\n",
486
+ " super(Seq2SeqTransformer,self).__init__()\n",
487
+ " self.transformer = Transformer(d_model = emb_size, nhead = nhead, num_encoder_layers = num_encoder_layers, num_decoder_layers = num_decoder_layers, dim_feedforward = dim_feedforward, dropout = dropout, norm_first = True)\n",
488
+ " self.generator = nn.Linear(emb_size, tgt_vocab_size)\n",
489
+ " self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)\n",
490
+ " self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)\n",
491
+ " self.positional_encoding = PositionalEncoding(emb_size, drop_out = dropout)\n",
492
+ "\n",
493
+ " def forward(self, src:Tensor, trg:Tensor, src_mask:Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):\n",
494
+ " src_emb = self.positional_encoding(self.src_tok_emb(src))\n",
495
+ " tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))\n",
496
+ " outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)\n",
497
+ " return self.generator(outs)\n",
498
+ "\n",
499
+ " def encode(self, src, src_mask):\n",
500
+ " return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)),src_mask)\n",
501
+ "\n",
502
+ " def decode(self, tgt:Tensor, memory:Tensor, tgt_mask:Tensor):\n",
503
+ " return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)\n",
504
+ "\n",
505
+ "\n",
506
+ "\n",
507
+ "\n",
508
+ "\n",
509
+ "\n",
510
+ "def set_seed_everywhere(seed, cuda):\n",
511
+ " #seed = self.seed\n",
512
+ " #cuda = self.cuda\n",
513
+ " np.random.seed(seed)\n",
514
+ " torch.manual_seed(seed)\n",
515
+ " print(seed)\n",
516
+ " if cuda:\n",
517
+ " torch.cuda.manual_seed_all(seed)\n",
518
+ "\n",
519
+ "\n",
520
+ "def generate_square_subsequent_mask(sz):\n",
521
+ " mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)\n",
522
+ " mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
523
+ " return mask\n",
524
+ "\n",
525
+ "\n",
526
+ "\n",
527
+ "def handle_dirs(save_dirs):\n",
528
+ " dirpath = save_dir\n",
529
+ " if not os.path.exists(dirpath):\n",
530
+ " os.makedirs(dirpath)\n",
531
+ "\n",
532
+ "\n",
533
+ "\n",
534
+ "def create_mask(src, tgt,PAD_IDX):\n",
535
+ " src_seq_len = src.shape[0]\n",
536
+ " tgt_seq_len = tgt.shape[0]\n",
537
+ " \n",
538
+ " tgt_mask = generate_square_subsequent_mask(tgt_seq_len)\n",
539
+ " src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)\n",
540
+ " \n",
541
+ " src_padding_mask = (src == PAD_IDX).transpose(0, 1)\n",
542
+ " tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)\n",
543
+ " return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask\n",
544
+ "\n",
545
+ "\n",
546
+ "\n",
547
+ "def train_epoch(batch_size, device, model, dataset, split_value, optimizer, PAD_IDX, loss_fn):\n",
548
+ " BATCH_SIZE = batch_size\n",
549
+ " model.train()\n",
550
+ " losses = 0\n",
551
+ " print(dataset.__len__())\n",
552
+ " train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)\n",
553
+ " #print(BATCH_SIZE,len(list(train_dataloader)))\n",
554
+ " dataset.set_split(split_value)\n",
555
+ " batch_generator = generate_nmt_batches(dataset, batch_size=BATCH_SIZE, device = device)\n",
556
+ " print(\"printing batch generator\",batch_generator)\n",
557
+ " ctr = 0\n",
558
+ " for batch_index, batch_dict in enumerate(batch_generator):\n",
559
+ " ctr = ctr+1\n",
560
+ " #optimizer.zero_grad()\n",
561
+ " #print(torch.cat((torch.transpose(batch_dict['x_source'],0,1),torch.transpose(batch_dict['x_target'],0,1),torch.transpose(batch_dict['y_target'],0,1)),1).numpy().shape)\n",
562
+ " #print(torch.transpose(batch_dict['x_target'],0,1))\n",
563
+ " #print(torch.transpose(batch_dict['y_target'],0,1))\n",
564
+ " src=torch.transpose(batch_dict['x_source'],0,1)\n",
565
+ " tgt=torch.transpose(batch_dict['y_target'],0,1)\n",
566
+ " tgt_input = tgt[:-1,:]\n",
567
+ " src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src,tgt_input, PAD_IDX)\n",
568
+ " logits = model(src,tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
569
+ " optimizer.zero_grad()\n",
570
+ " tgt_out = tgt[1:,:]\n",
571
+ " loss = loss_fn(logits.reshape(-1, logits.shape[-1]),tgt_out.reshape(-1))\n",
572
+ " loss.backward()\n",
573
+ " optimizer.step()\n",
574
+ " losses += loss.item()\n",
575
+ " if ctr%50==0:\n",
576
+ " #print('source_shape',src.shape, 'target_shape',tgt.shape)\n",
577
+ " print(\"ctr: \",ctr,\" losses: \",losses/ctr,'time',datetime.datetime.now())#,\" len_train_dataloader: \",len(list(train_dataloader)))\n",
578
+ " return losses/len(list(train_dataloader))\n",
579
+ "\n",
580
+ "\n",
581
+ "def evaluate(batch_size,device,model, dataset,split_value,PAD_IDX,loss_fn):\n",
582
+ " model.eval()\n",
583
+ " losses = 0\n",
584
+ " dataset.set_split(split_value)\n",
585
+ " val_dataloader=DataLoader(dataset, batch_size=batch_size)\n",
586
+ " batch_generator=generate_nmt_batches(dataset, batch_size=batch_size, device=device)\n",
587
+ " ctr = 0\n",
588
+ " for batch_index, batch_dict in enumerate(batch_generator):\n",
589
+ " src = torch.transpose(batch_dict['x_source'],0,1)\n",
590
+ " tgt = torch.transpose(batch_dict['y_target'],0,1)\n",
591
+ " tgt_input = tgt[:-1,:]\n",
592
+ " src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src,tgt_input, PAD_IDX)\n",
593
+ " logits = model(src,tgt_input,src_mask,tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
594
+ " tgt_out=tgt[1:,:]\n",
595
+ " loss = loss_fn(logits.reshape(-1, logits.shape[-1]),tgt_out.reshape(-1))#loss_fn(logits.reshape[-1],tgt_out.reshape[-1])\n",
596
+ " losses += loss.item()\n",
597
+ " ctr = ctr+1\n",
598
+ " print(ctr,\"validation\",losses/ctr)\n",
599
+ "\n",
600
+ " \"\"\"for src, tgt in val_dataloader:\n",
601
+ " src = src.to(DEVICE)\n",
602
+ " tgt = tgt.to(DEVICE)\n",
603
+ "\n",
604
+ " tgt_input = tgt[:-1, :]\n",
605
+ "\n",
606
+ " src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n",
607
+ "\n",
608
+ " logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
609
+ "\n",
610
+ " tgt_out = tgt[1:, :]\n",
611
+ " loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n",
612
+ " losses += loss.item()\"\"\"\n",
613
+ " return losses/len(list(val_dataloader))\n",
614
+ "\n",
615
+ "\n",
616
+ "\n",
617
+ "def greedy_decode(DEVICE, model, src, src_mask, max_len, start_symbol, EOS_IDX):\n",
618
+ " src = src.to(DEVICE)\n",
619
+ " src_mask=src_mask.to(DEVICE)\n",
620
+ " memory = model.encode(src, src_mask)\n",
621
+ " ys = torch.ones(1,1).fill_(start_symbol).type(torch.long).to(DEVICE)\n",
622
+ " for i in range(max_len):\n",
623
+ " #print(i,'ys',ys)\n",
624
+ " memory = memory.to(DEVICE)\n",
625
+ " tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(DEVICE)\n",
626
+ " #print('tgt_mask',tgt_mask)\n",
627
+ " out = model.decode(ys,memory, tgt_mask)#.squeeze()\n",
628
+ " #print(\"out\",out,'out_shape',out.shape)\n",
629
+ " out = out.transpose(0,1)\n",
630
+ " #print(\"out transpose\",out,'out_transpose_shape',out.shape)\n",
631
+ " prob = model.generator(out)[:,-1]\n",
632
+ " _, next_word = torch.max(prob, dim=1)\n",
633
+ " next_word = next_word.item()\n",
634
+ " #print('next_word = ',next_word)\n",
635
+ " ys = torch.cat([ys, torch.ones(1,1).type_as(src.data).fill_(next_word)], dim = 0)\n",
636
+ " #print('ys',ys)\n",
637
+ " if next_word == EOS_IDX:\n",
638
+ " break\n",
639
+ " return ys\n",
640
+ "\n",
641
+ "\n",
642
+ "\n",
643
+ "def translate( device,model:torch.nn.Module, src_sentence:str, BOS_IDX, EOS_IDX):\n",
644
+ " model.eval()\n",
645
+ " src= src_sentence\n",
646
+ " #print('src',src)\n",
647
+ " num_tokens = src.shape[0]\n",
648
+ " #print(num_tokens)\n",
649
+ " src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)\n",
650
+ " #print('src_mask',src_mask)\n",
651
+ " tgt_tokens = greedy_decode(device,model, src, src_mask, max_len = num_tokens, start_symbol=BOS_IDX, EOS_IDX=EOS_IDX).flatten()\n",
652
+ " return tgt_tokens\n",
653
+ "\n",
654
+ "\n",
655
+ "\n",
656
+ "\n",
657
+ "\n",
658
+ "\n",
659
+ "\n",
660
+ "\n",
661
+ "\n",
662
+ "\n",
663
+ "\n",
664
+ "\n",
665
+ "\n",
666
+ "\n",
667
+ "\n",
668
+ "\n",
669
+ "input_df = 'dataset_for_APE_hinglish_to_english2.csv'\n",
670
+ "fpath = \"nmt_IITB_APE2\"\n",
671
+ "\n",
672
+ "\n",
673
+ "#dataset = NMTDataset.load_dataset_and_make_vectorizer('IITB_dataset_1.csv')\n",
674
+ "#dataset.save_vectorizer(\"vectorizer_transformer_3layer_IITB1mill.json\")\n",
675
+ "\n",
676
+ "\n",
677
+ "\n",
678
+ "#dataloader = DataLoader(dataset=dataset, batch_size=1024,shuffle=False, drop_last=True)\n",
679
+ "\n",
680
+ "dataset_csv = 'dataset_for_APE_hinglish_to_english2.csv'\n",
681
+ "vectorizer_file = 'vectorizer_APE_2.json'\n",
682
+ "print(vectorizer_file)\n",
683
+ "model_state_file = 'APE_2.pth'\n",
684
+ "save_dir = \"nmt_DG2_FFNN8192\"#'GenV1_Transforemer_1',\n",
685
+ "print(save_dir)\n",
686
+ "reload_from_files = True\n",
687
+ "cuda = False\n",
688
+ "seed = 13\n",
689
+ "learning_rate = 8e-3\n",
690
+ "batch_size = 1024\n",
691
+ "batch_size_val = 1\n",
692
+ "num_epochs = 40\n",
693
+ "source_embedding_size = 256\n",
694
+ "target_embedding_size = 256\n",
695
+ "encoding_size = 256\n",
696
+ "use_glove = False\n",
697
+ "expand_filepaths_to_save_dir = True\n",
698
+ "early_stopping_criteria = 10\n",
699
+ "dataset_to_evaluate = 'dataset_for_APE_hinglish_to_english2.csv'\n",
700
+ "path_to_save = 'APE_1_new.csv'\n",
701
+ "saved_model_path = 'APE_1_new.pt'\n",
702
+ "file_exist = 0\n",
703
+ "existing_file_name = 'dataset_for_APE_hinglish_to_english2.csv'\n",
704
+ "\n",
705
+ "\n",
706
+ "dataset_path = fpath\n",
707
+ "existing_file_name = input_df\n",
708
+ "fname = existing_file_name\n",
709
+ "dataset_csv = fname\n",
710
+ "\n",
711
+ "\n",
712
+ "\n",
713
+ "\n",
714
+ "\n",
715
+ "\n",
716
+ "model_state_file = model_state_file\n",
717
+ "save_dir = save_dir\n",
718
+ "print(save_dir)\n",
719
+ "reload_from_files = reload_from_files\n",
720
+ "expand_filepaths_to_save_dir = expand_filepaths_to_save_dir\n",
721
+ "cuda = cuda\n",
722
+ "seed = seed\n",
723
+ "learning_rate = learning_rate\n",
724
+ "batch_size = batch_size\n",
725
+ "batch_size_val = batch_size_val\n",
726
+ "num_epochs = num_epochs\n",
727
+ "early_stopping_criteria = True#self.early_stopping_criteria\n",
728
+ "source_embedding_size = source_embedding_size\n",
729
+ "target_embedding_size = target_embedding_size\n",
730
+ "encoding_size = encoding_size\n",
731
+ "use_glove = False\n",
732
+ "catch_keyboard_interrupt = True\n",
733
+ "if expand_filepaths_to_save_dir:\n",
734
+ " vectorizer_file = os.path.join(save_dir, vectorizer_file)\n",
735
+ "model_state_file = os.path.join(save_dir, model_state_file)\n",
736
+ "if not torch.cuda.is_available():\n",
737
+ " cuda = False\n",
738
+ "device = torch.device(\"cuda\" if cuda else \"cpu\")\n",
739
+ "set_seed_everywhere(seed,cuda)\n",
740
+ "handle_dirs(save_dir)\n",
741
+ "if reload_from_files and os.path.exists(vectorizer_file):\n",
742
+ " dataset = NMTDataset.load_dataset_and_load_vectorizer(dataset_csv, vectorizer_file)\n",
743
+ " print('load_dataset_and_load_vectorizer______')\n",
744
+ "else:\n",
745
+ " dataset = NMTDataset.load_dataset_and_make_vectorizer(dataset_csv)\n",
746
+ " dataset.save_vectorizer(vectorizer_file)\n",
747
+ " print('_________load_dataset_and_make_vectorizer______')\n",
748
+ "vectorizer = dataset.get_vectorizer()\n",
749
+ "PAD_IDX = vectorizer.to_serializable()['target_vocab']['token_to_idx']['<MASK>']\n",
750
+ "BOS_IDX = vectorizer.to_serializable()['target_vocab']['token_to_idx']['<BEGIN>']\n",
751
+ "EOS_IDX = vectorizer.to_serializable()['target_vocab']['token_to_idx']['<END>']\n",
752
+ "SRC_VOCAB_SIZE = len(vectorizer.to_serializable()['source_vocab']['token_to_idx'])\n",
753
+ "TGT_VOCAB_SiZE = len(vectorizer.to_serializable()['target_vocab']['token_to_idx'])\n",
754
+ "print('target vocab size',TGT_VOCAB_SiZE)\n",
755
+ "print('dataset_size 1: ', dataset.__len__(), dataset_path, dataset_csv)\n",
756
+ "print(' dataset csv length',len(pd.read_csv(dataset_csv)))\n",
757
+ "EMB_SIZE = 256\n",
758
+ "NHEAD = 16\n",
759
+ "FFN_HID_DIM =8192\n",
760
+ "BATCH_SIZE = 128\n",
761
+ "NUM_ENCODER_LAYERS = 3\n",
762
+ "NUM_DECODER_LAYERS = 3\n",
763
+ "batch_size = BATCH_SIZE\n",
764
+ "transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SiZE, FFN_HID_DIM)\n",
765
+ "transformer = transformer.to(DEVICE)\n",
766
+ "loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)\n",
767
+ "optimizer = torch.optim.Adam(transformer.parameters(), lr=0.004, betas = (0.99, 0.99), eps = 1e-9)\n",
768
+ "from timeit import default_timer as timer\n",
769
+ "NUM_EPOCHS = num_epochs\n",
770
+ "for epoch in range(1, NUM_EPOCHS+1):\n",
771
+ " print(\"==================Training started==================\",epoch)\n",
772
+ " start_time = timer()\n",
773
+ " split_value_train = 'train'\n",
774
+ " split_value_validate = 'val'\n",
775
+ " train_loss = train_epoch(batch_size,device,transformer, dataset, split_value_train, optimizer, PAD_IDX, loss_fn)\n",
776
+ " end_time = timer()\n",
777
+ " torch.save(transformer,'epoch'+str(epoch)+'_APE_2_new.pt')\n",
778
+ "#torch.save(transformer, save_dir+\"/\"+saved_model_path+\"_epoch\")\n",
779
+ " #val_loss = evaluate(batch_size,device,transformer, dataset, split_value_validate, PAD_IDX, loss_fn)\n",
780
+ "torch.save(transformer, save_dir+\"/\"+saved_model_path)\n"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "code",
785
+ "execution_count": null,
786
+ "id": "37a50cf7-d754-4c19-aaa5-4e094cfd87e6",
787
+ "metadata": {},
788
+ "outputs": [],
789
+ "source": []
790
+ }
791
+ ],
792
+ "metadata": {
793
+ "kernelspec": {
794
+ "display_name": "Python 3 (ipykernel)",
795
+ "language": "python",
796
+ "name": "python3"
797
+ },
798
+ "language_info": {
799
+ "codemirror_mode": {
800
+ "name": "ipython",
801
+ "version": 3
802
+ },
803
+ "file_extension": ".py",
804
+ "mimetype": "text/x-python",
805
+ "name": "python",
806
+ "nbconvert_exporter": "python",
807
+ "pygments_lexer": "ipython3",
808
+ "version": "3.11.9"
809
+ }
810
+ },
811
+ "nbformat": 4,
812
+ "nbformat_minor": 5
813
+ }
APR_tr2_2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
digital_green_process_data.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from datasets import Dataset, DatasetDict, Audio
4
+ import soundfile as sf
5
+ import numpy as np
6
+ from sklearn.model_selection import train_test_split
7
+
8
+ # Paths
9
+ audio_folder = '/home/azureuser/data2/dg_16/' # Path where your audio files are stored
10
+ csv_file = 'digital_green_recordings.csv' # Path to the CSV that contains audio paths and transcripts
11
+
12
+ # Read your CSV file (assumes it has columns: 'path' and 'transcript')
13
+ df = pd.read_csv(csv_file, sep="$")
14
+
15
+ # Create a new column for client_id (random or default if you don’t have speaker info)
16
+ df['client_id'] = ['speaker_' + str(i) for i in range(len(df))]
17
+
18
+ # If your CSV has relative paths, ensure the paths are correct
19
+ df['path'] = df['path'].apply(lambda x: os.path.join(audio_folder, x))
20
+
21
+ # Add additional columns needed for the Common Voice format (can be optional)
22
+ df['up_votes'] = 0
23
+ df['down_votes'] = 0
24
+ df['age'] = None
25
+ df['gender'] = None
26
+ df['accent'] = None
27
+
28
+ # Function to load and possibly convert audio to mono
29
+ def load_audio(file_path):
30
+ # Load audio file
31
+ audio, sr = sf.read(file_path)
32
+ # Convert to mono if stereo
33
+ if len(audio.shape) > 1:
34
+ audio = np.mean(audio, axis=1)
35
+ return {'audio': {'array': audio, 'sampling_rate': sr}}
36
+
37
+ # Apply audio loading function to DataFrame
38
+ df['audio'] = df['path'].apply(lambda x: load_audio(x))
39
+
40
+ train_df, test_df = train_test_split(df, test_size=0.2, random_state=42) # Adjust test_size as needed
41
+
42
+ # Convert DataFrames to Hugging Face Datasets
43
+ train_dataset = Dataset.from_pandas(train_df)
44
+ test_dataset = Dataset.from_pandas(test_df)
45
+
46
+ # Cast the 'audio' column to the 'audio' type
47
+ train_dataset = train_dataset.cast_column('audio', Audio())
48
+ test_dataset = test_dataset.cast_column('audio', Audio())
49
+
50
+ # Create a DatasetDict to simulate train/test/validation splits if needed
51
+ dataset_dict = DatasetDict({
52
+ 'train': train_dataset,
53
+ 'test': test_dataset # If you have separate splits, add them here (e.g., 'train', 'test', 'validation')
54
+ })
55
+
56
+ # Save the dataset (optional) for future use
57
+ dataset_dict.save_to_disk('data2/digital_green_data')
58
+
59
+ # Print a sample from the dataset
60
+ print(dataset_dict['train'][0])
61
+
62
+ print(dataset_dict['test'][0])