lighteternal commited on
Commit
9a23d52
·
1 Parent(s): 0d2ad8d

fixed config.json

Browse files
.ipynb_checkpoints/ASR_Inference-checkpoint.ipynb DELETED
@@ -1,960 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {
7
- "ExecuteTime": {
8
- "end_time": "2021-03-17T11:10:25.794375Z",
9
- "start_time": "2021-03-17T11:10:24.301013Z"
10
- }
11
- },
12
- "outputs": [
13
- {
14
- "name": "stderr",
15
- "output_type": "stream",
16
- "text": [
17
- "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
18
- " warnings.warn(\n"
19
- ]
20
- }
21
- ],
22
- "source": [
23
- "from transformers import Wav2Vec2ForCTC\n",
24
- "from transformers import Wav2Vec2Processor\n",
25
- "from datasets import load_dataset, load_metric\n",
26
- "import re\n",
27
- "import torchaudio\n",
28
- "import librosa\n",
29
- "import numpy as np\n",
30
- "from datasets import load_dataset, load_metric\n",
31
- "import torch"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": 2,
37
- "metadata": {
38
- "ExecuteTime": {
39
- "end_time": "2021-03-17T11:10:29.608803Z",
40
- "start_time": "2021-03-17T11:10:29.599700Z"
41
- }
42
- },
43
- "outputs": [],
44
- "source": [
45
- "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
46
- "\n",
47
- "def remove_special_characters(batch):\n",
48
- " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n",
49
- " return batch\n",
50
- "\n",
51
- "def speech_file_to_array_fn(batch):\n",
52
- " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
53
- " batch[\"speech\"] = speech_array[0].numpy()\n",
54
- " batch[\"sampling_rate\"] = sampling_rate\n",
55
- " batch[\"target_text\"] = batch[\"text\"]\n",
56
- " return batch\n",
57
- "\n",
58
- "def resample(batch):\n",
59
- " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n",
60
- " batch[\"sampling_rate\"] = 16_000\n",
61
- " return batch\n",
62
- "\n",
63
- "def prepare_dataset(batch):\n",
64
- " # check that all files have the correct sampling rate\n",
65
- " assert (\n",
66
- " len(set(batch[\"sampling_rate\"])) == 1\n",
67
- " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
68
- "\n",
69
- " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
70
- " \n",
71
- " with processor.as_target_processor():\n",
72
- " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
73
- " return batch"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": 4,
79
- "metadata": {
80
- "ExecuteTime": {
81
- "end_time": "2021-03-17T11:11:02.120225Z",
82
- "start_time": "2021-03-17T11:10:56.182488Z"
83
- }
84
- },
85
- "outputs": [
86
- {
87
- "name": "stderr",
88
- "output_type": "stream",
89
- "text": [
90
- "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n"
91
- ]
92
- }
93
- ],
94
- "source": [
95
- "model = Wav2Vec2ForCTC.from_pretrained(\".\").to(\"cuda\")\n",
96
- "processor = Wav2Vec2Processor.from_pretrained(\".\")"
97
- ]
98
- },
99
- {
100
- "cell_type": "code",
101
- "execution_count": 6,
102
- "metadata": {
103
- "ExecuteTime": {
104
- "end_time": "2021-03-17T11:12:18.847005Z",
105
- "start_time": "2021-03-17T11:12:14.919077Z"
106
- }
107
- },
108
- "outputs": [
109
- {
110
- "name": "stderr",
111
- "output_type": "stream",
112
- "text": [
113
- "Using custom data configuration el-afd0a157f05ee080\n"
114
- ]
115
- },
116
- {
117
- "name": "stdout",
118
- "output_type": "stream",
119
- "text": [
120
- "Downloading and preparing dataset common_voice/el (download: 363.89 MiB, generated: 4.75 MiB, post-processed: Unknown size, total: 368.64 MiB) to /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f...\n"
121
- ]
122
- },
123
- {
124
- "data": {
125
- "application/vnd.jupyter.widget-view+json": {
126
- "model_id": "",
127
- "version_major": 2,
128
- "version_minor": 0
129
- },
130
- "text/plain": [
131
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
132
- ]
133
- },
134
- "metadata": {},
135
- "output_type": "display_data"
136
- },
137
- {
138
- "name": "stdout",
139
- "output_type": "stream",
140
- "text": [
141
- "\r"
142
- ]
143
- },
144
- {
145
- "data": {
146
- "application/vnd.jupyter.widget-view+json": {
147
- "model_id": "",
148
- "version_major": 2,
149
- "version_minor": 0
150
- },
151
- "text/plain": [
152
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
153
- ]
154
- },
155
- "metadata": {},
156
- "output_type": "display_data"
157
- },
158
- {
159
- "name": "stdout",
160
- "output_type": "stream",
161
- "text": [
162
- "\r"
163
- ]
164
- },
165
- {
166
- "data": {
167
- "application/vnd.jupyter.widget-view+json": {
168
- "model_id": "",
169
- "version_major": 2,
170
- "version_minor": 0
171
- },
172
- "text/plain": [
173
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
174
- ]
175
- },
176
- "metadata": {},
177
- "output_type": "display_data"
178
- },
179
- {
180
- "name": "stdout",
181
- "output_type": "stream",
182
- "text": [
183
- "\r"
184
- ]
185
- },
186
- {
187
- "data": {
188
- "application/vnd.jupyter.widget-view+json": {
189
- "model_id": "",
190
- "version_major": 2,
191
- "version_minor": 0
192
- },
193
- "text/plain": [
194
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
195
- ]
196
- },
197
- "metadata": {},
198
- "output_type": "display_data"
199
- },
200
- {
201
- "name": "stdout",
202
- "output_type": "stream",
203
- "text": [
204
- "\r"
205
- ]
206
- },
207
- {
208
- "data": {
209
- "application/vnd.jupyter.widget-view+json": {
210
- "model_id": "",
211
- "version_major": 2,
212
- "version_minor": 0
213
- },
214
- "text/plain": [
215
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
216
- ]
217
- },
218
- "metadata": {},
219
- "output_type": "display_data"
220
- },
221
- {
222
- "name": "stdout",
223
- "output_type": "stream",
224
- "text": [
225
- "\r",
226
- "Dataset common_voice downloaded and prepared to /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. Subsequent calls will reuse this data.\n"
227
- ]
228
- }
229
- ],
230
- "source": [
231
- "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")"
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": 7,
237
- "metadata": {
238
- "ExecuteTime": {
239
- "end_time": "2021-03-17T11:12:18.860240Z",
240
- "start_time": "2021-03-17T11:12:18.857252Z"
241
- }
242
- },
243
- "outputs": [],
244
- "source": [
245
- "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])"
246
- ]
247
- },
248
- {
249
- "cell_type": "code",
250
- "execution_count": 8,
251
- "metadata": {
252
- "ExecuteTime": {
253
- "end_time": "2021-03-17T11:12:18.928497Z",
254
- "start_time": "2021-03-17T11:12:18.869198Z"
255
- }
256
- },
257
- "outputs": [
258
- {
259
- "data": {
260
- "application/vnd.jupyter.widget-view+json": {
261
- "model_id": "9869698af86e44bca75c4252996ff1a3",
262
- "version_major": 2,
263
- "version_minor": 0
264
- },
265
- "text/plain": [
266
- "HBox(children=(IntProgress(value=0, max=1522), HTML(value='')))"
267
- ]
268
- },
269
- "metadata": {},
270
- "output_type": "display_data"
271
- },
272
- {
273
- "name": "stdout",
274
- "output_type": "stream",
275
- "text": [
276
- "\n"
277
- ]
278
- }
279
- ],
280
- "source": [
281
- "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])"
282
- ]
283
- },
284
- {
285
- "cell_type": "code",
286
- "execution_count": 9,
287
- "metadata": {
288
- "ExecuteTime": {
289
- "end_time": "2021-03-17T11:12:40.824595Z",
290
- "start_time": "2021-03-17T11:12:18.937930Z"
291
- }
292
- },
293
- "outputs": [
294
- {
295
- "data": {
296
- "application/vnd.jupyter.widget-view+json": {
297
- "model_id": "d232b2bb009543e0bb2542bce273c554",
298
- "version_major": 2,
299
- "version_minor": 0
300
- },
301
- "text/plain": [
302
- "HBox(children=(IntProgress(value=0, max=1522), HTML(value='')))"
303
- ]
304
- },
305
- "metadata": {},
306
- "output_type": "display_data"
307
- },
308
- {
309
- "name": "stdout",
310
- "output_type": "stream",
311
- "text": [
312
- "\n"
313
- ]
314
- }
315
- ],
316
- "source": [
317
- "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": 10,
323
- "metadata": {
324
- "ExecuteTime": {
325
- "end_time": "2021-03-17T11:13:18.078738Z",
326
- "start_time": "2021-03-17T11:12:40.834398Z"
327
- }
328
- },
329
- "outputs": [
330
- {
331
- "name": "stdout",
332
- "output_type": "stream",
333
- "text": [
334
- " "
335
- ]
336
- },
337
- {
338
- "data": {
339
- "application/vnd.jupyter.widget-view+json": {
340
- "model_id": "ffd787bc4ed048ae8f4977f2c539bedb",
341
- "version_major": 2,
342
- "version_minor": 0
343
- },
344
- "text/plain": [
345
- "HBox(children=(IntProgress(value=0, description='#0', max=191, style=ProgressStyle(description_width='initial'…"
346
- ]
347
- },
348
- "metadata": {},
349
- "output_type": "display_data"
350
- },
351
- {
352
- "data": {
353
- "application/vnd.jupyter.widget-view+json": {
354
- "model_id": "79c51995d4f84ad8812230480d14b8cd",
355
- "version_major": 2,
356
- "version_minor": 0
357
- },
358
- "text/plain": [
359
- "HBox(children=(IntProgress(value=0, description='#2', max=190, style=ProgressStyle(description_width='initial'…"
360
- ]
361
- },
362
- "metadata": {},
363
- "output_type": "display_data"
364
- },
365
- {
366
- "data": {
367
- "application/vnd.jupyter.widget-view+json": {
368
- "model_id": "52963d9cfd814346af070b2cc4e105cf",
369
- "version_major": 2,
370
- "version_minor": 0
371
- },
372
- "text/plain": [
373
- "HBox(children=(IntProgress(value=0, description='#5', max=190, style=ProgressStyle(description_width='initial'…"
374
- ]
375
- },
376
- "metadata": {},
377
- "output_type": "display_data"
378
- },
379
- {
380
- "data": {
381
- "application/vnd.jupyter.widget-view+json": {
382
- "model_id": "3b940160575143c7acfa142564e9f7d2",
383
- "version_major": 2,
384
- "version_minor": 0
385
- },
386
- "text/plain": [
387
- "HBox(children=(IntProgress(value=0, description='#3', max=190, style=ProgressStyle(description_width='initial'…"
388
- ]
389
- },
390
- "metadata": {},
391
- "output_type": "display_data"
392
- },
393
- {
394
- "data": {
395
- "application/vnd.jupyter.widget-view+json": {
396
- "model_id": "aa540f67ba894d7aa64e12fcdfab5ce0",
397
- "version_major": 2,
398
- "version_minor": 0
399
- },
400
- "text/plain": [
401
- "HBox(children=(IntProgress(value=0, description='#1', max=191, style=ProgressStyle(description_width='initial'…"
402
- ]
403
- },
404
- "metadata": {},
405
- "output_type": "display_data"
406
- },
407
- {
408
- "data": {
409
- "application/vnd.jupyter.widget-view+json": {
410
- "model_id": "4962bdefdbbc44a7a44591480d8d6406",
411
- "version_major": 2,
412
- "version_minor": 0
413
- },
414
- "text/plain": [
415
- "HBox(children=(IntProgress(value=0, description='#4', max=190, style=ProgressStyle(description_width='initial'…"
416
- ]
417
- },
418
- "metadata": {},
419
- "output_type": "display_data"
420
- },
421
- {
422
- "data": {
423
- "application/vnd.jupyter.widget-view+json": {
424
- "model_id": "e77f088bfe5644548fe2c4277d0c86da",
425
- "version_major": 2,
426
- "version_minor": 0
427
- },
428
- "text/plain": [
429
- "HBox(children=(IntProgress(value=0, description='#7', max=190, style=ProgressStyle(description_width='initial'…"
430
- ]
431
- },
432
- "metadata": {},
433
- "output_type": "display_data"
434
- },
435
- {
436
- "data": {
437
- "application/vnd.jupyter.widget-view+json": {
438
- "model_id": "5827f93e99994fe9919aac53f0fb9444",
439
- "version_major": 2,
440
- "version_minor": 0
441
- },
442
- "text/plain": [
443
- "HBox(children=(IntProgress(value=0, description='#6', max=190, style=ProgressStyle(description_width='initial'…"
444
- ]
445
- },
446
- "metadata": {},
447
- "output_type": "display_data"
448
- },
449
- {
450
- "name": "stdout",
451
- "output_type": "stream",
452
- "text": [
453
- "\n",
454
- "\n",
455
- "\n",
456
- "\n",
457
- "\n",
458
- "\n",
459
- "\n",
460
- "\n"
461
- ]
462
- }
463
- ],
464
- "source": [
465
- "common_voice_test = common_voice_test.map(resample, num_proc=8)"
466
- ]
467
- },
468
- {
469
- "cell_type": "code",
470
- "execution_count": 11,
471
- "metadata": {
472
- "ExecuteTime": {
473
- "end_time": "2021-03-17T11:13:25.145155Z",
474
- "start_time": "2021-03-17T11:13:18.091929Z"
475
- }
476
- },
477
- "outputs": [
478
- {
479
- "name": "stderr",
480
- "output_type": "stream",
481
- "text": [
482
- "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
483
- " return array(a, dtype, copy=False, order=order)\n"
484
- ]
485
- },
486
- {
487
- "name": "stdout",
488
- "output_type": "stream",
489
- "text": [
490
- " "
491
- ]
492
- },
493
- {
494
- "data": {
495
- "application/vnd.jupyter.widget-view+json": {
496
- "model_id": "ae326a173a044b1494793e2a70d76a87",
497
- "version_major": 2,
498
- "version_minor": 0
499
- },
500
- "text/plain": [
501
- "HBox(children=(IntProgress(value=0, description='#0', max=24, style=ProgressStyle(description_width='initial')…"
502
- ]
503
- },
504
- "metadata": {},
505
- "output_type": "display_data"
506
- },
507
- {
508
- "data": {
509
- "application/vnd.jupyter.widget-view+json": {
510
- "model_id": "21ab1ef2af5a4a4fb23c68b0c5cf32f8",
511
- "version_major": 2,
512
- "version_minor": 0
513
- },
514
- "text/plain": [
515
- "HBox(children=(IntProgress(value=0, description='#1', max=24, style=ProgressStyle(description_width='initial')…"
516
- ]
517
- },
518
- "metadata": {},
519
- "output_type": "display_data"
520
- },
521
- {
522
- "data": {
523
- "application/vnd.jupyter.widget-view+json": {
524
- "model_id": "d331c5f4f888477daceffe370f6cd89f",
525
- "version_major": 2,
526
- "version_minor": 0
527
- },
528
- "text/plain": [
529
- "HBox(children=(IntProgress(value=0, description='#3', max=24, style=ProgressStyle(description_width='initial')…"
530
- ]
531
- },
532
- "metadata": {},
533
- "output_type": "display_data"
534
- },
535
- {
536
- "data": {
537
- "application/vnd.jupyter.widget-view+json": {
538
- "model_id": "6fa790118aa340e4afb9f83e71403a13",
539
- "version_major": 2,
540
- "version_minor": 0
541
- },
542
- "text/plain": [
543
- "HBox(children=(IntProgress(value=0, description='#2', max=24, style=ProgressStyle(description_width='initial')…"
544
- ]
545
- },
546
- "metadata": {},
547
- "output_type": "display_data"
548
- },
549
- {
550
- "data": {
551
- "application/vnd.jupyter.widget-view+json": {
552
- "model_id": "c8092e2f59a9404596dc2bab206edf2c",
553
- "version_major": 2,
554
- "version_minor": 0
555
- },
556
- "text/plain": [
557
- "HBox(children=(IntProgress(value=0, description='#5', max=24, style=ProgressStyle(description_width='initial')…"
558
- ]
559
- },
560
- "metadata": {},
561
- "output_type": "display_data"
562
- },
563
- {
564
- "data": {
565
- "application/vnd.jupyter.widget-view+json": {
566
- "model_id": "20f913f0caf8401098743b9e5051fc52",
567
- "version_major": 2,
568
- "version_minor": 0
569
- },
570
- "text/plain": [
571
- "HBox(children=(IntProgress(value=0, description='#4', max=24, style=ProgressStyle(description_width='initial')…"
572
- ]
573
- },
574
- "metadata": {},
575
- "output_type": "display_data"
576
- },
577
- {
578
- "data": {
579
- "application/vnd.jupyter.widget-view+json": {
580
- "model_id": "7c7e15e24384494cb49a72106ce41ccd",
581
- "version_major": 2,
582
- "version_minor": 0
583
- },
584
- "text/plain": [
585
- "HBox(children=(IntProgress(value=0, description='#6', max=24, style=ProgressStyle(description_width='initial')…"
586
- ]
587
- },
588
- "metadata": {},
589
- "output_type": "display_data"
590
- },
591
- {
592
- "data": {
593
- "application/vnd.jupyter.widget-view+json": {
594
- "model_id": "73245add55e24ee2a6dbe0713d5073d9",
595
- "version_major": 2,
596
- "version_minor": 0
597
- },
598
- "text/plain": [
599
- "HBox(children=(IntProgress(value=0, description='#7', max=24, style=ProgressStyle(description_width='initial')…"
600
- ]
601
- },
602
- "metadata": {},
603
- "output_type": "display_data"
604
- },
605
- {
606
- "name": "stdout",
607
- "output_type": "stream",
608
- "text": [
609
- "\n",
610
- "\n",
611
- "\n",
612
- "\n",
613
- "\n",
614
- "\n",
615
- "\n",
616
- "\n"
617
- ]
618
- }
619
- ],
620
- "source": [
621
- "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)"
622
- ]
623
- },
624
- {
625
- "cell_type": "code",
626
- "execution_count": 12,
627
- "metadata": {
628
- "ExecuteTime": {
629
- "end_time": "2021-03-17T11:14:12.721500Z",
630
- "start_time": "2021-03-17T11:14:08.198478Z"
631
- }
632
- },
633
- "outputs": [
634
- {
635
- "name": "stderr",
636
- "output_type": "stream",
637
- "text": [
638
- "Using custom data configuration el-ac779bf2c9f7c09b\n"
639
- ]
640
- },
641
- {
642
- "name": "stdout",
643
- "output_type": "stream",
644
- "text": [
645
- "Downloading and preparing dataset common_voice/el (download: 363.89 MiB, generated: 4.75 MiB, post-processed: Unknown size, total: 368.64 MiB) to /home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f...\n"
646
- ]
647
- },
648
- {
649
- "data": {
650
- "application/vnd.jupyter.widget-view+json": {
651
- "model_id": "",
652
- "version_major": 2,
653
- "version_minor": 0
654
- },
655
- "text/plain": [
656
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
657
- ]
658
- },
659
- "metadata": {},
660
- "output_type": "display_data"
661
- },
662
- {
663
- "name": "stdout",
664
- "output_type": "stream",
665
- "text": [
666
- "\r"
667
- ]
668
- },
669
- {
670
- "data": {
671
- "application/vnd.jupyter.widget-view+json": {
672
- "model_id": "",
673
- "version_major": 2,
674
- "version_minor": 0
675
- },
676
- "text/plain": [
677
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
678
- ]
679
- },
680
- "metadata": {},
681
- "output_type": "display_data"
682
- },
683
- {
684
- "name": "stdout",
685
- "output_type": "stream",
686
- "text": [
687
- "\r"
688
- ]
689
- },
690
- {
691
- "data": {
692
- "application/vnd.jupyter.widget-view+json": {
693
- "model_id": "",
694
- "version_major": 2,
695
- "version_minor": 0
696
- },
697
- "text/plain": [
698
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
699
- ]
700
- },
701
- "metadata": {},
702
- "output_type": "display_data"
703
- },
704
- {
705
- "name": "stdout",
706
- "output_type": "stream",
707
- "text": [
708
- "\r"
709
- ]
710
- },
711
- {
712
- "data": {
713
- "application/vnd.jupyter.widget-view+json": {
714
- "model_id": "",
715
- "version_major": 2,
716
- "version_minor": 0
717
- },
718
- "text/plain": [
719
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
720
- ]
721
- },
722
- "metadata": {},
723
- "output_type": "display_data"
724
- },
725
- {
726
- "name": "stdout",
727
- "output_type": "stream",
728
- "text": [
729
- "\r"
730
- ]
731
- },
732
- {
733
- "data": {
734
- "application/vnd.jupyter.widget-view+json": {
735
- "model_id": "",
736
- "version_major": 2,
737
- "version_minor": 0
738
- },
739
- "text/plain": [
740
- "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
741
- ]
742
- },
743
- "metadata": {},
744
- "output_type": "display_data"
745
- },
746
- {
747
- "name": "stdout",
748
- "output_type": "stream",
749
- "text": [
750
- "\r",
751
- "Dataset common_voice downloaded and prepared to /home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f. Subsequent calls will reuse this data.\n"
752
- ]
753
- }
754
- ],
755
- "source": [
756
- "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")"
757
- ]
758
- },
759
- {
760
- "cell_type": "code",
761
- "execution_count": 11,
762
- "metadata": {
763
- "ExecuteTime": {
764
- "end_time": "2021-03-14T19:33:39.856174Z",
765
- "start_time": "2021-03-14T19:33:14.402825Z"
766
- }
767
- },
768
- "outputs": [],
769
- "source": [
770
- "# Change this value to try inference on different CommonVoice extracts\n",
771
- "example = 678\n",
772
- "\n",
773
- "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
774
- "\n",
775
- "logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
776
- "\n",
777
- "pred_ids = torch.argmax(logits, dim=-1)"
778
- ]
779
- },
780
- {
781
- "cell_type": "code",
782
- "execution_count": 12,
783
- "metadata": {
784
- "ExecuteTime": {
785
- "end_time": "2021-03-14T19:33:39.887236Z",
786
- "start_time": "2021-03-14T19:33:39.881958Z"
787
- }
788
- },
789
- "outputs": [
790
- {
791
- "name": "stdout",
792
- "output_type": "stream",
793
- "text": [
794
- "Prediction:\n",
795
- "πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n",
796
- "\n",
797
- "Reference:\n",
798
- "πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς.\n"
799
- ]
800
- }
801
- ],
802
- "source": [
803
- "print(\"Prediction:\")\n",
804
- "print(processor.decode(pred_ids[0]))\n",
805
- "# πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n",
806
- "\n",
807
- "print(\"\\nReference:\")\n",
808
- "print(common_voice_test_transcription[\"sentence\"][example].lower())\n",
809
- "# πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς."
810
- ]
811
- },
812
- {
813
- "cell_type": "code",
814
- "execution_count": 13,
815
- "metadata": {
816
- "ExecuteTime": {
817
- "end_time": "2021-03-17T11:15:35.637739Z",
818
- "start_time": "2021-03-17T11:14:14.689842Z"
819
- }
820
- },
821
- "outputs": [
822
- {
823
- "data": {
824
- "application/vnd.jupyter.widget-view+json": {
825
- "model_id": "1f7ba9e12187401f870555d20a6a9458",
826
- "version_major": 2,
827
- "version_minor": 0
828
- },
829
- "text/plain": [
830
- "HBox(children=(IntProgress(value=0, max=1522), HTML(value='')))"
831
- ]
832
- },
833
- "metadata": {},
834
- "output_type": "display_data"
835
- },
836
- {
837
- "name": "stdout",
838
- "output_type": "stream",
839
- "text": [
840
- "\n"
841
- ]
842
- }
843
- ],
844
- "source": [
845
- "def map_to_result(batch):\n",
846
- " model.to(\"cuda\")\n",
847
- " input_values = processor(\n",
848
- " batch[\"input_values\"], \n",
849
- " sampling_rate=16_000, \n",
850
- " return_tensors=\"pt\"\n",
851
- " ).input_values.to(\"cuda\")\n",
852
- "\n",
853
- " with torch.no_grad():\n",
854
- " logits = model(input_values).logits\n",
855
- "\n",
856
- " pred_ids = torch.argmax(logits, dim=-1)\n",
857
- " batch[\"pred_str\"] = processor.batch_decode(pred_ids)[0]\n",
858
- "\n",
859
- " return batch\n",
860
- "\n",
861
- "results = common_voice_test.map(map_to_result)\n"
862
- ]
863
- },
864
- {
865
- "cell_type": "code",
866
- "execution_count": 16,
867
- "metadata": {
868
- "ExecuteTime": {
869
- "end_time": "2021-03-17T11:17:11.951524Z",
870
- "start_time": "2021-03-17T11:17:08.856552Z"
871
- }
872
- },
873
- "outputs": [
874
- {
875
- "name": "stdout",
876
- "output_type": "stream",
877
- "text": [
878
- "Test WER: 0.396\n"
879
- ]
880
- }
881
- ],
882
- "source": [
883
- "def compute_metrics(pred):\n",
884
- " pred_logits = pred.predictions\n",
885
- " pred_ids = np.argmax(pred_logits, axis=-1)\n",
886
- "\n",
887
- " pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
888
- "\n",
889
- " pred_str = processor.batch_decode(pred_ids)\n",
890
- " # we do not want to group tokens when computing the metrics\n",
891
- " label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
892
- "\n",
893
- " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
894
- "\n",
895
- " return {\"wer\": wer}\n",
896
- "\n",
897
- "wer_metric = load_metric(\"wer\")\n",
898
- "\n",
899
- "print(\"Test WER: {:.3f}\".format(wer_metric.compute(predictions=results[\"pred_str\"], references= [item.lower() for item in common_voice_test_transcription['sentence']])))"
900
- ]
901
- },
902
- {
903
- "cell_type": "code",
904
- "execution_count": null,
905
- "metadata": {},
906
- "outputs": [],
907
- "source": []
908
- }
909
- ],
910
- "metadata": {
911
- "kernelspec": {
912
- "display_name": "cuda110",
913
- "language": "python",
914
- "name": "cuda110"
915
- },
916
- "language_info": {
917
- "codemirror_mode": {
918
- "name": "ipython",
919
- "version": 3
920
- },
921
- "file_extension": ".py",
922
- "mimetype": "text/x-python",
923
- "name": "python",
924
- "nbconvert_exporter": "python",
925
- "pygments_lexer": "ipython3",
926
- "version": "3.8.5"
927
- },
928
- "varInspector": {
929
- "cols": {
930
- "lenName": 16,
931
- "lenType": 16,
932
- "lenVar": 40
933
- },
934
- "kernels_config": {
935
- "python": {
936
- "delete_cmd_postfix": "",
937
- "delete_cmd_prefix": "del ",
938
- "library": "var_list.py",
939
- "varRefreshCmd": "print(var_dic_list())"
940
- },
941
- "r": {
942
- "delete_cmd_postfix": ") ",
943
- "delete_cmd_prefix": "rm(",
944
- "library": "var_list.r",
945
- "varRefreshCmd": "cat(var_dic_list()) "
946
- }
947
- },
948
- "types_to_exclude": [
949
- "module",
950
- "function",
951
- "builtin_function_or_method",
952
- "instance",
953
- "_Feature"
954
- ],
955
- "window_display": false
956
- }
957
- },
958
- "nbformat": 4,
959
- "nbformat_minor": 4
960
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/Fine_Tune_XLSR_Wav2Vec2_on_Greek_ASR_with_🤗_Transformers-checkpoint.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -22,6 +22,9 @@ model-index:
22
  - name: Test WER
