OpenSound commited on
Commit
4622e54
·
1 Parent(s): 1bba970

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +674 -204
app.py CHANGED
@@ -10,6 +10,8 @@ from data.tokenizer import (
10
  )
11
  from edit_utils_en import parse_edit_en
12
  from edit_utils_en import parse_tts_en
 
 
13
  from inference_scale import inference_one_sample
14
  import librosa
15
  import soundfile as sf
@@ -70,37 +72,57 @@ def get_mask_interval(transcribe_state, word_span):
70
 
71
  return (start, end)
72
 
73
- from whisperx import load_align_model, load_model, load_audio
74
- from whisperx import align as align_func
 
 
 
 
 
 
 
75
 
76
- ssrspeech_model_name = "English"
77
- text_tokenizer = TextTokenizer(backend="espeak")
78
- language = "en"
79
- transcribe_model_name = "base.en"
80
 
81
- ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth"
82
- if not os.path.exists(ssrspeech_fn):
83
- os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn)
84
 
85
- ckpt = torch.load(ssrspeech_fn)
86
- model = ssr.SSR_Speech(ckpt["config"])
87
- model.load_state_dict(ckpt["model"])
88
- config = model.args
89
- phn2num = ckpt["phn2num"]
90
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  encodec_fn = f"{MODELS_PATH}/wmencodec.th"
93
- if not os.path.exists(encodec_fn):
94
- os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
95
-
96
- ssrspeech_model = {
97
- "config": config,
98
- "phn2num": phn2num,
99
- "model": model,
100
- "text_tokenizer": text_tokenizer,
101
  "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
102
  }
103
 
 
 
 
 
 
 
 
104
 
105
 
106
  def get_transcribe_state(segments):
@@ -112,7 +134,9 @@ def get_transcribe_state(segments):
112
  }
113
 
114
  @spaces.GPU
115
- def transcribe(audio_path):
 
 
116
  transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
117
  segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
118
  for segment in segments:
@@ -126,9 +150,37 @@ def transcribe(audio_path):
126
  state, success_message
