csukuangfj commited on
Commit
a97e72d
·
1 Parent(s): 588da9c

minor fixes.

Browse files
Files changed (3) hide show
  1. app.py +89 -36
  2. model.py +159 -22
  3. offline_asr.py +40 -32
app.py CHANGED
@@ -19,6 +19,7 @@
19
  # References:
20
  # https://gradio.app/docs/#dropdown
21
 
 
22
  import os
23
  import time
24
  from datetime import datetime
@@ -26,43 +27,43 @@ from datetime import datetime
26
  import gradio as gr
27
  import torchaudio
28
 
29
- from model import (
30
- get_gigaspeech_pre_trained_model,
31
- sample_rate,
32
- get_wenetspeech_pre_trained_model,
33
- )
34
 
35
- models = {
36
- "Chinese": get_wenetspeech_pre_trained_model(),
37
- "English": get_gigaspeech_pre_trained_model(),
38
- }
39
 
40
 
41
  def convert_to_wav(in_filename: str) -> str:
42
  """Convert the input audio file to a wave file"""
43
  out_filename = in_filename + ".wav"
44
- print(f"Converting '{in_filename}' to '{out_filename}'")
45
  _ = os.system(f"ffmpeg -hide_banner -i '{in_filename}' '{out_filename}'")
46
  return out_filename
47
 
48
 
49
- demo = gr.Blocks()
50
-
 
 
 
 
 
 
 
 
 
 
51
 
52
- def process(in_filename: str, language: str) -> str:
53
- print("in_filename", in_filename)
54
- print("language", language)
55
  filename = convert_to_wav(in_filename)
56
 
57
  now = datetime.now()
58
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
59
- print(f"Started at {date_time}")
60
 
61
  start = time.time()
62
  wave, wave_sample_rate = torchaudio.load(filename)
63
 
64
  if wave_sample_rate != sample_rate:
65
- print(
66
  f"Expected sample rate: {sample_rate}. Given: {wave_sample_rate}. "
67
  f"Resampling to {sample_rate}."
68
  )
@@ -74,7 +75,11 @@ def process(in_filename: str, language: str) -> str:
74
  )
75
  wave = wave[0] # use only the first channel.
76
 
77
- hyp = models[language].decode_waves([wave])[0]
 
 
 
 
78
 
79
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
80
  end = time.time()
@@ -82,11 +87,10 @@ def process(in_filename: str, language: str) -> str:
82
  duration = wave.shape[0] / sample_rate
83
  rtf = (end - start) / duration
84
 
