OpenSound commited on
Commit
de943de
·
1 Parent(s): 54f231c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -110
app.py CHANGED
@@ -8,9 +8,7 @@ from data.tokenizer import (
8
  AudioTokenizer,
9
  TextTokenizer,
10
  )
11
- from edit_utils_zh import parse_edit_zh
12
  from edit_utils_en import parse_edit_en
13
- from edit_utils_zh import parse_tts_zh
14
  from edit_utils_en import parse_tts_en
15
  from inference_scale import inference_one_sample
16
  import librosa
@@ -29,7 +27,6 @@ DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
29
  TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
30
  MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
- transcribe_model, align_model, ssrspeech_model = None, None, None
33
 
34
  def get_random_string():
35
  return "".join(str(uuid.uuid4()).split("-"))
@@ -124,56 +121,36 @@ class WhisperxModel:
124
  segment['text'] = replace_numbers_with_words(segment['text'])
125
  return self.align_model.align(segments, audio_path)
126
 
127
- @spaces.GPU
128
- def load_models(ssrspeech_model_name):
129
- global transcribe_model, align_model, ssrspeech_model
130
-
131
- alignment_model_name = "whisperX"
132
- whisper_backend_name = "whisperX"
133
- if ssrspeech_model_name == "English":
134
- ssrspeech_model_name = "English"
135
- text_tokenizer = TextTokenizer(backend="espeak")
136
- language = "en"
137
- transcribe_model_name = "base.en"
138
-
139
- elif ssrspeech_model_name == "Mandarin":
140
- ssrspeech_model_name = "Mandarin"
141
- text_tokenizer = TextTokenizer(backend="espeak", language='cmn')
142
- language = "zh"
143
- transcribe_model_name = "base"
144
-
145
- align_model = WhisperxAlignModel(language)
146
- transcribe_model = WhisperxModel(transcribe_model_name, align_model, language)
147
-
148
- ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth"
149
- if not os.path.exists(ssrspeech_fn):
150
- os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn)
151
- print(transcribe_model, align_model)
152
- ckpt = torch.load(ssrspeech_fn)
153
- model = ssr.SSR_Speech(ckpt["config"])
154
- model.load_state_dict(ckpt["model"])
155
- config = model.args
156
- phn2num = ckpt["phn2num"]
157
- model.to(device)
158
-
159
- encodec_fn = f"{MODELS_PATH}/wmencodec.th"
160
- if not os.path.exists(encodec_fn):
161
- os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
162
-
163
- ssrspeech_model = {
164
- "config": config,
165
- "phn2num": phn2num,
166
- "model": model,
167
- "text_tokenizer": text_tokenizer,
168
- "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
169
- }
170
- success_message = "<span style='color:green;'>Success: Models loading completed successfully!</span>"
171
-
172
- return [
173
- gr.Accordion(),
174
- success_message
175
- ]
176
-
177
 
178
  def get_transcribe_state(segments):
179
  transcript = " ".join([segment["text"] for segment in segments])
@@ -185,8 +162,6 @@ def get_transcribe_state(segments):
185
 
186
  @spaces.GPU
187
  def transcribe(audio_path):
188
- global transcribe_model
189
-
190
  if transcribe_model is None:
191
  raise gr.Error("Transcription model not loaded")
192
 
@@ -202,7 +177,6 @@ def transcribe(audio_path):
202
 
203
  @spaces.GPU
204
  def align(segments, audio_path):
205
- global align_model
206
  if align_model is None:
207
  raise gr.Error("Align model not loaded")
208
 
@@ -230,21 +204,15 @@ def replace_numbers_with_words(sentence):
230
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
231
 
232
  @spaces.GPU
233
- def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_k, top_p, temperature,
234
  stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
235
  audio_path, original_transcript, transcript, mode):
236
 
237
- global transcribe_model, align_model, ssrspeech_model
238
  aug_text = True if aug_text == 1 else False
239
  if ssrspeech_model is None:
240
  raise gr.Error("ssrspeech model not loaded")
241
 
242
  seed_everything(seed)
243
-
244
- if ssrspeech_model_choice == "English":
245
- language = "en"
246
- elif ssrspeech_model_choice == "Mandarin":
247
- language = "zh"
248
 
249
  # resample audio
250
  audio, _ = librosa.load(audio_path, sr=16000)
@@ -255,15 +223,9 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
255
  orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
256
 
257
  [orig_transcript, segments, _] = transcribe(audio_path)