127
  ]
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  @spaces.GPU
131
- def align(segments, audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
132
  align_model, metadata = load_align_model(language_code=language, device=device)
133
  audio = load_audio(audio_path)
134
  segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
@@ -155,14 +207,18 @@ def replace_numbers_with_words(sentence):
155
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
156
 
157
  @spaces.GPU
158
- def run(seed, sub_amount, codec_audio_sr, codec_sr, top_k, top_p, temperature,
159
- stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
160
- audio_path, original_transcript, transcript, mode):
161
 
162
- aug_text = True if aug_text == 1 else False
163
- if ssrspeech_model is None:
164
- raise gr.Error("ssrspeech model not loaded")
 
 
 
 
165
 
 
166
  seed_everything(seed)
167
 
168
  # resample audio
@@ -173,118 +229,269 @@ def run(seed, sub_amount, codec_audio_sr, codec_sr, top_k, top_p, temperature,
173
  target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
174
  orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
175
 
176
- [orig_transcript, segments, _, _] = transcribe(audio_path)
177
  orig_transcript = orig_transcript.lower()
178
  target_transcript = target_transcript.lower()
179
- transcribe_state,_ = align(segments, audio_path)
180
  print(orig_transcript)
181
  print(target_transcript)
182
 
183
- if mode == "TTS":
184
- info = torchaudio.info(audio_path)
185
- duration = info.num_frames / info.sample_rate
186
- cut_length = duration
187
- # Cut long audio for tts
188
- if duration > prompt_length:
189
- seg_num = len(transcribe_state['segments'])
190
- for i in range(seg_num):
191
- words = transcribe_state['segments'][i]['words']
192
- for item in words:
193
- if item['end'] >= prompt_length:
194
- cut_length = min(item['end'], cut_length)
195
-
196
- audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
197
- sf.write(audio_path, audio, 16000)
198
- [orig_transcript, segments, _, _] = transcribe(audio_path)
199
-
200
-
201
- orig_transcript = orig_transcript.lower()
202
- target_transcript = target_transcript.lower()
203
- transcribe_state,_ = align(segments, audio_path)
204
- print(orig_transcript)
205
- target_transcript_copy = target_transcript # for tts cut out
206
- target_transcript_copy = target_transcript_copy.split(' ')[0]
207
- target_transcript = orig_transcript + ' ' + target_transcript
208
- print(target_transcript)
209
-
210
- if mode == "Edit":
211
- operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
212
- print(operations)
213
- print("orig_spans: ", orig_spans)
214
-
215
- if len(orig_spans) > 3:
216
- raise gr.Error("Current model only supports maximum 3 editings")
217
-
218
- starting_intervals = []
219
- ending_intervals = []
220
- for orig_span in orig_spans:
221
- start, end = get_mask_interval(transcribe_state, orig_span)
222
- starting_intervals.append(start)
223
- ending_intervals.append(end)
224
-
225
- print("intervals: ", starting_intervals, ending_intervals)
226
-
227
- info = torchaudio.info(audio_path)
228
- audio_dur = info.num_frames / info.sample_rate
229
-
230
- def combine_spans(spans, threshold=0.2):
231
- spans.sort(key=lambda x: x[0])
232
- combined_spans = []
233
- current_span = spans[0]
234
-
235
- for i in range(1, len(spans)):
236
- next_span = spans[i]
237
- if current_span[1] >= next_span[0] - threshold:
238
- current_span[1] = max(current_span[1], next_span[1])
239
- else:
240
- combined_spans.append(current_span)
241
- current_span = next_span
242
- combined_spans.append(current_span)
243
- return combined_spans
244
 
245
- morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
246
- for start, end in zip(starting_intervals, ending_intervals)] # in seconds
247
- morphed_span = combine_spans(morphed_span, threshold=0.2)
248
- print("morphed_spans: ", morphed_span)
249
- mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
250
- mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  else:
252
- info = torchaudio.info(audio_path)
253
- audio_dur = info.num_frames / info.sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- morphed_span = [(audio_dur, audio_dur)] # in seconds
256
- mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
257
- mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
258
- print("mask_interval: ", mask_interval)
 
 
 
 
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
261
 
262
- tts = True if mode == "TTS" else False
263
  new_audio = inference_one_sample(
264
- ssrspeech_model["model"],
265
- ssrspeech_model["config"],
266
- ssrspeech_model["phn2num"],
267
- ssrspeech_model["text_tokenizer"],
268
- ssrspeech_model["audio_tokenizer"],
269
  audio_path, orig_transcript, target_transcript, mask_interval,
270
- cfg_coef, aug_text, False, True, tts,
271
  device, decode_config
272
  )
273
  audio_tensors = []
274
  # save segments for comparison
275
  new_audio = new_audio[0].cpu()
276
  torchaudio.save(audio_path, new_audio, codec_audio_sr)
277
- if tts: # remove the start parts
278
- [new_transcript, new_segments, _, _] = transcribe(audio_path)
279
- transcribe_state,_ = align(new_segments, audio_path)
280
- tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
281
- tmp2 = target_transcript_copy.lower()
282
- if tmp1 == tmp2:
283
- offset = transcribe_state['segments'][0]['words'][0]['start']
284
- else:
285
- offset = transcribe_state['segments'][0]['words'][1]['start']
286
-
287
- new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
288
  audio_tensors.append(new_audio)
289
  output_audio = get_output_audio(audio_tensors, codec_audio_sr)
290
 
@@ -292,88 +499,112 @@ def run(seed, sub_amount, codec_audio_sr, codec_sr, top_k, top_p, temperature,
292
  return output_audio, success_message
293
 
294
 
295
- demo_original_transcript = "Gwynplain had besides for his work and for his feats of strength, round his neck and over his shoulders, an esclavine of leather."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- demo_text = {
298
- "TTS": {
299
- "regular": "Gwynplain had besides for his work and for his feats of strength, I cannot believe that the same model can also do text to speech synthesis too!"
300
- },
301
- "Edit": {
302
- "regular": "Gwynplain had besides for his work and feats of strength, hanging from his neck and shoulders, an esclavine of leather."
303
- },
304
- }
305
 
 
306
 
307
- def get_app():
308
- with gr.Blocks() as app:
309
- gr.Markdown("""
310
- # SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer
311
- Generate and edit speech from text. Adjust advanced settings for more control.
312
-
313
- Learn more about 🟣**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/).
314
-
315
- 🚀 The **SSR-Speech (Mandarin)** demo is now live! Try it on [🤗SSR-Speech-Mandarin Space](https://huggingface.co/spaces/OpenSound/SSR-Speech-Mandarin).
316
- """)
317
- with gr.Row():
318
- with gr.Column(scale=2):
319
- input_audio = gr.Audio(value=f"{DEMO_PATH}/5895_34622_000026_000002.wav", label="Input Audio", type="filepath", interactive=True)
320
- with gr.Group():
321
- original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript,
322
- info="Use whisperx model to get the transcript.")
323
- transcribe_btn = gr.Button(value="Transcribe")
324
-
325
- with gr.Column(scale=3):
326
- with gr.Group():
327
- transcript = gr.Textbox(label="Text", lines=7, value=demo_text["Edit"]["regular"])
328
-
329
- with gr.Row():
330
- mode = gr.Radio(label="Mode", choices=["Edit", "TTS"], value="Edit")
331
-
332
- run_btn = gr.Button(value="Run")
333
-
334
- with gr.Column(scale=2):
335
- output_audio = gr.Audio(label="Output Audio")
336
-
337
- with gr.Row():
338
- with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False):
339
- stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=2,
340
- info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled")
341
- seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
342
- kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
343
- info="set to 0 to use less VRAM, but with slower inference")
344
- aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
345
- info="set to 1 to use cfg")
346
- cfg_coef = gr.Number(label="cfg_coef", value=1.5,
347
- info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
348
- prompt_length = gr.Number(label="prompt_length", value=3,
349
- info="used for tts prompt, will automatically cut the prompt audio to this length")
350
- sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
351
- top_p = gr.Number(label="top_p", value=0.8, info="0.9 is a good value, 0.8 is also good")
352
- temperature = gr.Number(label="temperature", value=1, info="haven't try other values, do not change")
353
- top_k = gr.Number(label="top_k", value=0, info="0 means we don't use topk sampling, because we use topp sampling")
354
- codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000, info='encodec specific, do not change')
355
- codec_sr = gr.Number(label="codec_sr", value=50, info='encodec specific, do not change')
356
- silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]", info="encodec specific, do not change")
357
-
358
- success_output = gr.HTML()
359
-
360
- semgents = gr.State() # not used
361
- transcribe_btn.click(fn=transcribe,
362
- inputs=[input_audio],
363
- outputs=[original_transcript, semgents, success_output])
364
-
365
- run_btn.click(fn=run,
366
- inputs=[
367
- seed, sub_amount,
368
- codec_audio_sr, codec_sr,
369
- top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
370
- aug_text, cfg_coef, prompt_length,
371
- input_audio, original_transcript, transcript,
372
- mode
373
- ],
374
- outputs=[output_audio, success_output])
375
 
