ttop324 commited on
Commit
af74539
·
1 Parent(s): 9e975ee

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -44
README.md CHANGED
@@ -1,50 +1,112 @@
1
  ---
 
 
 
 
 
2
  tags:
3
- - generated_from_trainer
 
 
 
 
4
  model-index:
5
  - name: wav2vec2-live-japanese
6
- results: []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ---
8
-
9
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
10
- should probably proofread and complete it, then remove this comment. -->
11
-
12
  # wav2vec2-live-japanese
13
-
14
- This model was trained from scratch on the None dataset.
15
-
16
- ## Model description
17
-
18
- More information needed
19
-
20
- ## Intended uses & limitations
21
-
22
- More information needed
23
-
24
- ## Training and evaluation data
25
-
26
- More information needed
27
-
28
- ## Training procedure
29
-
30
- ### Training hyperparameters
31
-
32
- The following hyperparameters were used during training:
33
- - learning_rate: 0.0003
34
- - train_batch_size: 3
35
- - eval_batch_size: 2
36
- - seed: 42
37
- - gradient_accumulation_steps: 2
38
- - total_train_batch_size: 6
39
- - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
40
- - lr_scheduler_type: linear
41
- - lr_scheduler_warmup_steps: 500
42
- - num_epochs: 50
43
- - mixed_precision_training: Native AMP
44
-
45
- ### Framework versions
46
-
47
- - Transformers 4.11.2
48
- - Pytorch 1.9.1
49
- - Datasets 1.11.0
50
- - Tokenizers 0.10.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: ja
3
+ datasets:
4
+ - common_voice
5
+ metrics:
6
+ - wer
7
  tags:
8
+ - audio
9
+ - automatic-speech-recognition
10
+ - speech
11
+ - xlsr-fine-tuning-week
12
+ license: apache-2.0
13
  model-index:
14
  - name: wav2vec2-live-japanese
15
+ results:
16
+ - task:
17
+ name: Speech Recognition
18
+ type: automatic-speech-recognition
19
+ dataset:
20
+ name: Common Voice Japanese
21
+ type: common_voice
22
+ args: ja
23
+ metrics:
24
+ - name: Test WER
25
+ type: wer
26
+ value: 22.08%
27
+ - name: Test CER
28
+ type: cer
29
+ value: 10.08%
30
  ---
 
 
 
 
31
  # wav2vec2-live-japanese
32
+ https://github.com/ttop32/wav2vec2-live-japanese-translator
33
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese hiragana using the
34
+ - [common_voice](https://huggingface.co/datasets/common_voice)
35
+ - [JSUT](https://sites.google.com/site/shinnosuketakamichi/publication/jsut)
36
+ - [CSS10](https://github.com/Kyubyong/css10)
37
+ - [TEDxJP-10K](https://github.com/laboroai/TEDxJP-10K)
38
+ - [JVS](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus)
39
+ ## Inference
40
+ ```python
41
+ #usage
42
+ import torch
43
+ import torchaudio
44
+ from datasets import load_dataset
45
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
46
+ model = Wav2Vec2ForCTC.from_pretrained("ttop324/wav2vec2-live-japanese")
47
+ processor = Wav2Vec2Processor.from_pretrained("ttop324/wav2vec2-live-japanese")
48
+ test_dataset = load_dataset("common_voice", "ja", split="test")
49
+ # Preprocessing the datasets.
50
+ # We need to read the aduio files as arrays
51
+ def speech_file_to_array_fn(batch):
52
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
53
+ batch["speech"] = torchaudio.functional.resample(speech_array, sampling_rate, 16000)[0].numpy()
54
+ return batch
55
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
56
+ inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
57
+ with torch.no_grad():
58
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
59
+ predicted_ids = torch.argmax(logits, dim=-1)
60
+ print("Prediction:", processor.batch_decode(predicted_ids))
61
+ print("Reference:", test_dataset[:2]["sentence"])
62
+ ```
63
+ ## Evaluation
64
+ ```python
65
+ import torch
66
+ import torchaudio
67
+ from datasets import load_dataset, load_metric
68
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
69
+ import re
70
+ import pykakasi
71
+ import MeCab
72
+ wer = load_metric("wer")
73
+ cer = load_metric("cer")
74
+ model = Wav2Vec2ForCTC.from_pretrained("ttop324/wav2vec2-live-japanese").to("cuda")
75
+ processor = Wav2Vec2Processor.from_pretrained("ttop324/wav2vec2-live-japanese")
76
+ test_dataset = load_dataset("common_voice", "ja", split="test")
77
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\�‘、。.!,・―─~「」『』\\\\※\[\]\{\}「」〇?…]'
78
+ wakati = MeCab.Tagger("-Owakati")
79
+ kakasi = pykakasi.kakasi()
80
+ kakasi.setMode("J","H") # kanji to hiragana
81
+ kakasi.setMode("K","H") # katakana to hiragana
82
+ conv = kakasi.getConverter()
83
+ FULLWIDTH_TO_HALFWIDTH = str.maketrans(
84
+ ' 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!゛#$%&()*+、ー。/:;〈=〉?@[]^_‘{|}~',
85
+ ' 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~',
86
+ )
87
+ def fullwidth_to_halfwidth(s):
88
+ return s.translate(FULLWIDTH_TO_HALFWIDTH)
89
+ def preprocessData(batch):
90
+ batch["sentence"] = fullwidth_to_halfwidth(batch["sentence"])
91
+ batch["sentence"] = re.sub(chars_to_ignore_regex,' ', batch["sentence"]).lower() #remove special char
92
+ batch["sentence"] = wakati.parse(batch["sentence"]) #add space
93
+ batch["sentence"] = conv.do(batch["sentence"]) #covert to hiragana
94
+ batch["sentence"] = " ".join(batch["sentence"].split())+" " #remove multiple space
95
+
96
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
97
+ batch["speech"] = torchaudio.functional.resample(speech_array, sampling_rate, 16000)[0].numpy()
98
+ return batch
99
+ test_dataset = test_dataset.map(preprocessData)
100
+ # Preprocessing the datasets.
101
+ # We need to read the aduio files as arrays
102
+ def evaluate(batch):
103
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
104
+ with torch.no_grad():
105
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
106
+ pred_ids = torch.argmax(logits, dim=-1)
107
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
108
+ return batch
109
+ result = test_dataset.map(evaluate, batched=True, batch_size=8)
110
+ print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
111
+ print("CER: {:2f}".format(100 * cer.compute(predictions=result["pred_strings"], references=result["sentence"])))
112
+ ```