258
- if language == 'zh':
259
- converter = opencc.OpenCC('t2s')
260
- orig_transcript = converter.convert(orig_transcript)
261
- transcribe_state = align(traditional_to_simplified(segments), audio_path)
262
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
263
- elif language == 'en':
264
- orig_transcript = orig_transcript.lower()
265
- target_transcript = target_transcript.lower()
266
- transcribe_state = align(segments, audio_path)
267
  print(orig_transcript)
268
  print(target_transcript)
269
 
@@ -284,26 +246,18 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
284
  sf.write(audio_path, audio, 16000)
285
  [orig_transcript, segments, _] = transcribe(audio_path)
286
 
287
- if language == 'zh':
288
- converter = opencc.OpenCC('t2s')
289
- orig_transcript = converter.convert(orig_transcript)
290
- transcribe_state = align(traditional_to_simplified(segments), audio_path)
291
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
292
- elif language == 'en':
293
- orig_transcript = orig_transcript.lower()
294
- target_transcript = target_transcript.lower()
295
- transcribe_state = align(segments, audio_path)
296
  print(orig_transcript)
297
  target_transcript_copy = target_transcript # for tts cut out
298
- if language == 'en':
299
- target_transcript_copy = target_transcript_copy.split(' ')[0]
300
- elif language == 'zh':
301
- target_transcript_copy = target_transcript_copy[0]
302
- target_transcript = orig_transcript + ' ' + target_transcript if language == 'en' else orig_transcript + target_transcript
303
  print(target_transcript)
304
 
305
  if mode == "Edit":
306
- operations, orig_spans = parse_edit_en(orig_transcript, target_transcript) if language == 'en' else parse_edit_zh(orig_transcript, target_transcript)
307
  print(operations)
308
  print("orig_spans: ", orig_spans)
309
 
@@ -371,15 +325,9 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
371
  torchaudio.save(audio_path, new_audio, codec_audio_sr)
372
  if tts: # remove the start parts
373
  [new_transcript, new_segments, _] = transcribe(audio_path)
374
- if language == 'zh':
375
- transcribe_state = align(traditional_to_simplified(new_segments), audio_path)
376
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
377
- tmp1 = transcribe_state['segments'][0]['words'][0]['word']
378
- tmp2 = target_transcript_copy
379
- elif language == 'en':
380
- transcribe_state = align(new_segments, audio_path)
381
- tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
382
- tmp2 = target_transcript_copy.lower()
383
  if tmp1 == tmp2:
384
  offset = transcribe_state['segments'][0]['words'][0]['start']
385
  else:
@@ -406,15 +354,6 @@ demo_text = {
406
 
407
  def get_app():
408
  with gr.Blocks() as app:
409
- with gr.Row():
410
- with gr.Column(scale=2):
411
- load_models_btn = gr.Button(value="Load models")
412
- with gr.Column(scale=5):
413
- with gr.Accordion("Select models", open=False) as models_selector:
414
- with gr.Row():
415
- ssrspeech_model_choice = gr.Radio(label="ssrspeech model", value="English",
416
- choices=["English", "Mandarin"])
417
-
418
  with gr.Row():
419
  with gr.Column(scale=2):
420
  input_audio = gr.Audio(value=f"{DEMO_PATH}/5895_34622_000026_000002.wav", label="Input Audio", type="filepath", interactive=True)
@@ -458,10 +397,6 @@ def get_app():
458
 
459
  success_output = gr.HTML()
460
 
461
- load_models_btn.click(fn=load_models,
462
- inputs=[ssrspeech_model_choice],
463
- outputs=[models_selector, success_output])
464
-
465
  semgents = gr.State() # not used
466
  transcribe_btn.click(fn=transcribe,
467
  inputs=[input_audio],
@@ -469,7 +404,7 @@ def get_app():
469
 
470
  run_btn.click(fn=run,
471
  inputs=[
472
- seed, sub_amount, ssrspeech_model_choice,
473
  codec_audio_sr, codec_sr,
474
  top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
475
  aug_text, cfg_coef, prompt_length,
 
8
  AudioTokenizer,
9
  TextTokenizer,
10
  )
 
11
  from edit_utils_en import parse_edit_en
 
12
  from edit_utils_en import parse_tts_en
13
  from inference_scale import inference_one_sample
14
  import librosa
 
27
  TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
28
  MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
30
 
31
  def get_random_string():
32
  return "".join(str(uuid.uuid4()).split("-"))
 
121
  segment['text'] = replace_numbers_with_words(segment['text'])
122
  return self.align_model.align(segments, audio_path)
123
 
124
+ ssrspeech_model_name = "English"
125
+ text_tokenizer = TextTokenizer(backend="espeak")
126
+ language = "en"
127
+ transcribe_model_name = "base.en"
128
+
129
+ align_model = WhisperxAlignModel(language)
130
+ transcribe_model = WhisperxModel(transcribe_model_name, align_model, language)
131
+
132
+ ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth"
133
+ if not os.path.exists(ssrspeech_fn):
134
+ os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn)
135
+
136
+ ckpt = torch.load(ssrspeech_fn)
137
+ model = ssr.SSR_Speech(ckpt["config"])
138
+ model.load_state_dict(ckpt["model"])
139
+ config = model.args
140
+ phn2num = ckpt["phn2num"]
141
+ model.to(device)
142
+
143
+ encodec_fn = f"{MODELS_PATH}/wmencodec.th"
144
+ if not os.path.exists(encodec_fn):
145
+ os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
146
+
147
+ ssrspeech_model = {
148
+ "config": config,
149
+ "phn2num": phn2num,
150
+ "model": model,
151
+ "text_tokenizer": text_tokenizer,
152
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
153
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def get_transcribe_state(segments):
156
  transcript = " ".join([segment["text"] for segment in segments])
 
162
 
163
  @spaces.GPU
164
  def transcribe(audio_path):
 
 
165
  if transcribe_model is None:
166
  raise gr.Error("Transcription model not loaded")
167
 
 
177
 
178
  @spaces.GPU
179
  def align(segments, audio_path):
 
180
  if align_model is None:
181
  raise gr.Error("Align model not loaded")
182
 
 
204
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
205
 
206
  @spaces.GPU
207
+ def run(seed, sub_amount, codec_audio_sr, codec_sr, top_k, top_p, temperature,
208
  stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
209
  audio_path, original_transcript, transcript, mode):
210
 
 
211
  aug_text = True if aug_text == 1 else False
212
  if ssrspeech_model is None:
213
  raise gr.Error("ssrspeech model not loaded")
214
 
215
  seed_everything(seed)
 
 
 
 
 
216
 
217
  # resample audio
218
  audio, _ = librosa.load(audio_path, sr=16000)
 
223
  orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
224
 
225
  [orig_transcript, segments, _] = transcribe(audio_path)
226
+ orig_transcript = orig_transcript.lower()
227
+ target_transcript = target_transcript.lower()
228
+ transcribe_state = align(segments, audio_path)
 
 
 
 
 
 
229
  print(orig_transcript)
230
  print(target_transcript)
231
 
 
246
  sf.write(audio_path, audio, 16000)
247
  [orig_transcript, segments, _] = transcribe(audio_path)
248
 
249
+
250
+ orig_transcript = orig_transcript.lower()
251
+ target_transcript = target_transcript.lower()
252
+ transcribe_state = align(segments, audio_path)
 
 
 
 
 
253
  print(orig_transcript)
254
  target_transcript_copy = target_transcript # for tts cut out
255
+ target_transcript_copy = target_transcript_copy.split(' ')[0]
256
+ target_transcript = orig_transcript + ' ' + target_transcript
 
 
 
257
  print(target_transcript)
258
 
259
  if mode == "Edit":
260
+ operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
261
  print(operations)
262
  print("orig_spans: ", orig_spans)
263
 
 
325
  torchaudio.save(audio_path, new_audio, codec_audio_sr)
326
  if tts: # remove the start parts
327
  [new_transcript, new_segments, _] = transcribe(audio_path)
328
+ transcribe_state = align(new_segments, audio_path)
329
+ tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
330
+ tmp2 = target_transcript_copy.lower()
 
 
 
 
 
 
331
  if tmp1 == tmp2:
332
  offset = transcribe_state['segments'][0]['words'][0]['start']
333
  else:
 
354
 
355
  def get_app():
356
  with gr.Blocks() as app:
 
 
 
 
 
 
 
 
 
357
  with gr.Row():
358
  with gr.Column(scale=2):
359
  input_audio = gr.Audio(value=f"{DEMO_PATH}/5895_34622_000026_000002.wav", label="Input Audio", type="filepath", interactive=True)
 
397
 
398
  success_output = gr.HTML()
399
 
 
 
 
 
400
  semgents = gr.State() # not used
401
  transcribe_btn.click(fn=transcribe,
402
  inputs=[input_audio],
 
404
 
405
  run_btn.click(fn=run,
406
  inputs=[
407
+ seed, sub_amount,
408
  codec_audio_sr, codec_sr,
409
  top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
410
  aug_text, cfg_coef, prompt_length,