376
- return app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
 
379
  if __name__ == "__main__":
@@ -393,5 +624,244 @@ if __name__ == "__main__":
393
  TMP_PATH = args.tmp_path
394
  MODELS_PATH = args.models_path
395
 
396
- app = get_app()
397
- app.queue().launch(share=args.share, server_port=args.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
11
  from edit_utils_en import parse_edit_en
12
  from edit_utils_en import parse_tts_en
13
+ from edit_utils_zh import parse_edit_zh
14
+ from edit_utils_zh import parse_tts_zh
15
  from inference_scale import inference_one_sample
16
  import librosa
17
  import soundfile as sf
 
72
 
73
  return (start, end)
74
 
75
+ def traditional_to_simplified(segments):
76
+ converter = opencc.OpenCC('t2s')
77
+ seg_num = len(segments)
78
+ for i in range(seg_num):
79
+ words = segments[i]['words']
80
+ for j in range(len(words)):
81
+ segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word'])
82
+ segments[i]['text'] = converter.convert(segments[i]['text'])
83
+ return segments
84
 
 
 
 
 
85
 
86
+ from whisperx import load_align_model, load_model, load_audio
87
+ from whisperx import align as align_func
 
88
 
89
+ # Load models
90
+ text_tokenizer_en = TextTokenizer(backend="espeak")
91
+ text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
92
+
93
+ ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
94
+ ckpt_en = torch.load(ssrspeech_fn_en)
95
+ model_en = ssr.SSR_Speech(ckpt_en["config"])
96
+ model_en.load_state_dict(ckpt_en["model"])
97
+ config_en = model_en.args
98
+ phn2num_en = ckpt_en["phn2num"]
99
+ model_en.to(device)
100
+
101
+ ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
102
+ ckpt_zh = torch.load(ssrspeech_fn_zh)
103
+ model_zh = ssr.SSR_Speech(ckpt_zh["config"])
104
+ model_zh.load_state_dict(ckpt_zh["model"])
105
+ config_zh = model_zh.args
106
+ phn2num_zh = ckpt_zh["phn2num"]
107
+ model_zh.to(device)
108
 
109
  encodec_fn = f"{MODELS_PATH}/wmencodec.th"
110
+
111
+ ssrspeech_model_en = {
112
+ "config": config_en,
113
+ "phn2num": phn2num_en,
114
+ "model": model_en,
115
+ "text_tokenizer": text_tokenizer_en,
 
 
116
  "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
117
  }
118
 
119
+ ssrspeech_model_zh = {
120
+ "config": config_zh,
121
+ "phn2num": phn2num_zh,
122
+ "model": model_zh,
123
+ "text_tokenizer": text_tokenizer_zh,
124
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
125
+ }
126
 
127
 
128
  def get_transcribe_state(segments):
 
134
  }
135
 
136
  @spaces.GPU
137
+ def transcribe_en(audio_path):
138
+ language = "en"
139
+ transcribe_model_name = "base.en"
140
  transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
141
  segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
142
  for segment in segments:
 
150
  state, success_message
151
  ]
152
 
153
+ @spaces.GPU
154
+ def transcribe_zh(audio_path):
155
+ language = "zh"
156
+ transcribe_model_name = "base"
157
+ transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
158
+ segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
159
+ for segment in segments:
160
+ segment['text'] = replace_numbers_with_words(segment['text'])
161
+ _, segments = align(segments, audio_path)
162
+ state = get_transcribe_state(segments)
163
+ success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
164
+
165
+ return [
166
+ state["transcript"], state['segments'],
167
+ state, success_message
168
+ ]
169
 
