Antoine-caubriere commited on
Commit
ecaafd2
1 Parent(s): 131ff81

Upload SB_ASR_FLEURS_finetuning.ipynb

Browse files
Files changed (1) hide show
  1. SB_ASR_FLEURS_finetuning.ipynb +689 -0
SB_ASR_FLEURS_finetuning.ipynb ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "49b85514-0fb6-49c6-be76-259bfeb638c6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Introduction\n",
9
+ "N'hésitez pas à nous contacter en cas de questions : [email protected] & [email protected]\n",
10
+ "\n",
11
+ "Pensez à modifier l'ensemble des PATH dans le fichier de configuration ASR_FLEURSswahili_hf.yaml et dans le code python ci-dessous (PATH_TO_YOUR_FOLDER).\n",
12
+ "\n",
13
+ "Dans le cas d'un changement de corpus (autre sous partie de FLEURS / vos propres jeux de données), pensez à modifier la taille de la couche de sortie du modèle : ASR_swahili_hf.yaml/output_neurons\n"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "id": "e62faa86-911a-48ce-82bc-8a34e13ffbc4",
19
+ "metadata": {},
20
+ "source": [
21
+ "# Préparation des données FLEURS"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "id": "c6ccf4a5-cad1-4632-8954-f4e454ff3540",
27
+ "metadata": {},
28
+ "source": [
29
+ "### 1. Installation des dépendances"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "7bb8b44e-826f-4f13-b128-eebbd18dedc5",
36
+ "metadata": {
37
+ "jupyter": {
38
+ "source_hidden": true
39
+ }
40
+ },
41
+ "outputs": [],
42
+ "source": [
43
+ "pip install datasets librosa soundfile"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "016d7646-bcca-4422-8b28-9d12d4b86c8f",
49
+ "metadata": {},
50
+ "source": [
51
+ "### 2. Téléchargement et formatage du dataset"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "da273973-05ee-4de5-830e-34d7f2220353",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "from datasets import load_dataset\n",
62
+ "from pathlib import Path\n",
63
+ "from collections import OrderedDict\n",
64
+ "from tqdm import tqdm\n",
65
+ "import shutil\n",
66
+ "import os\n",
67
+ "\n",
68
+ "dataset_write_base = \"PATH_TO_YOUR_FOLDER/data_speechbrain/\"\n",
69
+ "cache_dir = \"PATH_TO_YOUR_FOLDER/data_huggingface/\"\n",
70
+ "\n",
71
+ "if os.path.isdir(cache_dir):\n",
72
+ " print(\"rm -rf \"+cache_dir)\n",
73
+ " os.system(\"rm -rf \"+cache_dir)\n",
74
+ "\n",
75
+ "if os.path.isdir(dataset_write_base):\n",
76
+ " print(\"rm -rf \"+dataset_write_base)\n",
77
+ " os.system(\"rm -rf \"+dataset_write_base)\n",
78
+ "\n",
79
+ "# **************************************\n",
80
+ "# choix des langues à extraire de FLEURS\n",
81
+ "# **************************************\n",
82
+ "lang_dict = OrderedDict([\n",
83
+ " #(\"Afrikaans\",\"af_za\"),\n",
84
+ " #(\"Amharic\", \"am_et\"),\n",
85
+ " #(\"Fula\", \"ff_sn\"),\n",
86
+ " #(\"Ganda\", \"lg_ug\"),\n",
87
+ " #(\"Hausa\", \"ha_ng\"),\n",
88
+ " #(\"Igbo\", \"ig_ng\"),\n",
89
+ " #(\"Kamba\", \"kam_ke\"),\n",
90
+ " #(\"Lingala\", \"ln_cd\"),\n",
91
+ " #(\"Luo\", \"luo_ke\"),\n",
92
+ " #(\"Northern-Sotho\", \"nso_za\"),\n",
93
+ " #(\"Nyanja\", \"ny_mw\"),\n",
94
+ " #(\"Oromo\", \"om_et\"),\n",
95
+ " #(\"Shona\", \"sn_zw\"),\n",
96
+ " #(\"Somali\", \"so_so\"),\n",
97
+ " (\"Swahili\", \"sw_ke\"),\n",
98
+ " #(\"Umbundu\", \"umb_ao\"),\n",
99
+ " #(\"Wolof\", \"wo_sn\"), \n",
100
+ " #(\"Xhosa\", \"xh_za\"), \n",
101
+ " #(\"Yoruba\", \"yo_ng\"), \n",
102
+ " #(\"Zulu\", \"zu_za\")\n",
103
+ " ])\n",
104
+ "\n",
105
+ "# ********************************\n",
106
+ "# choix des sous-parties à traiter\n",
107
+ "# ********************************\n",
108
+ "datasets = [\"train\",\"test\",\"validation\"]\n",
109
+ "\n",
110
+ "for lang in lang_dict:\n",
111
+ " print(\"Prepare --->\", lang)\n",
112
+ " \n",
113
+ " # ********************************\n",
114
+ " # Download FLEURS from huggingface\n",
115
+ " # ********************************\n",
116
+ " fleurs_asr = load_dataset(\"google/fleurs\", lang_dict[lang],cache_dir=cache_dir, trust_remote_code=True)\n",
117
+ "\n",
118
+ " for subparts in datasets:\n",
119
+ " \n",
120
+ " used_ID = []\n",
121
+ " Path(dataset_write_base+\"/\"+lang+\"/wavs/\"+subparts).mkdir(parents=True, exist_ok=True)\n",
122
+ " \n",
123
+ " # csv header\n",
124
+ " f = open(dataset_write_base+\"/\"+lang+\"/\"+subparts+\".csv\", \"w\")\n",
125
+ " f.write(\"ID,duration,wav,spk_id,wrd\\n\")\n",
126
+ "\n",
127
+ " for uid in tqdm(range(len(fleurs_asr[subparts]))):\n",
128
+ "\n",
129
+ " # ***************\n",
130
+ " # format CSV line\n",
131
+ " # ***************\n",
132
+ " text_id = lang+\"_\"+str(fleurs_asr[subparts][uid][\"id\"])\n",
133
+ " \n",
134
+ " # some ID are duplicated (same speaker, same transcription BUT different recording)\n",
135
+ " while(text_id in used_ID):\n",
136
+ " text_id += \"_bis\"\n",
137
+ " used_ID.append(text_id)\n",
138
+ "\n",
139
+ " duration = \"{:.3f}\".format(round(float(fleurs_asr[subparts][uid][\"num_samples\"])/float(fleurs_asr[subparts][uid][\"audio\"][\"sampling_rate\"]),3))\n",
140
+ " wav_path = \"/\".join([dataset_write_base, lang, \"wavs\",subparts, fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
141
+ " spk_id = \"spk_\" + text_id\n",
142
+ " # AC : \"pseudo-normalisation\" de cas marginaux -- TODO mieux\n",
143
+ " wrd = fleurs_asr[subparts][uid][\"transcription\"].replace(',','').replace('$',' $ ').replace('\"','').replace('”','').replace(' ',' ')\n",
144
+ "\n",
145
+ " # **************\n",
146
+ " # write CSV line\n",
147
+ " # **************\n",
148
+ " f.write(text_id+\",\"+duration+\",\"+wav_path+\",\"+spk_id+\",\"+wrd+\"\\n\") \n",
149
+ "\n",
150
+ " # *******************\n",
151
+ " # Move wav from cache\n",
152
+ " # *******************\n",
153
+ " previous_path = \"/\".join(fleurs_asr[subparts][uid][\"path\"].split('/')[:-1]) + \"/\" + fleurs_asr[subparts][uid][\"audio\"][\"path\"]\n",
154
+ " new_path = \"/\".join([dataset_write_base,lang,\"wavs\",subparts,fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
155
+ " shutil.move(previous_path,new_path)\n",
156
+ " \n",
157
+ " f.close()\n",
158
+ " print(\"--->\", lang, \"done\")"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "id": "4c32e369-f0f9-4695-8c9a-aa3a9de7bf7b",
164
+ "metadata": {},
165
+ "source": [
166
+ "# Recette ASR"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "id": "77fb2c55-3f8c-4f34-81f0-ad48a632e010",
172
+ "metadata": {
173
+ "jp-MarkdownHeadingCollapsed": true
174
+ },
175
+ "source": [
176
+ "## 1. Installation des dépendances"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "fbe25635-e765-480c-8416-c48a31ee6140",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "pip install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 speechbrain transformers jdc"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "id": "6acf1f8c-2cf3-4c9c-8a45-e2580ecbee27",
192
+ "metadata": {},
193
+ "source": [
194
+ "## 2. Mise en place de la recette Speechbrain -- class Brain"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "d5e8884d-3542-40ff-a454-597078fcf97c",
200
+ "metadata": {},
201
+ "source": [
202
+ "### 2.1 Imports & logger"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "6c677f9f-6abe-423f-b4dd-fdf5ded357cd",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "import logging\n",
213
+ "import os\n",
214
+ "import sys\n",
215
+ "from pathlib import Path\n",
216
+ "\n",
217
+ "import torch\n",
218
+ "from hyperpyyaml import load_hyperpyyaml\n",
219
+ "\n",
220
+ "import speechbrain as sb\n",
221
+ "from speechbrain.utils.distributed import if_main_process, run_on_main\n",
222
+ "\n",
223
+ "import jdc\n",
224
+ "\n",
225
+ "logger = logging.getLogger(__name__)"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "id": "9698bb92-16ad-4b61-8938-c74b62ee93b2",
231
+ "metadata": {},
232
+ "source": [
233
+ "### 2.2 Création de notre classe héritant de la classe brain"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "id": "7c7cd624-6249-449b-8ee9-d4a73b7b3301",
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "# Define training procedure\n",
244
+ "class MY_SSA_ASR(sb.Brain):\n",
245
+ " print(\"\")\n",
246
+ " # define here"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "id": "ecf31c9c-15dd-4428-aa10-b3cc5e127f0d",
252
+ "metadata": {},
253
+ "source": [
254
+ "### 2.3 Définition de la fonction forward "
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "4368b488-b9d8-49ff-8ce3-78a12d46be83",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "%%add_to MY_SSA_ASR\n",
265
+ "def compute_forward(self, batch, stage):\n",
266
+ " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
267
+ " batch = batch.to(self.device)\n",
268
+ " wavs, wav_lens = batch.sig\n",
269
+ " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
270
+ "\n",
271
+ " # Downsample the inputs if specified\n",
272
+ " if hasattr(self.modules, \"downsampler\"):\n",
273
+ " wavs = self.modules.downsampler(wavs)\n",
274
+ "\n",
275
+ " # Add waveform augmentation if specified.\n",
276
+ " if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
277
+ " wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)\n",
278
+ "\n",
279
+ " # Forward pass\n",
280
+ " feats = self.modules.hubert(wavs, wav_lens)\n",
281
+ " x = self.modules.top_lin(feats)\n",
282
+ "\n",
283
+ " # Compute outputs\n",
284
+ " logits = self.modules.ctc_lin(x)\n",
285
+ " p_ctc = self.hparams.log_softmax(logits)\n",
286
+ "\n",
287
+ "\n",
288
+ " p_tokens = None\n",
289
+ " if stage == sb.Stage.VALID:\n",
290
+ " p_tokens = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, blank_id=self.hparams.blank_index)\n",
291
+ "\n",
292
+ " elif stage == sb.Stage.TEST:\n",
293
+ " p_tokens = test_searcher(p_ctc, wav_lens)\n",
294
+ "\n",
295
+ " candidates = []\n",
296
+ " scores = []\n",
297
+ "\n",
298
+ " for batch in p_tokens:\n",
299
+ " candidates.append([hyp.text for hyp in batch])\n",
300
+ " scores.append([hyp.score for hyp in batch])\n",
301
+ "\n",
302
+ " if hasattr(self.hparams, \"rescorer\"):\n",
303
+ " p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)\n",
304
+ "\n",
305
+ " return p_ctc, wav_lens, p_tokens\n"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "id": "f0052b79-5a27-4c4c-8601-7ab064e8c951",
311
+ "metadata": {},
312
+ "source": [
313
+ "### 2.4 Définition de la fonction objectives"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": null,
319
+ "id": "3608aee8-c9c3-4e34-98bc-667513fa7f7b",
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "%%add_to MY_SSA_ASR\n",
324
+ "def compute_objectives(self, predictions, batch, stage):\n",
325
+ " \"\"\"Computes the loss (CTC+NLL) given predictions and targets.\"\"\"\n",
326
+ "\n",
327
+ " p_ctc, wav_lens, predicted_tokens = predictions\n",
328
+ "\n",
329
+ " ids = batch.id\n",
330
+ " tokens, tokens_lens = batch.tokens\n",
331
+ "\n",
332
+ " # Labels must be extended if parallel augmentation or concatenated\n",
333
+ " # augmentation was performed on the input (increasing the time dimension)\n",
334
+ " if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
335
+ " (tokens, tokens_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens)\n",
336
+ "\n",
337
+ "\n",
338
+ "\n",
339
+ " # Compute loss\n",
340
+ " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
341
+ "\n",
342
+ " if stage == sb.Stage.VALID:\n",
343
+ " # Decode token terms to words\n",
344
+ " predicted_words = [\"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \") for utt_seq in predicted_tokens]\n",
345
+ " \n",
346
+ " elif stage == sb.Stage.TEST:\n",
347
+ " predicted_words = [hyp[0].text.split(\" \") for hyp in predicted_tokens]\n",
348
+ "\n",
349
+ " if stage != sb.Stage.TRAIN:\n",
350
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
351
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
352
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
353
+ "\n",
354
+ " return loss\n"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "markdown",
359
+ "id": "9a514c50-89ad-41cb-882a-23daf829a538",
360
+ "metadata": {},
361
+ "source": [
362
+ "### 2.5 définition du comportement au début d'un \"stage\""
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "id": "609814ce-3ef0-4818-a70f-cadc293c9dd2",
369
+ "metadata": {},
370
+ "outputs": [],
371
+ "source": [
372
+ "%%add_to MY_SSA_ASR\n",
373
+ "# stage gestion\n",
374
+ "def on_stage_start(self, stage, epoch):\n",
375
+ " \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
376
+ " if stage != sb.Stage.TRAIN:\n",
377
+ " self.cer_metric = self.hparams.cer_computer()\n",
378
+ " self.wer_metric = self.hparams.error_rate_computer()\n",
379
+ "\n",
380
+ " if stage == sb.Stage.TEST:\n",
381
+ " if hasattr(self.hparams, \"rescorer\"):\n",
382
+ " self.hparams.rescorer.move_rescorers_to_device()\n",
383
+ "\n"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "markdown",
388
+ "id": "55929209-c94a-4f8b-8f2e-9dd5d9de8be9",
389
+ "metadata": {},
390
+ "source": [
391
+ "### 2.6 définition du comportement à la fin d'un \"stage\""
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "id": "8f297542-10d5-47bf-9938-c141f5a99ab8",
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "%%add_to MY_SSA_ASR\n",
402
+ "def on_stage_end(self, stage, stage_loss, epoch):\n",
403
+ " \"\"\"Gets called at the end of an epoch.\"\"\"\n",
404
+ " # Compute/store important stats\n",
405
+ " stage_stats = {\"loss\": stage_loss}\n",
406
+ " if stage == sb.Stage.TRAIN:\n",
407
+ " self.train_stats = stage_stats\n",
408
+ " else:\n",
409
+ " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
410
+ " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
411
+ "\n",
412
+ " # Perform end-of-iteration things, like annealing, logging, etc.\n",
413
+ " if stage == sb.Stage.VALID:\n",
414
+ " # *******************************\n",
415
+ " # Anneal and update Learning Rate\n",
416
+ " # *******************************\n",
417
+ " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(stage_stats[\"loss\"])\n",
418
+ " old_lr_hubert, new_lr_hubert = self.hparams.lr_annealing_hubert(stage_stats[\"loss\"])\n",
419
+ " sb.nnet.schedulers.update_learning_rate(self.model_optimizer, new_lr_model)\n",
420
+ " sb.nnet.schedulers.update_learning_rate(self.hubert_optimizer, new_lr_hubert)\n",
421
+ "\n",
422
+ " # *****************\n",
423
+ " # Logs informations\n",
424
+ " # *****************\n",
425
+ " self.hparams.train_logger.log_stats(stats_meta={\"epoch\": epoch, \"lr_model\": old_lr_model, \"lr_hubert\": old_lr_hubert}, train_stats=self.train_stats, valid_stats=stage_stats)\n",
426
+ "\n",
427
+ " # ***************\n",
428
+ " # Save checkpoint\n",
429
+ " # ***************\n",
430
+ " self.checkpointer.save_and_keep_only(meta={\"WER\": stage_stats[\"WER\"]},min_keys=[\"WER\"])\n",
431
+ "\n",
432
+ " elif stage == sb.Stage.TEST:\n",
433
+ " self.hparams.train_logger.log_stats(stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},test_stats=stage_stats)\n",
434
+ " if if_main_process():\n",
435
+ " with open(self.hparams.test_wer_file, \"w\") as w:\n",
436
+ " self.wer_metric.write_stats(w)\n"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "markdown",
441
+ "id": "0c656457-6b61-4316-8199-70021f92babf",
442
+ "metadata": {},
443
+ "source": [
444
+ "### 2.7 définition de l'initialisation des optimizers"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "da8d9cb5-c5ad-4e78-83d3-e129e138a741",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "%%add_to MY_SSA_ASR\n",
455
+ "def init_optimizers(self):\n",
456
+ " \"Initializes the hubert optimizer and model optimizer\"\n",
457
+ " self.hubert_optimizer = self.hparams.hubert_opt_class(self.modules.hubert.parameters())\n",
458
+ " self.model_optimizer = self.hparams.model_opt_class(self.hparams.model.parameters())\n",
459
+ "\n",
460
+ " # save the optimizers in a dictionary\n",
461
+ " # the key will be used in `freeze_optimizers()`\n",
462
+ " self.optimizers_dict = {\"model_optimizer\": self.model_optimizer}\n",
463
+ " if not self.hparams.freeze_hubert:\n",
464
+ " self.optimizers_dict[\"hubert_optimizer\"] = self.hubert_optimizer\n",
465
+ "\n",
466
+ " if self.checkpointer is not None:\n",
467
+ " self.checkpointer.add_recoverable(\"hubert_opt\", self.hubert_optimizer)\n",
468
+ " self.checkpointer.add_recoverable(\"model_opt\", self.model_optimizer)\n"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "markdown",
473
+ "id": "cf2e730c-2faa-41f2-b98d-e5fbb2305cc2",
474
+ "metadata": {},
475
+ "source": [
476
+ "## 3 Définition de la lecture des datasets"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": null,
482
+ "id": "c5e667f7-6269-4b49-88bb-5e431762c8fe",
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": [
486
+ "def dataio_prepare(hparams):\n",
487
+ " \"\"\"This function prepares the datasets to be used in the brain class.\n",
488
+ " It also defines the data processing pipeline through user-defined functions.\n",
489
+ " \"\"\"\n",
490
+ "\n",
491
+ " # **************\n",
492
+ " # Load CSV files\n",
493
+ " # **************\n",
494
+ " data_folder = hparams[\"data_folder\"]\n",
495
+ "\n",
496
+ " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"train_csv\"],replacements={\"data_root\": data_folder})\n",
497
+ " # we sort training data to speed up training and get better results.\n",
498
+ " train_data = train_data.filtered_sorted(sort_key=\"duration\")\n",
499
+ " hparams[\"train_dataloader_opts\"][\"shuffle\"] = False # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
500
+ "\n",
501
+ " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"valid_csv\"],replacements={\"data_root\": data_folder})\n",
502
+ " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n",
503
+ "\n",
504
+ " # test is separate\n",
505
+ " test_datasets = {}\n",
506
+ " for csv_file in hparams[\"test_csv\"]:\n",
507
+ " name = Path(csv_file).stem\n",
508
+ " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=csv_file, replacements={\"data_root\": data_folder})\n",
509
+ " test_datasets[name] = test_datasets[name].filtered_sorted(sort_key=\"duration\")\n",
510
+ "\n",
511
+ " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n",
512
+ "\n",
513
+ " # *************************\n",
514
+ " # 2. Define audio pipeline:\n",
515
+ " # *************************\n",
516
+ " @sb.utils.data_pipeline.takes(\"wav\")\n",
517
+ " @sb.utils.data_pipeline.provides(\"sig\")\n",
518
+ " def audio_pipeline(wav):\n",
519
+ " sig = sb.dataio.dataio.read_audio(wav)\n",
520
+ " return sig\n",
521
+ "\n",
522
+ " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n",
523
+ "\n",
524
+ " # ************************\n",
525
+ " # 3. Define text pipeline:\n",
526
+ " # ************************\n",
527
+ " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
528
+ " \n",
529
+ " @sb.utils.data_pipeline.takes(\"wrd\")\n",
530
+ " @sb.utils.data_pipeline.provides(\"wrd\", \"char_list\", \"tokens_list\", \"tokens\")\n",
531
+ " def text_pipeline(wrd):\n",
532
+ " yield wrd\n",
533
+ " char_list = list(wrd)\n",
534
+ " yield char_list\n",
535
+ " tokens_list = label_encoder.encode_sequence(char_list)\n",
536
+ " yield tokens_list\n",
537
+ " tokens = torch.LongTensor(tokens_list)\n",
538
+ " yield tokens\n",
539
+ "\n",
540
+ " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n",
541
+ "\n",
542
+ "\n",
543
+ " # *******************************\n",
544
+ " # 4. Create or load label encoder\n",
545
+ " # *******************************\n",
546
+ " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
547
+ " special_labels = {\"blank_label\": hparams[\"blank_index\"]}\n",
548
+ " label_encoder.add_unk()\n",
549
+ " label_encoder.load_or_create(path=lab_enc_file, from_didatasets=[train_data], output_key=\"char_list\", special_labels=special_labels, sequence_input=True)\n",
550
+ "\n",
551
+ " # **************\n",
552
+ " # 5. Set output:\n",
553
+ " # **************\n",
554
+ " sb.dataio.dataset.set_output_keys(datasets,[\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],)\n",
555
+ "\n",
556
+ " return train_data, valid_data, test_datasets, label_encoder\n"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "markdown",
561
+ "id": "e97c4f20-6951-4d12-8e17-9eb818a52bb1",
562
+ "metadata": {},
563
+ "source": [
564
+ "## 4. Utilisation de la recette Créée"
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "markdown",
569
+ "id": "76b72148-6bd0-48bd-ad40-cb6f8bfd34c0",
570
+ "metadata": {},
571
+ "source": [
572
+ "### 4.1 Préparation au lancement"
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": null,
578
+ "id": "d47ec39a-5562-4a63-8243-656c9235b7a2",
579
+ "metadata": {},
580
+ "outputs": [],
581
+ "source": [
582
+ "hparams_file, run_opts, overrides = sb.parse_arguments([\"PATH_TO_YOUR_FOLDER/ASR_FLEURS-swahili_hf.yaml\"])\n",
583
+ "# create ddp_group with the right communication protocol\n",
584
+ "sb.utils.distributed.ddp_init_group(run_opts)\n",
585
+ "\n",
586
+ "# ***********************************\n",
587
+ "# Chargement du fichier de paramètres\n",
588
+ "# ***********************************\n",
589
+ "with open(hparams_file) as fin:\n",
590
+ " hparams = load_hyperpyyaml(fin, overrides)\n",
591
+ "\n",
592
+ "# ***************************\n",
593
+ "# Create experiment directory\n",
594
+ "# ***************************\n",
595
+ "sb.create_experiment_directory(experiment_directory=hparams[\"output_folder\"], hyperparams_to_save=hparams_file, overrides=overrides)\n",
596
+ "\n",
597
+ "# ***************************\n",
598
+ "# Create the datasets objects\n",
599
+ "# ***************************\n",
600
+ "train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)\n",
601
+ "\n",
602
+ "# **********************\n",
603
+ "# Trainer initialization\n",
604
+ "# **********************\n",
605
+ "asr_brain = MY_SSA_ASR(modules=hparams[\"modules\"], hparams=hparams, run_opts=run_opts, checkpointer=hparams[\"checkpointer\"])\n",
606
+ "asr_brain.tokenizer = label_encoder"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "markdown",
611
+ "id": "62ae72eb-416c-4ef0-9348-d02bbc268fbd",
612
+ "metadata": {},
613
+ "source": [
614
+ "### 4.2 Apprentissage du modèle"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": null,
620
+ "id": "d3dd30ee-89c0-40ea-a9d2-0e2b9d8c8686",
621
+ "metadata": {},
622
+ "outputs": [],
623
+ "source": [
624
+ "# ********\n",
625
+ "# Training\n",
626
+ "# ********\n",
627
+ "asr_brain.fit(asr_brain.hparams.epoch_counter, \n",
628
+ " train_data, valid_data, \n",
629
+ " train_loader_kwargs=hparams[\"train_dataloader_opts\"], \n",
630
+ " valid_loader_kwargs=hparams[\"valid_dataloader_opts\"],\n",
631
+ " )\n",
632
+ "\n"
633
+ ]
634
+ },
635
+ {
636
+ "cell_type": "markdown",
637
+ "id": "1b55af4c-c544-45ff-8435-58226218328f",
638
+ "metadata": {},
639
+ "source": [
640
+ "### 4.3 Test du Modèle"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "execution_count": null,
646
+ "id": "9cef9011-1a3e-43a4-ab16-8cfb2b57dbd9",
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "# *******\n",
651
+ "# Testing\n",
652
+ "# *******\n",
653
+ "if not os.path.exists(hparams[\"output_wer_folder\"]):\n",
654
+ " os.makedirs(hparams[\"output_wer_folder\"])\n",
655
+ "\n",
656
+ "from speechbrain.decoders.ctc import CTCBeamSearcher\n",
657
+ "\n",
658
+ "ind2lab = label_encoder.ind2lab\n",
659
+ "vocab_list = [ind2lab[x] for x in range(len(ind2lab))]\n",
660
+ "test_searcher = CTCBeamSearcher(**hparams[\"test_beam_search\"], vocab_list=vocab_list)\n",
661
+ "\n",
662
+ "for k in test_datasets.keys(): # Allow multiple evaluation throught list of test sets\n",
663
+ " asr_brain.hparams.test_wer_file = os.path.join(hparams[\"output_wer_folder\"], f\"wer_{k}.txt\")\n",
664
+ " asr_brain.evaluate(test_datasets[k], test_loader_kwargs=hparams[\"test_dataloader_opts\"], min_key=\"WER\")\n"
665
+ ]
666
+ }
667
+ ],
668
+ "metadata": {
669
+ "kernelspec": {
670
+ "display_name": "Python 3 (ipykernel)",
671
+ "language": "python",
672
+ "name": "python3"
673
+ },
674
+ "language_info": {
675
+ "codemirror_mode": {
676
+ "name": "ipython",
677
+ "version": 3
678
+ },
679
+ "file_extension": ".py",
680
+ "mimetype": "text/x-python",
681
+ "name": "python",
682
+ "nbconvert_exporter": "python",
683
+ "pygments_lexer": "ipython3",
684
+ "version": "3.10.14"
685
+ }
686
+ },
687
+ "nbformat": 4,
688
+ "nbformat_minor": 5
689
+ }