OpenSound commited on
Commit
915b86a
·
1 Parent(s): 554cf55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -126,7 +126,10 @@ class WhisperxModel:
126
 
127
  @spaces.GPU
128
  def load_models(ssrspeech_model_name):
129
-
 
 
 
130
  if ssrspeech_model_name == "English":
131
  ssrspeech_model_name = "English"
132
  text_tokenizer = TextTokenizer(backend="espeak")
@@ -168,7 +171,6 @@ def load_models(ssrspeech_model_name):
168
 
169
  return [
170
  gr.Accordion(),
171
- whisper_model, align_model, ssrspeech_model,
172
  success_message
173
  ]
174
 
@@ -182,7 +184,9 @@ def get_transcribe_state(segments):
182
  }
183
 
184
  @spaces.GPU
185
- def transcribe(audio_path, transcribe_model):
 
 
186
  if transcribe_model is None:
187
  raise gr.Error("Transcription model not loaded")
188
 
@@ -197,7 +201,8 @@ def transcribe(audio_path, transcribe_model):
197
 
198
 
199
  @spaces.GPU
200
- def align(segments, audio_path, align_model):
 
201
  if align_model is None:
202
  raise gr.Error("Align model not loaded")
203
 
@@ -227,8 +232,9 @@ def replace_numbers_with_words(sentence):
227
  @spaces.GPU
228
  def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_k, top_p, temperature,
229
  stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
230
- audio_path, original_transcript, transcript, mode, whisper_model, align_model, ssrspeech_model):
231
 
 
232
  aug_text = True if aug_text == 1 else False
233
  if ssrspeech_model is None:
234
  raise gr.Error("ssrspeech model not loaded")
@@ -248,16 +254,16 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
248
  target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
249
  orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
250
 
251
- [orig_transcript, segments, _] = transcribe(audio_path, whisper_model)
252
  if language == 'zh':
253
  converter = opencc.OpenCC('t2s')
254
  orig_transcript = converter.convert(orig_transcript)
255
- transcribe_state = align(traditional_to_simplified(segments), audio_path, align_model)
256
  transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
257
  elif language == 'en':
258
  orig_transcript = orig_transcript.lower()
259
  target_transcript = target_transcript.lower()
260
- transcribe_state = align(segments, audio_path, align_model)
261
  print(orig_transcript)
262
  print(target_transcript)
263
 
@@ -276,17 +282,17 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
276
 
277
  audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
278
  sf.write(audio_path, audio, 16000)
279
- [orig_transcript, segments, _] = transcribe(audio_path, whisper_model)
280
 
281
  if language == 'zh':
282
  converter = opencc.OpenCC('t2s')
283
  orig_transcript = converter.convert(orig_transcript)
284
- transcribe_state = align(traditional_to_simplified(segments), audio_path, align_model)
285
  transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
286
  elif language == 'en':
287
  orig_transcript = orig_transcript.lower()
288
  target_transcript = target_transcript.lower()
289
- transcribe_state = align(segments, audio_path, align_model)
290
  print(orig_transcript)
291
  target_transcript_copy = target_transcript # for tts cut out
292
  if language == 'en':
@@ -364,14 +370,14 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
364
  new_audio = new_audio[0].cpu()
365
  torchaudio.save(audio_path, new_audio, codec_audio_sr)
366
  if tts: # remove the start parts
367
- [new_transcript, new_segments, _] = transcribe(audio_path, whisper_model)
368
  if language == 'zh':
369
- transcribe_state = align(traditional_to_simplified(new_segments), audio_path, align_model)
370
  transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
371
  tmp1 = transcribe_state['segments'][0]['words'][0]['word']
372
  tmp2 = target_transcript_copy
373
  elif language == 'en':
374
- transcribe_state = align(new_segments, audio_path, align_model)
375
  tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
376
  tmp2 = target_transcript_copy.lower()
377
  if tmp1 == tmp2:
@@ -454,7 +460,7 @@ def get_app():
454
 
455
  load_models_btn.click(fn=load_models,
456
  inputs=[ssrspeech_model_choice],
457
- outputs=[models_selector, whisper_model, align_model, ssrspeech_model, success_output])
458
 
459
  semgents = gr.State() # not used