170
  @spaces.GPU
171
+ def align_en(segments, audio_path):
172
+ language = "en"
173
+ align_model, metadata = load_align_model(language_code=language, device=device)
174
+ audio = load_audio(audio_path)
175
+ segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
176
+ state = get_transcribe_state(segments)
177
+
178
+ return state, segments
179
+
180
+
181
+ @spaces.GPU
182
+ def align_zh(segments, audio_path):
183
+ language = "zh"
184
  align_model, metadata = load_align_model(language_code=language, device=device)
185
  audio = load_audio(audio_path)
186
  segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
 
207
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
208
 
209
  @spaces.GPU
210
+ def run_edit_en(seed, sub_amount, aug_text, cfg_coef, prompt_length,
211
+ audio_path, original_transcript, transcript):
 
212
 
213
+ codec_audio_sr = 16000
214
+ codec_sr = 50
215
+ top_k = 0
216
+ top_p = 0.8
217
+ temperature = 1
218
+ kvcache = 1
219
+ stop_repetition = 2
220
 
221
+ aug_text = True if aug_text == 1 else False
222
  seed_everything(seed)
223
 
224
  # resample audio
 
229
  target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
230
  orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
231
 
232
+ [orig_transcript, segments, _, _] = transcribe_en(audio_path)
233
  orig_transcript = orig_transcript.lower()
234
  target_transcript = target_transcript.lower()
235
+ transcribe_state,_ = align_en(segments, audio_path)
236
  print(orig_transcript)
237
  print(target_transcript)
238
 
239
+ operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
240
+ print(operations)
241
+ print("orig_spans: ", orig_spans)
242
+
243
+ if len(orig_spans) > 3:
244
+ raise gr.Error("Current model only supports maximum 3 editings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ starting_intervals = []
247
+ ending_intervals = []
248
+ for orig_span in orig_spans:
249
+ start, end = get_mask_interval(transcribe_state, orig_span)
250
+ starting_intervals.append(start)
251
+ ending_intervals.append(end)
252
+
253
+ print("intervals: ", starting_intervals, ending_intervals)
254
+
255
+ info = torchaudio.info(audio_path)
256
+ audio_dur = info.num_frames / info.sample_rate
257
+
258
+ def combine_spans(spans, threshold=0.2):
259
+ spans.sort(key=lambda x: x[0])
260
+ combined_spans = []
261
+ current_span = spans[0]
262
+
263
+ for i in range(1, len(spans)):
264
+ next_span = spans[i]
265
+ if current_span[1] >= next_span[0] - threshold:
266
+ current_span[1] = max(current_span[1], next_span[1])
267
+ else:
268
+ combined_spans.append(current_span)
269
+ current_span = next_span
270
+ combined_spans.append(current_span)
271
+ return combined_spans
272
+
273
+ morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
274
+ for start, end in zip(starting_intervals, ending_intervals)] # in seconds
275
+ morphed_span = combine_spans(morphed_span, threshold=0.2)
276
+ print("morphed_spans: ", morphed_span)
277
+ mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
278
+ mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
279
+
280
+ decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
281
+
282
+ new_audio = inference_one_sample(
283
+ ssrspeech_model_en["model"],
284
+ ssrspeech_model_en["config"],
285
+ ssrspeech_model_en["phn2num"],
286
+ ssrspeech_model_en["text_tokenizer"],
287
+ ssrspeech_model_en["audio_tokenizer"],
288
+ audio_path, orig_transcript, target_transcript, mask_interval,
289
+ cfg_coef, aug_text, False, True, False,
290
+ device, decode_config
291
+ )
292
+ audio_tensors = []
293
+ # save segments for comparison
294
+ new_audio = new_audio[0].cpu()
295
+ torchaudio.save(audio_path, new_audio, codec_audio_sr)
296
+
297
+ audio_tensors.append(new_audio)
298
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
299
+
300
+ success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
301
+ return output_audio, success_message
302
+
303
+
304
+ @spaces.GPU
305
+ def run_tts_en(seed, sub_amount, aug_text, cfg_coef, prompt_length,
306
+ audio_path, original_transcript, transcript):
307
+
308
+ codec_audio_sr = 16000
309
+ codec_sr = 50
310
+ top_k = 0
311
+ top_p = 0.8
312
+ temperature = 1
313
+ kvcache = 1
314
+ stop_repetition = 2
315
+
316
+ aug_text = True if aug_text == 1 else False
317
+ seed_everything(seed)
318
+
319
+ # resample audio
320
+ audio, _ = librosa.load(audio_path, sr=16000)
321
+ sf.write(audio_path, audio, 16000)
322
+
323
+ # text normalization
324
+ target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
325
+ orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
326
+
327
+ [orig_transcript, segments, _, _] = transcribe_en(audio_path)
328
+ orig_transcript = orig_transcript.lower()
329
+ target_transcript = target_transcript.lower()
330
+ transcribe_state,_ = align_en(segments, audio_path)
331
+ print(orig_transcript)
332
+ print(target_transcript)
333
+
334
+
335
+ info = torchaudio.info(audio_path)
336
+ duration = info.num_frames / info.sample_rate
337
+ cut_length = duration
338
+ # Cut long audio for tts
339
+ if duration > prompt_length:
340
+ seg_num = len(transcribe_state['segments'])
341
+ for i in range(seg_num):
342
+ words = transcribe_state['segments'][i]['words']
343
+ for item in words:
344
+ if item['end'] >= prompt_length:
345
+ cut_length = min(item['end'], cut_length)
346
+
347
+ audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
348
+ sf.write(audio_path, audio, 16000)
349
+ [orig_transcript, segments, _, _] = transcribe_en(audio_path)
350
+
351
+
352
+ orig_transcript = orig_transcript.lower()
353
+ target_transcript = target_transcript.lower()
354
+ transcribe_state,_ = align_en(segments, audio_path)
355
+ print(orig_transcript)
356
+ target_transcript_copy = target_transcript # for tts cut out
357
+ target_transcript_copy = target_transcript_copy.split(' ')[0]
358
+ target_transcript = orig_transcript + ' ' + target_transcript
359
+ print(target_transcript)
360
+
361
+
362
+ info = torchaudio.info(audio_path)
363
+ audio_dur = info.num_frames / info.sample_rate
364
+
365
+ morphed_span = [(audio_dur, audio_dur)] # in seconds
366
+ mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
367
+ mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
368
+ print("mask_interval: ", mask_interval)
369
+
370
+ decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
371
+
372
+ new_audio = inference_one_sample(
373
+ ssrspeech_model_en["model"],
374
+ ssrspeech_model_en["config"],
375
+ ssrspeech_model_en["phn2num"],
376
+ ssrspeech_model_en["text_tokenizer"],
377
+ ssrspeech_model_en["audio_tokenizer"],
378
+ audio_path, orig_transcript, target_transcript, mask_interval,
379
+ cfg_coef, aug_text, False, True, True,
380
+ device, decode_config
381
+ )
382
+ audio_tensors = []
383
+ # save segments for comparison
384
+ new_audio = new_audio[0].cpu()
385
+ torchaudio.save(audio_path, new_audio, codec_audio_sr)
386
+
387
+ [new_transcript, new_segments, _, _] = transcribe_en(audio_path)
388
+ transcribe_state,_ = align_en(new_segments, audio_path)
389
+ tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
390
+ tmp2 = target_transcript_copy.lower()
391
+ if tmp1 == tmp2:
392
+ offset = transcribe_state['segments'][0]['words'][0]['start']
393
  else:
394
+ offset = transcribe_state['segments'][0]['words'][1]['start']
395
+
396
+ new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
397
+ audio_tensors.append(new_audio)
398
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
399
+
400
+ success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
401
+ return output_audio, success_message
402
+
403
+
404
+ @spaces.GPU
405
+ def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, prompt_length,
406
+ audio_path, original_transcript, transcript):
407
+
408
+ codec_audio_sr = 16000
409
+ codec_sr = 50
410
+ top_k = 0
411
+ top_p = 0.8
412
+ temperature = 1
413
+ kvcache = 1
414
+ stop_repetition = 2
415
+
416
+ aug_text = True if aug_text == 1 else False
417
+
418
+ seed_everything(seed)
419
+
420
+ # resample audio
421
+ audio, _ = librosa.load(audio_path, sr=16000)
422
+ sf.write(audio_path, audio, 16000)
423
+
424
+ # text normalization
425
+ target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
426
+ orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
427
+
428
+ [orig_transcript, segments, _] = transcribe_zh(audio_path)
429
+
430
+ converter = opencc.OpenCC('t2s')
431
+ orig_transcript = converter.convert(orig_transcript)
432
+ transcribe_state = align_zh(traditional_to_simplified(segments), audio_path)
433
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
434
+
435
+ print(orig_transcript)
436
+ print(target_transcript)
437
+
438
+ operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
439
+ print(operations)
440
+ print("orig_spans: ", orig_spans)
441
+
442
+ if len(orig_spans) > 3:
443
+ raise gr.Error("Current model only supports maximum 3 editings")
444
 
445
+ starting_intervals = []
446
+ ending_intervals = []
447
+ for orig_span in orig_spans:
448
+ start, end = get_mask_interval(transcribe_state, orig_span)
449
+ starting_intervals.append(start)
450
+ ending_intervals.append(end)
451
+
452
+ print("intervals: ", starting_intervals, ending_intervals)
453
 
