dragonSwing commited on
Commit
d326cd2
·
1 Parent(s): 5b9a0b4

Remove LM model in inference

Browse files
Files changed (1) hide show
  1. app.py +49 -29
app.py CHANGED
@@ -5,24 +5,28 @@ from pyctcdecode import build_ctcdecoder
5
  from speechbrain.pretrained import EncoderASR
6
  from transformers.file_utils import cached_path, hf_bucket_url
7
 
8
- cache_dir = './cache/'
9
- lm_file = hf_bucket_url(
10
- "dragonSwing/wav2vec2-base-vn-270h", filename='4gram.zip')
11
- lm_file = cached_path(lm_file, cache_dir=cache_dir)
12
- with zipfile.ZipFile(lm_file, 'r') as zip_ref:
13
- zip_ref.extractall(cache_dir)
14
- lm_file = cache_dir + 'lm.binary'
15
- vocab_file = cache_dir + 'vocab-260000.txt'
16
- model = EncoderASR.from_hparams(source="dragonSwing/wav2vec2-base-vn-270h",
17
- savedir="./pretrained/wav2vec-vi-asr"
18
- )
 
 
 
 
19
 
20
 
21
  def get_decoder_ngram_model(tokenizer, ngram_lm_path, vocab_path=None):
22
  unigrams = None
23
  if vocab_path is not None:
24
  unigrams = []
25
- with open(vocab_path, encoding='utf-8') as f:
26
  for line in f:
27
  unigrams.append(line.strip())
28
 
@@ -40,21 +44,28 @@ def get_decoder_ngram_model(tokenizer, ngram_lm_path, vocab_path=None):
40
  return decoder
41
 
42
 
43
- ngram_lm_model = get_decoder_ngram_model(model.tokenizer, lm_file, vocab_file)
44
 
45
 
46
- def transcribe_file(path, max_seconds=20):
47
  waveform = model.load_audio(path)
48
  if max_seconds > 0:
49
- waveform = waveform[:max_seconds*16000]
50
  batch = waveform.unsqueeze(0)
51
  rel_length = torch.tensor([1.0])
52
- with torch.no_grad():
53
- logits = model(batch, rel_length)
54
- text_batch = [ngram_lm_model.decode(
55
- logit.detach().cpu().numpy(), beam_width=500) for logit in logits]
56
- return text_batch[0]
57
-
 
 
 
 
 
 
 
58
 
59
  def speech_recognize(file_upload, file_mic):
60
  if file_upload is not None:
@@ -68,17 +79,26 @@ def speech_recognize(file_upload, file_mic):
68
  return text
69
 
70
 
71
- inputs = [gr.inputs.Audio(source="upload", type='filepath', optional=True), gr.inputs.Audio(
72
- source="microphone", type='filepath', optional=True)]
 
 
73
  outputs = gr.outputs.Textbox(label="Output Text")
74
  title = "wav2vec2-base-vietnamese-270h"
75
  description = "Gradio demo for a wav2vec2 base vietnamese speech recognition. To use it, simply upload your audio, click one of the examples to load them, or record from your own microphone. Read more at the links below. Currently supports 16_000hz audio files"
76
  article = "<p style='text-align: center'><a href='https://huggingface.co/dragonSwing/wav2vec2-base-vn-270h' target='_blank'>Pretrained model</a></p>"
77
  examples = [
78
- ['example1.wav', 'example1.wav'],
79
- ['example2.mp3', 'example2.mp3'],
80
- ['example3.mp3', 'example3.mp3'],
81
- ['example4.wav', 'example4.wav'],
82
  ]
83
- gr.Interface(speech_recognize, inputs, outputs, title=title,
84
- description=description, article=article, examples=examples).launch()
 
 
 
 
 
 
 
 
5
  from speechbrain.pretrained import EncoderASR
6
  from transformers.file_utils import cached_path, hf_bucket_url
7
 
8
+
9
+ def download_lm(cache_dir="./cache"):
10
+ cache_dir = "./cache/"
11
+ lm_file = hf_bucket_url("dragonSwing/wav2vec2-base-vn-270h", filename="4gram.zip")
12
+ lm_file = cached_path(lm_file, cache_dir=cache_dir)
13
+ with zipfile.ZipFile(lm_file, "r") as zip_ref:
14
+ zip_ref.extractall(cache_dir)
15
+ lm_file = cache_dir + "lm.binary"
16
+ vocab_file = cache_dir + "vocab-260000.txt"
17
+ return lm_file, vocab_file
18
+
19
+
20
+ model = EncoderASR.from_hparams(
21
+ source="dragonSwing/wav2vec2-base-vn-270h", savedir="./pretrained/wav2vec-vi-asr"
22
+ )
23
 
24
 
25
  def get_decoder_ngram_model(tokenizer, ngram_lm_path, vocab_path=None):
26
  unigrams = None
27
  if vocab_path is not None:
28
  unigrams = []
29
+ with open(vocab_path, encoding="utf-8") as f:
30
  for line in f:
31
  unigrams.append(line.strip())
32
 
 
44
  return decoder
45
 
46
 
47
+ # ngram_lm_model = get_decoder_ngram_model(model.tokenizer, lm_file, vocab_file)
48
 
49
 
50
+ def transcribe_file(path, max_seconds=20, lm_model=None):
51
  waveform = model.load_audio(path)
52
  if max_seconds > 0:
53
+ waveform = waveform[: max_seconds * 16000]
54
  batch = waveform.unsqueeze(0)
55
  rel_length = torch.tensor([1.0])
56
+ if lm_model:
57
+ with torch.no_grad():
58
+ logits = model(batch, rel_length)
59
+ text_batch = [
60
+ lm_model.decode(logit.detach().cpu().numpy(), beam_width=500)
61
+ for logit in logits
62
+ ]
63
+ return text_batch[0]
64
+ else:
65
+ text_batch, _ = model.transcribe_batch(
66
+ batch, rel_length
67
+ )
68
+ return text_batch[0]
69
 
70
  def speech_recognize(file_upload, file_mic):
71
  if file_upload is not None:
 
79
  return text
80
 
81
 
82
+ inputs = [
83
+ gr.inputs.Audio(source="upload", type="filepath", optional=True),
84
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True),
85
+ ]
86
  outputs = gr.outputs.Textbox(label="Output Text")
87
  title = "wav2vec2-base-vietnamese-270h"
88
  description = "Gradio demo for a wav2vec2 base vietnamese speech recognition. To use it, simply upload your audio, click one of the examples to load them, or record from your own microphone. Read more at the links below. Currently supports 16_000hz audio files"
89
  article = "<p style='text-align: center'><a href='https://huggingface.co/dragonSwing/wav2vec2-base-vn-270h' target='_blank'>Pretrained model</a></p>"
90
  examples = [
91
+ ["example1.wav", "example1.wav"],
92
+ ["example2.mp3", "example2.mp3"],
93
+ ["example3.mp3", "example3.mp3"],
94
+ ["example4.wav", "example4.wav"],
95
  ]
96
+ gr.Interface(
97
+ speech_recognize,
98
+ inputs,
99
+ outputs,
100
+ title=title,
101
+ description=description,
102
+ article=article,
103
+ examples=examples,
104
+ ).launch()