23
  type: wer
24
  value: 10.497628
 
 
 
25
  ---
26
 
27
  # Greek (el) version of the XLSR-Wav2Vec2 automatic speech recognition (ASR) model
@@ -204,6 +207,7 @@ Instructions and code to replicate the process are provided in the Fine_Tune_XLS
204
  | ----------- | ----------- |
205
  | Training Loss | 0.0545 |
206
  | Validation Loss | 0.1661 |
 
207
  | WER on CommonVoice Test (%) *| 10.4976 |
208
  * Reference transcripts were lower-cased and striped of punctuation and special characters.
209
 
 
22
  - name: Test WER
23
  type: wer
24
  value: 10.497628
25
+ - name: Test CER
26
+ type: cer
27
+ value: 2.875260
28
  ---
29
 
30
  # Greek (el) version of the XLSR-Wav2Vec2 automatic speech recognition (ASR) model
 
207
  | ----------- | ----------- |
208
  | Training Loss | 0.0545 |
209
  | Validation Loss | 0.1661 |
210
+ | CER on CommonVoice Test (%) *| 2.8753 |
211
  | WER on CommonVoice Test (%) *| 10.4976 |
212
  * Reference transcripts were lower-cased and striped of punctuation and special characters.
213
 
config.json CHANGED
@@ -36,7 +36,7 @@
36
  2
37
  ],
38
  "ctc_loss_reduction": "mean",
39
- "ctc_zero_infinity": false,
40
  "do_stable_layer_norm": true,
41
  "eos_token_id": 2,
42
  "feat_extract_activation": "gelu",
@@ -70,7 +70,7 @@
70
  "num_conv_pos_embeddings": 128,
71
  "num_feat_extract_layers": 7,
72
  "num_hidden_layers": 24,
73
- "pad_token_id": 52,
74
  "transformers_version": "4.4.0.dev0",
75
- "vocab_size": 53
76
  }
 
36
  2
37
  ],
38
  "ctc_loss_reduction": "mean",
39
+ "ctc_zero_infinity": true,
40
  "do_stable_layer_norm": true,
41
  "eos_token_id": 2,
42
  "feat_extract_activation": "gelu",
 
70
  "num_conv_pos_embeddings": 128,
71
  "num_feat_extract_layers": 7,
72
  "num_hidden_layers": 24,
73
+ "pad_token_id": 54,
74
  "transformers_version": "4.4.0.dev0",
75
+ "vocab_size": 55
76
  }