454
+ info = torchaudio.info(audio_path)
455
+ audio_dur = info.num_frames / info.sample_rate
456
+
457
+ def combine_spans(spans, threshold=0.2):
458
+ spans.sort(key=lambda x: x[0])
459
+ combined_spans = []
460
+ current_span = spans[0]
461
+
462
+ for i in range(1, len(spans)):
463
+ next_span = spans[i]
464
+ if current_span[1] >= next_span[0] - threshold:
465
+ current_span[1] = max(current_span[1], next_span[1])
466
+ else:
467
+ combined_spans.append(current_span)
468
+ current_span = next_span
469
+ combined_spans.append(current_span)
470
+ return combined_spans
471
+
472
+ morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
473
+ for start, end in zip(starting_intervals, ending_intervals)] # in seconds
474
+ morphed_span = combine_spans(morphed_span, threshold=0.2)
475
+ print("morphed_spans: ", morphed_span)
476
+ mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
477
+ mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
478
+
479
  decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
480
 
 
481
  new_audio = inference_one_sample(
482
+ ssrspeech_model_zh["model"],
483
+ ssrspeech_model_zh["config"],
484
+ ssrspeech_model_zh["phn2num"],
485
+ ssrspeech_model_zh["text_tokenizer"],
486
+ ssrspeech_model_zh["audio_tokenizer"],
487
  audio_path, orig_transcript, target_transcript, mask_interval,
488
+ cfg_coef, aug_text, False, True, False,
489
  device, decode_config
490
  )
491
  audio_tensors = []
492
  # save segments for comparison
493
  new_audio = new_audio[0].cpu()
494
  torchaudio.save(audio_path, new_audio, codec_audio_sr)
 
 
 
 
 
 
 
 
 
 
 
495
  audio_tensors.append(new_audio)
496
  output_audio = get_output_audio(audio_tensors, codec_audio_sr)
497
 
 
499
  return output_audio, success_message
500
 
501
 
502
+ @spaces.GPU
503
+ def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, prompt_length,
504
+ audio_path, original_transcript, transcript):
505
+
506
+ codec_audio_sr = 16000
507
+ codec_sr = 50
508
+ top_k = 0
509
+ top_p = 0.8
510
+ temperature = 1
511
+ kvcache = 1
512
+ stop_repetition = 2
513
+
514
+ aug_text = True if aug_text == 1 else False
515
+
516
+ seed_everything(seed)
517
 
518
+ # resample audio
519
+ audio, _ = librosa.load(audio_path, sr=16000)
520
+ sf.write(audio_path, audio, 16000)
521
+
522
+ # text normalization
523
+ target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
524
+ orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
 
525
 
526
+ [orig_transcript, segments, _] = transcribe_zh(audio_path)
527
 
528
+ converter = opencc.OpenCC('t2s')
529
+ orig_transcript = converter.convert(orig_transcript)
530
+ transcribe_state = align_zh(traditional_to_simplified(segments), audio_path)
531
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
532
+
533
+ print(orig_transcript)
534
+ print(target_transcript)
535
+
536
+ info = torchaudio.info(audio_path)
537
+ duration = info.num_frames / info.sample_rate
538
+ cut_length = duration
539
+ # Cut long audio for tts
540
+ if duration > prompt_length:
541
+ seg_num = len(transcribe_state['segments'])
542
+ for i in range(seg_num):
543
+ words = transcribe_state['segments'][i]['words']
544
+ for item in words:
545
+ if item['end'] >= prompt_length:
546
+ cut_length = min(item['end'], cut_length)
547
+
548
+ audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
549
+ sf.write(audio_path, audio, 16000)
550
+ [orig_transcript, segments, _] = transcribe_zh(audio_path)
551
+
552
+
553
+ converter = opencc.OpenCC('t2s')
554
+ orig_transcript = converter.convert(orig_transcript)
555
+ transcribe_state = align_zh(traditional_to_simplified(segments), audio_path)
556
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
557
+
558
+ print(orig_transcript)
559
+ target_transcript_copy = target_transcript # for tts cut out
560
+ target_transcript_copy = target_transcript_copy[0]
561
+ target_transcript = orig_transcript + target_transcript
562
+ print(target_transcript)
563
+
564
+
565
+ info = torchaudio.info(audio_path)
566
+ audio_dur = info.num_frames / info.sample_rate
567
+
568
+ morphed_span = [(audio_dur, audio_dur)] # in seconds
569
+ mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
570
+ mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
571
+ print("mask_interval: ", mask_interval)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
+ decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
574
+
575
+ new_audio = inference_one_sample(
576
+ ssrspeech_model_zh["model"],
577
+ ssrspeech_model_zh["config"],
578
+ ssrspeech_model_zh["phn2num"],
579
+ ssrspeech_model_zh["text_tokenizer"],
580
+ ssrspeech_model_zh["audio_tokenizer"],
581
+ audio_path, orig_transcript, target_transcript, mask_interval,
582
+ cfg_coef, aug_text, False, True, True,
583
+ device, decode_config
584
+ )
585
+ audio_tensors = []
586
+ # save segments for comparison
587
+ new_audio = new_audio[0].cpu()
588
+ torchaudio.save(audio_path, new_audio, codec_audio_sr)
589
+
590
+ [new_transcript, new_segments, _] = transcribe_zh(audio_path)
591
+
592
+ transcribe_state = align_zh(traditional_to_simplified(new_segments), audio_path)
593
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
594
+ tmp1 = transcribe_state['segments'][0]['words'][0]['word']
595
+ tmp2 = target_transcript_copy
596
+
597
+ if tmp1 == tmp2:
598
+ offset = transcribe_state['segments'][0]['words'][0]['start']
599
+ else:
600
+ offset = transcribe_state['segments'][0]['words'][1]['start']
601
+
602
+ new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
603
+ audio_tensors.append(new_audio)
604
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
605
+
606
+ success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
607
+ return output_audio, success_message
608
 