460
  transcribe_btn.click(fn=transcribe,
@@ -468,7 +474,7 @@ def get_app():
468
  top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
469
  aug_text, cfg_coef, prompt_length,
470
  input_audio, original_transcript, transcript,
471
- mode, whisper_model, align_model, ssrspeech_model
472
  ],
473
  outputs=[output_audio, success_output])
474
 
 
126
 
127
  @spaces.GPU
128
  def load_models(ssrspeech_model_name):
129
+ global transcribe_model, align_model, ssrspeech_model
130
+
131
+ alignment_model_name = "whisperX"
132
+ whisper_backend_name = "whisperX"
133
  if ssrspeech_model_name == "English":
134
  ssrspeech_model_name = "English"
135
  text_tokenizer = TextTokenizer(backend="espeak")
 
171
 
172
  return [
173
  gr.Accordion(),
 
174
  success_message
175
  ]
176
 
 
184
  }
185
 
186
  @spaces.GPU
187
+ def transcribe(audio_path):
188
+ global transcribe_model
189
+
190
  if transcribe_model is None:
191
  raise gr.Error("Transcription model not loaded")
192
 
 
201
 
202
 
203
  @spaces.GPU
204
+ def align(segments, audio_path):
205
+ global align_model
206
  if align_model is None:
207
  raise gr.Error("Align model not loaded")
208
 
 
232
  @spaces.GPU
233
  def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_k, top_p, temperature,
234
  stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
235
+ audio_path, original_transcript, transcript, mode):
236
 
237
+ global transcribe_model, align_model, ssrspeech_model
238
  aug_text = True if aug_text == 1 else False
239
  if ssrspeech_model is None:
240
  raise gr.Error("ssrspeech model not loaded")
 
254
  target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
255
  orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
256
 
257
+ [orig_transcript, segments, _] = transcribe(audio_path)
258
  if language == 'zh':
259
  converter = opencc.OpenCC('t2s')
260
  orig_transcript = converter.convert(orig_transcript)
261
+ transcribe_state = align(traditional_to_simplified(segments), audio_path)
262
  transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
263
  elif language == 'en':
264
  orig_transcript = orig_transcript.lower()
265
  target_transcript = target_transcript.lower()
266
+ transcribe_state = align(segments, audio_path)
267
  print(orig_transcript)
268
  print(target_transcript)
269
 
 
282
 
283
  audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
284
  sf.write(audio_path, audio, 16000)
285
+ [orig_transcript, segments, _] = transcribe(audio_path)
286
 
287
  if language == 'zh':
288
  converter = opencc.OpenCC('t2s')
289
  orig_transcript = converter.convert(orig_transcript)
290
+ transcribe_state = align(traditional_to_simplified(segments), audio_path)
291
  transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
292
  elif language == 'en':
293
  orig_transcript = orig_transcript.lower()
294
  target_transcript = target_transcript.lower()
295
+ transcribe_state = align(segments, audio_path)
296
  print(orig_transcript)
297
  target_transcript_copy = target_transcript # for tts cut out
298
  if language == 'en':
 
370
  new_audio = new_audio[0].cpu()
371
  torchaudio.save(audio_path, new_audio, codec_audio_sr)
372
  if tts: # remove the start parts
373
+ [new_transcript, new_segments, _] = transcribe(audio_path)
374
  if language == 'zh':
375
+ transcribe_state = align(traditional_to_simplified(new_segments), audio_path)
376
  transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
377
  tmp1 = transcribe_state['segments'][0]['words'][0]['word']
378
  tmp2 = target_transcript_copy
379
  elif language == 'en':
380
+ transcribe_state = align(new_segments, audio_path)
381
  tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
382
  tmp2 = target_transcript_copy.lower()
383
  if tmp1 == tmp2:
 
460
 
461
  load_models_btn.click(fn=load_models,
462
  inputs=[ssrspeech_model_choice],
463
+ outputs=[models_selector, success_output])
464
 
465
  semgents = gr.State() # not used
466
  transcribe_btn.click(fn=transcribe,
 
474
  top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
475
  aug_text, cfg_coef, prompt_length,
476
  input_audio, original_transcript, transcript,
477
+ mode
478
  ],
479
  outputs=[output_audio, success_output])
480