85
- print(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
86
- print(f"Duration {duration: .3f} s")
87
- print(f"RTF {rtf: .3f}")
88
- print("hyp")
89
- print(hyp)
90
 
91
  return hyp
92
 
@@ -103,51 +107,100 @@ See more information by visiting the following links:
103
  - <https://github.com/lhotse-speech/lhotse>
104
  """
105
 
 
 
 
 
 
 
 
 
 
 
 
106
  with demo:
107
  gr.Markdown(title)
108
- gr.Markdown(description)
109
- language_choices = list(models.keys())
110
- language = gr.inputs.Radio(
111
  label="Language",
112
  choices=language_choices,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
114
 
115
  with gr.Tabs():
116
  with gr.TabItem("Upload from disk"):
117
- uploaded_file = gr.inputs.Audio(
118
  source="upload", # Choose between "microphone", "upload"
119
  type="filepath",
120
  optional=False,
121
  label="Upload from disk",
122
  )
123
  upload_button = gr.Button("Submit for recognition")
124
- uploaded_output = gr.outputs.Textbox(
125
- label="Recognized speech from uploaded file"
126
- )
127
 
128
  with gr.TabItem("Record from microphone"):
129
- microphone = gr.inputs.Audio(
130
  source="microphone", # Choose between "microphone", "upload"
131
  type="filepath",
132
  optional=False,
133
  label="Record from microphone",
134
  )
135
- recorded_output = gr.outputs.Textbox(
136
- label="Recognized speech from recordings"
137
- )
138
 
139
  record_button = gr.Button("Submit for recognition")
 
140
 
141
  upload_button.click(
142
  process,
143
- inputs=[uploaded_file, language],
 
 
 
 
 
 
144
  outputs=uploaded_output,
145
  )
146
  record_button.click(
147
  process,
148
- inputs=[microphone, language],
 
 
 
 
 
 
149
  outputs=recorded_output,
150
  )
 
151
 
152
  if __name__ == "__main__":
 
 
 
 
153
  demo.launch()
 
19
  # References:
20
  # https://gradio.app/docs/#dropdown
21
 
22
+ import logging
23
  import os
24
  import time
25
  from datetime import datetime
 
27
  import gradio as gr
28
  import torchaudio
29
 
30
+ from model import get_pretrained_model, language_to_models, sample_rate
 
 
 
 
31
 
32
+ languages = sorted(language_to_models.keys())
 
 
 
33
 
34
 
35
  def convert_to_wav(in_filename: str) -> str:
36
  """Convert the input audio file to a wave file"""
37
  out_filename = in_filename + ".wav"
38
+ logging.info(f"Converting '{in_filename}' to '{out_filename}'")
39
  _ = os.system(f"ffmpeg -hide_banner -i '{in_filename}' '{out_filename}'")
40
  return out_filename
41
 
42
 
43
+ def process(
44
+ in_filename: str,
45
+ language: str,
46
+ repo_id: str,
47
+ decoding_method: str,
48
+ num_active_paths: int,
49
+ ) -> str:
50
+ logging.info(f"in_filename: {in_filename}")
51
+ logging.info(f"language: {language}")
52
+ logging.info(f"repo_id: {repo_id}")
53
+ logging.info(f"decoding_method: {decoding_method}")
54
+ logging.info(f"num_active_paths: {num_active_paths}")
55
 
 
 
 
56
  filename = convert_to_wav(in_filename)
57
 
58
  now = datetime.now()
59
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
60
+ logging.info(f"Started at {date_time}")
61
 
62
  start = time.time()
63
  wave, wave_sample_rate = torchaudio.load(filename)
64
 
65
  if wave_sample_rate != sample_rate:
66
+ logging.info(
67
  f"Expected sample rate: {sample_rate}. Given: {wave_sample_rate}. "
68
  f"Resampling to {sample_rate}."
69
  )
 
75
  )
76
  wave = wave[0] # use only the first channel.
77
 
78
+ hyp = get_pretrained_model(repo_id).decode_waves(
79
+ [wave],
80
+ decoding_method=decoding_method,
81
+ num_active_paths=num_active_paths,
82
+ )[0]
83
 
84
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
85
  end = time.time()
 
87
  duration = wave.shape[0] / sample_rate
88
  rtf = (end - start) / duration
89
 
90
+ logging.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
91
+ logging.info(f"Duration {duration: .3f} s")
92
+ logging.info(f"RTF {rtf: .3f}")
93
+ logging.info(f"hyp:\n{hyp}")
 
94
 
95
  return hyp
96
 
 
107
  - <https://github.com/lhotse-speech/lhotse>
108
  """
109
 
110
+
111
+ def update_model_dropdown(language: str):
112
+ if language in language_to_models:
113
+ choices = language_to_models[language]
114
+ return gr.Dropdown.update(choices=choices, value=choices[0])
115
+
116
+ raise ValueError(f"Unsupported language: {language}")
117
+
118
+
119
+ demo = gr.Blocks()
120
+
121
  with demo:
122
  gr.Markdown(title)
123
+ language_choices = list(language_to_models.keys())
124
+
125
+ language_radio = gr.Radio(
126
  label="Language",
127
  choices=language_choices,
128
+ value=language_choices[0],
129
+ )
130
+ model_dropdown = gr.Dropdown(
131
+ choices=language_to_models[language_choices[0]],
132
+ label="Select a model",
133
+ value=language_to_models[language_choices[0]][0],
134
+ )
135
+
136
+ language_radio.change(
137
+ update_model_dropdown,
138
+ inputs=language_radio,
139
+ outputs=model_dropdown,
140
+ )
141
+
142
+ decoding_method_radio = gr.Radio(
143
+ label="Decoding method",
144
+ choices=["greedy_search", "modified_beam_search"],
145
+ value="greedy_search",
146
+ )
147
+
148
+ num_active_paths_slider = gr.Slider(
149
+ minimum=1,
150
+ value=4,
151
+ step=1,
152
+ label="Number of active paths for modified_beam_search",
153
  )
154
 
155
  with gr.Tabs():
156
  with gr.TabItem("Upload from disk"):
157
+ uploaded_file = gr.Audio(
158
  source="upload", # Choose between "microphone", "upload"
159
  type="filepath",
160
  optional=False,
161
  label="Upload from disk",
162
  )
163
  upload_button = gr.Button("Submit for recognition")
164
+ uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
 
 
165
 
166
  with gr.TabItem("Record from microphone"):
167
+ microphone = gr.Audio(
168
  source="microphone", # Choose between "microphone", "upload"
169
  type="filepath",
170
  optional=False,
171
  label="Record from microphone",
172
  )
 
 
 
173
 
174
  record_button = gr.Button("Submit for recognition")
175
+ recorded_output = gr.Textbox(label="Recognized speech from recordings")
176
 
177
  upload_button.click(
178
  process,
179
+ inputs=[
180
+ uploaded_file,
181
+ language_radio,
182
+ model_dropdown,
183
+ decoding_method_radio,
184
+ num_active_paths_slider,
185
+ ],
186
  outputs=uploaded_output,
187
  )
188
  record_button.click(
189
  process,
190
+ inputs=[
191
+ microphone,
192
+ language_radio,
193
+ model_dropdown,
194
+ decoding_method_radio,
195
+ num_active_paths_slider,
196
+ ],
197
  outputs=recorded_output,
198
  )
199
+ gr.Markdown(description)
200
 
201
  if __name__ == "__main__":
202
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
203
+
204
+ logging.basicConfig(format=formatter, level=logging.INFO)
205
+
206
  demo.launch()
model.py CHANGED
@@ -23,52 +23,189 @@ from offline_asr import OfflineAsr
23
  sample_rate = 16000
24
 
25
 
26
- @lru_cache(maxsize=1)
27
- def get_gigaspeech_pre_trained_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  nn_model_filename = hf_hub_download(
29
- # It is converted from https://huggingface.co/wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2
30
- repo_id="csukuangfj/icefall-asr-gigaspeech-pruned-transducer-stateless2",
31
- filename="cpu_jit-epoch-29-avg-11-torch-1.10.0.pt",
32
- subfolder="exp",
33
  )
 
 
34
 
 
 
 
 
 
35
  bpe_model_filename = hf_hub_download(
36
- repo_id="wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
37
- filename="bpe.model",
38
- subfolder="data/lang_bpe_500",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
 
40
 
41
  return OfflineAsr(
42
  nn_model_filename=nn_model_filename,
43
  bpe_model_filename=bpe_model_filename,
44
  token_filename=None,
45
- decoding_method="greedy_search",
46
- num_active_paths=4,
47
  sample_rate=sample_rate,
48
  device="cpu",
49
  )
50
 
51
 
52
- @lru_cache(maxsize=1)
53
- def get_wenetspeech_pre_trained_model():
54
- nn_model_filename = hf_hub_download(
55
- repo_id="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
57
- subfolder="exp",
58
  )
 
59
 
60
- token_filename = hf_hub_download(
61
- repo_id="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
62
- filename="tokens.txt",
63
- subfolder="data/lang_char",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
 
65
 
66
  return OfflineAsr(
67
  nn_model_filename=nn_model_filename,
68
  bpe_model_filename=None,
69
  token_filename=token_filename,
70
- decoding_method="greedy_search",
71
- num_active_paths=4,
72
  sample_rate=sample_rate,
73
  device="cpu",
74
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  sample_rate = 16000
24
 
25
 
26
+ @lru_cache(maxsize=30)
27
+ def get_pretrained_model(repo_id: str) -> OfflineAsr:
28
+ if repo_id in chinese_models:
29
+ return chinese_models[repo_id](repo_id)
30
+ elif repo_id in english_models:
31
+ return english_models[repo_id](repo_id)
32
+ elif repo_id in chinese_english_mixed_models:
33
+ return chinese_english_mixed_models[repo_id](repo_id)
34
+ else:
35
+ raise ValueError(f"Unsupported repo_id: {repo_id}")
36
+
37
+
38
+ def _get_nn_model_filename(
39
+ repo_id: str,
40
+ filename: str,
41
+ subfolder: str = "exp",
42
+ ) -> str:
43
  nn_model_filename = hf_hub_download(
44
+ repo_id=repo_id,
45
+ filename=filename,
46
+ subfolder=subfolder,
 
47
  )
48
+ return nn_model_filename
49
+
50
 
51
+ def _get_bpe_model_filename(
52
+ repo_id: str,
53
+ filename: str = "bpe.model",
54
+ subfolder: str = "data/lang_bpe_500",
55
+ ) -> str:
56
  bpe_model_filename = hf_hub_download(
57
+ repo_id=repo_id,
58
+ filename=filename,
59
+ subfolder=subfolder,
60
+ )
61
+ return bpe_model_filename
62
+
63
+
64
+ def _get_token_filename(
65
+ repo_id: str,
66
+ filename: str = "tokens.txt",
67
+ subfolder: str = "data/lang_char",
68
+ ) -> str:
69
+ token_filename = hf_hub_download(
70
+ repo_id=repo_id,
71
+ filename=filename,
72
+ subfolder=subfolder,
73
+ )
74
+ return token_filename
75
+
76
+
77
+ @lru_cache(maxsize=10)
78
+ def _get_aishell2_pretrained_model(repo_id: str) -> OfflineAsr:
79
+ assert repo_id in [
80
+ # context-size 1
81
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12", # noqa
82
+ # context-size 2
83
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12", # noqa
84
+ ]
85
+
86
+ nn_model_filename = _get_nn_model_filename(
87
+ repo_id=repo_id,
88
+ filename="cpu_jit.pt",
89
+ )
90
+ token_filename = _get_token_filename(repo_id=repo_id)
91
+
92
+ return OfflineAsr(
93
+ nn_model_filename=nn_model_filename,
94
+ bpe_model_filename=None,
95
+ token_filename=token_filename,
96
+ sample_rate=sample_rate,
97
+ device="cpu",
98
+ )
99
+
100
+
101
+ @lru_cache(maxsize=10)
102
+ def _get_gigaspeech_pre_trained_model(repo_id: str) -> OfflineAsr:
103
+ assert repo_id in [
104
+ "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2",
105
+ ]
106
+
107
+ nn_model_filename = _get_nn_model_filename(
108
+ # It is converted from https://huggingface.co/wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2 # noqa
109
+ repo_id="csukuangfj/icefall-asr-gigaspeech-pruned-transducer-stateless2", # noqa
110
+ filename="cpu_jit-epoch-29-avg-11-torch-1.10.0.pt",
111
  )
112
+ bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
113
 
114
  return OfflineAsr(
115
  nn_model_filename=nn_model_filename,
116
  bpe_model_filename=bpe_model_filename,
117
  token_filename=None,
 
 
118
  sample_rate=sample_rate,
119
  device="cpu",
120
  )
121
 
122
 
123
+ @lru_cache(maxsize=10)
124
+ def _get_librispeech_pre_trained_model(repo_id: str) -> OfflineAsr:
125
+ assert repo_id in [
126
+ "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13", # noqa
127
+ ]
128
+
129
+ nn_model_filename = _get_nn_model_filename(
130
+ repo_id=repo_id,
131
+ filename="cpu_jit.pt",
132
+ )
133
+ bpe_model_filename = _get_bpe_model_filename(repo_id=repo_id)
134
+
135
+ return OfflineAsr(
136
+ nn_model_filename=nn_model_filename,
137
+ bpe_model_filename=bpe_model_filename,
138
+ token_filename=None,
139
+ sample_rate=sample_rate,
140
+ device="cpu",
141
+ )
142
+
143
+
144
+ @lru_cache(maxsize=10)
145
+ def _get_wenetspeech_pre_trained_model(repo_id: str):
146
+ assert repo_id in [
147
+ "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
148
+ ]
149
+
150
+ nn_model_filename = _get_nn_model_filename(
151
+ repo_id=repo_id,
152
  filename="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt",
 
153
  )
154
+ token_filename = _get_token_filename(repo_id=repo_id)
155
 
156
+ return OfflineAsr(
157
+ nn_model_filename=nn_model_filename,
158
+ bpe_model_filename=None,
159
+ token_filename=token_filename,
160
+ sample_rate=sample_rate,
161
+ device="cpu",
162
+ )
163
+
164
+
165
+ @lru_cache(maxsize=10)
166
+ def _get_tal_csasr_pre_trained_model(repo_id: str):
167
+ assert repo_id in [
168
+ "luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5",
169
+ ]
170
+
171
+ nn_model_filename = _get_nn_model_filename(
172
+ repo_id=repo_id,
173
+ filename="cpu_jit.pt",
174
  )
175
+ token_filename = _get_token_filename(repo_id=repo_id)
176
 
177
  return OfflineAsr(
178
  nn_model_filename=nn_model_filename,
179
  bpe_model_filename=None,
180
  token_filename=token_filename,
 
 
181
  sample_rate=sample_rate,
182
  device="cpu",
183
  )
184
+
185
+
186
+ chinese_models = {
187
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12": _get_aishell2_pretrained_model, # noqa
188
+ "yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-B-2022-07-12": _get_aishell2_pretrained_model, # noqa
189
+ "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2": _get_wenetspeech_pre_trained_model, # noqa
190
+ }
191
+
192
+ english_models = {
193
+ "wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2": _get_gigaspeech_pre_trained_model, # noqa
194
+ "csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13": _get_librispeech_pre_trained_model, # noqa
195
+ }
196
+
197
+ chinese_english_mixed_models = {
198
+ "luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5": _get_tal_csasr_pre_trained_model, # noqa
199
+ }
200
+
201
+ all_models = {
202
+ **chinese_models,
203
+ **english_models,
204
+ **chinese_english_mixed_models,
205
+ }
206
+
207
+ language_to_models = {
208
+ "Chinese": sorted(chinese_models.keys()),
209
+ "English": sorted(english_models.keys()),
210
+ "Chinese+English": sorted(chinese_english_mixed_models.keys()),
211
+ }
offline_asr.py CHANGED
@@ -206,10 +206,10 @@ class OfflineAsr(object):
206
  def __init__(
207
  self,
208
  nn_model_filename: str,
209
- bpe_model_filename: Optional[str],
210
- token_filename: Optional[str],
211
- decoding_method: str,
212
- num_active_paths: int,
213
  sample_rate: int = 16000,
214
  device: Union[str, torch.device] = "cpu",
215
  ):
@@ -223,14 +223,6 @@ class OfflineAsr(object):
223
  token_filename:
224
  Path to tokens.txt. If it is None, you have to provide
225
  `bpe_model_filename`.
226
- decoding_method:
227
- The decoding method to use. Currently, only greedy_search and
228
- modified_beam_search are implemented.
229
- num_active_paths:
230
- Used only when decoding_method is modified_beam_search.
231
- It specifies number of active paths for each utterance. Due to
232
- merging paths with identical token sequences, the actual number
233
- may be less than "num_active_paths".
234
  sample_rate:
235
  Expected sample rate of the feature extractor.
236
  device:
@@ -246,6 +238,7 @@ class OfflineAsr(object):
246
  self.sp = spm.SentencePieceProcessor()
247
  self.sp.load(bpe_model_filename)
248
  else:
 
249
  self.token_table = k2.SymbolTable.from_file(token_filename)
250
 
251
  self.feature_extractor = self._build_feature_extractor(
@@ -253,24 +246,6 @@ class OfflineAsr(object):
253
  device=device,
254
  )
255
 
256
- assert decoding_method in (
257
- "greedy_search",
258
- "modified_beam_search",
259
- ), decoding_method
260
- if decoding_method == "greedy_search":
261
- nn_and_decoding_func = run_model_and_do_greedy_search
262
- elif decoding_method == "modified_beam_search":
263
- nn_and_decoding_func = functools.partial(
264
- run_model_and_do_modified_beam_search,
265
- num_active_paths=num_active_paths,
266
- )
267
- else:
268
- raise ValueError(
269
- f"Unsupported decoding_method: {decoding_method} "
270
- "Please use greedy_search or modified_beam_search"
271
- )
272
-
273
- self.nn_and_decoding_func = nn_and_decoding_func
274
  self.device = device
275
 
276
  def _build_feature_extractor(
@@ -299,7 +274,12 @@ class OfflineAsr(object):
299
 
300
  return fbank
301
 
302
- def decode_waves(self, waves: List[torch.Tensor]) -> List[List[str]]:
 
 
 
 
 
303
  """
304
  Args:
305
  waves:
@@ -313,20 +293,48 @@ class OfflineAsr(object):
313
  then the given waves have to contain samples in this range.
314
 
315
  All models trained in icefall use the normalized range [-1, 1].
 
 
 
 
 
 
 
 
316
  Returns:
317
  Return a list of decoded results. `ans[i]` contains the decoded
318
  results for `wavs[i]`.
319
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  waves = [w.to(self.device) for w in waves]
321
  features = self.feature_extractor(waves)
322
 
323
- tokens = self.nn_and_decoding_func(self.model, features)
324
 
325
  if hasattr(self, "sp"):
326
  results = self.sp.decode(tokens)
327
  else:
328
  results = [[self.token_table[i] for i in hyp] for hyp in tokens]
 
329
  results = ["".join(r) for r in results]
 
330
 
331
  return results
332
 
 
206
  def __init__(
207
  self,
208
  nn_model_filename: str,
209
+ bpe_model_filename: Optional[str] = None,
210
+ token_filename: Optional[str] = None,
211
+ decoding_method: str = "greedy_search",
212
+ num_active_paths: int = 4,
213
  sample_rate: int = 16000,
214
  device: Union[str, torch.device] = "cpu",
215
  ):
 
223
  token_filename:
224
  Path to tokens.txt. If it is None, you have to provide
225
  `bpe_model_filename`.
 
 
 
 
 
 
 
 
226
  sample_rate:
227
  Expected sample rate of the feature extractor.
228
  device:
 
238
  self.sp = spm.SentencePieceProcessor()
239
  self.sp.load(bpe_model_filename)
240
  else:
241
+ assert token_filename is not None, token_filename
242
  self.token_table = k2.SymbolTable.from_file(token_filename)
243
 
244
  self.feature_extractor = self._build_feature_extractor(
 
246
  device=device,
247
  )
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  self.device = device
250
 
251
  def _build_feature_extractor(
 
274
 
275
  return fbank
276
 
277
+ def decode_waves(
278
+ self,
279
+ waves: List[torch.Tensor],
280
+ decoding_method: str,
281
+ num_active_paths: int,
282
+ ) -> List[List[str]]:
283
  """
284
  Args:
285
  waves:
 
293
  then the given waves have to contain samples in this range.
294
 
295
  All models trained in icefall use the normalized range [-1, 1].
296
+ decoding_method:
297
+ The decoding method to use. Currently, only greedy_search and
298
+ modified_beam_search are implemented.
299
+ num_active_paths:
300
+ Used only when decoding_method is modified_beam_search.
301
+ It specifies number of active paths for each utterance. Due to
302
+ merging paths with identical token sequences, the actual number
303
+ may be less than "num_active_paths".
304
  Returns:
305
  Return a list of decoded results. `ans[i]` contains the decoded
306
  results for `wavs[i]`.
307
  """
308
+ assert decoding_method in (
309
+ "greedy_search",
310
+ "modified_beam_search",
311
+ ), decoding_method
312
+
313
+ if decoding_method == "greedy_search":
314
+ nn_and_decoding_func = run_model_and_do_greedy_search
315
+ elif decoding_method == "modified_beam_search":
316
+ nn_and_decoding_func = functools.partial(
317
+ run_model_and_do_modified_beam_search,
318
+ num_active_paths=num_active_paths,
319
+ )
320
+ else:
321
+ raise ValueError(
322
+ f"Unsupported decoding_method: {decoding_method} "
323
+ "Please use greedy_search or modified_beam_search"
324
+ )
325
+
326
  waves = [w.to(self.device) for w in waves]
327
  features = self.feature_extractor(waves)
328
 
329
+ tokens = nn_and_decoding_func(self.model, features)
330
 
331
  if hasattr(self, "sp"):
332
  results = self.sp.decode(tokens)
333
  else:
334
  results = [[self.token_table[i] for i in hyp] for hyp in tokens]
335
+ blank = chr(0x2581)
336
  results = ["".join(r) for r in results]
337
+ results = [r.replace(blank, " ") for r in results]
338
 
339
  return results
340