609
 
610
  if __name__ == "__main__":
 
624
  TMP_PATH = args.tmp_path
625
  MODELS_PATH = args.models_path
626
 
627
+ # app = get_app()
628
+ # app.queue().launch(share=args.share, server_port=args.port)
629
+
630
+ # CSS styling (optional)
631
+ css = """
632
+ #col-container {
633
+ margin: 0 auto;
634
+ max-width: 1280px;
635
+ }
636
+ """
637
+
638
+ # Gradio Blocks layout
639
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
640
+ with gr.Column(elem_id="col-container"):
641
+ gr.Markdown("""
642
+ # SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer
643
+ Generate and edit speech from text. Adjust advanced settings for more control.
644
+
645
+ Learn more about 🚀**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/).
646
+ """)
647
+
648
+
649
+ # Tabs for Generate and Edit
650
+ with gr.Tab("English Speech Editing"):
651
+
652
+ with gr.Row():
653
+ with gr.Column(scale=2):
654
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
655
+ with gr.Group():
656
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug",
657
+ info="Use whisperx model to get the transcript.")
658
+ transcribe_btn = gr.Button(value="Transcribe")
659
+
660
+ with gr.Column(scale=3):
661
+ with gr.Group():
662
+ transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True)
663
+ run_btn = gr.Button(value="Run")
664
+
665
+ with gr.Column(scale=2):
666
+ output_audio = gr.Audio(label="Output Audio")
667
+
668
+ with gr.Row():
669
+ with gr.Accordion("Advanced Settings", open=False):
670
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
671
+ aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
672
+ info="set to 1 to use classifer-free guidance, change if you don't like the results")
673
+ cfg_coef = gr.Number(label="cfg_coef", value=1.5,
674
+ info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
675
+ prompt_length = gr.Number(label="prompt_length", value=3,
676
+ info="used for tts prompt, will automatically cut the prompt audio to this length")
677
+ sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
678
+
679
+ success_output = gr.HTML()
680
+
681
+ semgents = gr.State() # not used
682
+ state = gr.State() # not used
683
+ transcribe_btn.click(fn=transcribe_en,
684
+ inputs=[input_audio],
685
+ outputs=[original_transcript, semgents, state, success_output])
686
+
687
+ run_btn.click(fn=run_edit_en,
688
+ inputs=[
689
+ seed, sub_amount,
690
+ aug_text, cfg_coef, prompt_length,
691
+ input_audio, original_transcript, transcript,
692
+ ],
693
+ outputs=[output_audio, success_output])
694
+
695
+ transcript.submit(fn=run_edit_en,
696
+ inputs=[
697
+ seed, sub_amount,
698
+ aug_text, cfg_coef, prompt_length,
699
+ input_audio, original_transcript, transcript,
700
+ ],
701
+ outputs=[output_audio, success_output]
702
+ )
703
+
704
+ with gr.Tab("English TTS"):
705
+
706
+ with gr.Row():
707
+ with gr.Column(scale=2):
708
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
709
+ with gr.Group():
710
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug",
711
+ info="Use whisperx model to get the transcript.")
712
+ transcribe_btn = gr.Button(value="Transcribe")
713
+
714
+ with gr.Column(scale=3):
715
+ with gr.Group():
716
+ transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True)
717
+ run_btn = gr.Button(value="Run")
718
+
719
+ with gr.Column(scale=2):
720
+ output_audio = gr.Audio(label="Output Audio")
721
+
722
+ with gr.Row():
723
+ with gr.Accordion("Advanced Settings", open=False):
724
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
725
+ aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
726
+ info="set to 1 to use classifer-free guidance, change if you don't like the results")
727
+ cfg_coef = gr.Number(label="cfg_coef", value=1.5,
728
+ info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
729
+ prompt_length = gr.Number(label="prompt_length", value=3,
730
+ info="used for tts prompt, will automatically cut the prompt audio to this length")
731
+ sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
732
+
733
+ success_output = gr.HTML()
734
+
735
+ semgents = gr.State() # not used
736
+ state = gr.State() # not used
737
+ transcribe_btn.click(fn=transcribe_en,
738
+ inputs=[input_audio],
739
+ outputs=[original_transcript, semgents, state, success_output])
740
+
741
+ run_btn.click(fn=run_tts_en,
742
+ inputs=[
743
+ seed, sub_amount,
744
+ aug_text, cfg_coef, prompt_length,
745
+ input_audio, original_transcript, transcript,
746
+ ],
747
+ outputs=[output_audio, success_output])
748
+
749
+ transcript.submit(fn=run_tts_en,
750
+ inputs=[
751
+ seed, sub_amount,
752
+ aug_text, cfg_coef, prompt_length,
753
+ input_audio, original_transcript, transcript,
754
+ ],
755
+ outputs=[output_audio, success_output]
756
+ )
757
+
758
+ with gr.Tab("Mandarin Speech Editing"):
759
+
760
+ with gr.Row():
761
+ with gr.Column(scale=2):
762
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
763
+ with gr.Group():
764
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug",
765
+ info="Use whisperx model to get the transcript.")
766
+ transcribe_btn = gr.Button(value="Transcribe")
767
+
768
+ with gr.Column(scale=3):
769
+ with gr.Group():
770
+ transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True)
771
+ run_btn = gr.Button(value="Run")
772
+
773
+ with gr.Column(scale=2):
774
+ output_audio = gr.Audio(label="Output Audio")
775
+
776
+ with gr.Row():
777
+ with gr.Accordion("Advanced Settings", open=False):
778
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
779
+ aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
780
+ info="set to 1 to use classifer-free guidance, change if you don't like the results")
781
+ cfg_coef = gr.Number(label="cfg_coef", value=1.5,
782
+ info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
783
+ prompt_length = gr.Number(label="prompt_length", value=3,
784
+ info="used for tts prompt, will automatically cut the prompt audio to this length")
785
+ sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
786
+
787
+ success_output = gr.HTML()
788
+
789
+ semgents = gr.State() # not used
790
+ state = gr.State() # not used
791
+ transcribe_btn.click(fn=transcribe_zh,
792
+ inputs=[input_audio],
793
+ outputs=[original_transcript, semgents, state, success_output])
794
+
795
+ run_btn.click(fn=run_edit_zh,
796
+ inputs=[
797
+ seed, sub_amount,
798
+ aug_text, cfg_coef, prompt_length,
799
+ input_audio, original_transcript, transcript,
800
+ ],
801
+ outputs=[output_audio, success_output])
802
+
803
+ transcript.submit(fn=run_edit_zh,
804
+ inputs=[
805
+ seed, sub_amount,
806
+ aug_text, cfg_coef, prompt_length,
807
+ input_audio, original_transcript, transcript,
808
+ ],
809
+ outputs=[output_audio, success_output]
810
+ )
811
+
812
+ with gr.Tab("Mandarin TTS"):
813
+
814
+ with gr.Row():
815
+ with gr.Column(scale=2):
816
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
817
+ with gr.Group():
818
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug",
819
+ info="Use whisperx model to get the transcript.")
820
+ transcribe_btn = gr.Button(value="Transcribe")
821
+
822
+ with gr.Column(scale=3):
823
+ with gr.Group():
824
+ transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True)
825
+ run_btn = gr.Button(value="Run")
826
+
827
+ with gr.Column(scale=2):
828
+ output_audio = gr.Audio(label="Output Audio")
829
+
830
+ with gr.Row():
831
+ with gr.Accordion("Advanced Settings", open=False):
832
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
833
+ aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
834
+ info="set to 1 to use classifer-free guidance, change if you don't like the results")
835
+ cfg_coef = gr.Number(label="cfg_coef", value=1.5,
836
+ info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
837
+ prompt_length = gr.Number(label="prompt_length", value=3,
838
+ info="used for tts prompt, will automatically cut the prompt audio to this length")
839
+ sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
840
+
841
+ success_output = gr.HTML()
842
+
843
+ semgents = gr.State() # not used
844
+ state = gr.State() # not used
845
+ transcribe_btn.click(fn=transcribe_zh,
846
+ inputs=[input_audio],
847
+ outputs=[original_transcript, semgents, state, success_output])
848
+
849
+ run_btn.click(fn=run_tts_zh,
850
+ inputs=[
851
+ seed, sub_amount,
852
+ aug_text, cfg_coef, prompt_length,
853
+ input_audio, original_transcript, transcript,
854
+ ],
855
+ outputs=[output_audio, success_output])
856
+
857
+ transcript.submit(fn=run_tts_zh,
858
+ inputs=[
859
+ seed, sub_amount,
860
+ aug_text, cfg_coef, prompt_length,
861
+ input_audio, original_transcript, transcript,
862
+ ],
863
+ outputs=[output_audio, success_output]
864
+ )
865
+
866
+ # Launch the Gradio demo
867
+ demo.launch()