Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -8,9 +8,7 @@ from data.tokenizer import (
|
|
8 |
AudioTokenizer,
|
9 |
TextTokenizer,
|
10 |
)
|
11 |
-
from edit_utils_zh import parse_edit_zh
|
12 |
from edit_utils_en import parse_edit_en
|
13 |
-
from edit_utils_zh import parse_tts_zh
|
14 |
from edit_utils_en import parse_tts_en
|
15 |
from inference_scale import inference_one_sample
|
16 |
import librosa
|
@@ -29,7 +27,6 @@ DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
|
|
29 |
TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
|
30 |
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
|
31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
-
transcribe_model, align_model, ssrspeech_model = None, None, None
|
33 |
|
34 |
def get_random_string():
|
35 |
return "".join(str(uuid.uuid4()).split("-"))
|
@@ -124,56 +121,36 @@ class WhisperxModel:
|
|
124 |
segment['text'] = replace_numbers_with_words(segment['text'])
|
125 |
return self.align_model.align(segments, audio_path)
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
model
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
model.to(device)
|
158 |
-
|
159 |
-
encodec_fn = f"{MODELS_PATH}/wmencodec.th"
|
160 |
-
if not os.path.exists(encodec_fn):
|
161 |
-
os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
|
162 |
-
|
163 |
-
ssrspeech_model = {
|
164 |
-
"config": config,
|
165 |
-
"phn2num": phn2num,
|
166 |
-
"model": model,
|
167 |
-
"text_tokenizer": text_tokenizer,
|
168 |
-
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
|
169 |
-
}
|
170 |
-
success_message = "<span style='color:green;'>Success: Models loading completed successfully!</span>"
|
171 |
-
|
172 |
-
return [
|
173 |
-
gr.Accordion(),
|
174 |
-
success_message
|
175 |
-
]
|
176 |
-
|
177 |
|
178 |
def get_transcribe_state(segments):
|
179 |
transcript = " ".join([segment["text"] for segment in segments])
|
@@ -185,8 +162,6 @@ def get_transcribe_state(segments):
|
|
185 |
|
186 |
@spaces.GPU
|
187 |
def transcribe(audio_path):
|
188 |
-
global transcribe_model
|
189 |
-
|
190 |
if transcribe_model is None:
|
191 |
raise gr.Error("Transcription model not loaded")
|
192 |
|
@@ -202,7 +177,6 @@ def transcribe(audio_path):
|
|
202 |
|
203 |
@spaces.GPU
|
204 |
def align(segments, audio_path):
|
205 |
-
global align_model
|
206 |
if align_model is None:
|
207 |
raise gr.Error("Align model not loaded")
|
208 |
|
@@ -230,21 +204,15 @@ def replace_numbers_with_words(sentence):
|
|
230 |
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
|
231 |
|
232 |
@spaces.GPU
|
233 |
-
def run(seed, sub_amount,
|
234 |
stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
|
235 |
audio_path, original_transcript, transcript, mode):
|
236 |
|
237 |
-
global transcribe_model, align_model, ssrspeech_model
|
238 |
aug_text = True if aug_text == 1 else False
|
239 |
if ssrspeech_model is None:
|
240 |
raise gr.Error("ssrspeech model not loaded")
|
241 |
|
242 |
seed_everything(seed)
|
243 |
-
|
244 |
-
if ssrspeech_model_choice == "English":
|
245 |
-
language = "en"
|
246 |
-
elif ssrspeech_model_choice == "Mandarin":
|
247 |
-
language = "zh"
|
248 |
|
249 |
# resample audio
|
250 |
audio, _ = librosa.load(audio_path, sr=16000)
|
@@ -255,15 +223,9 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
|
|
255 |
orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
|
256 |
|
257 |
[orig_transcript, segments, _] = transcribe(audio_path)
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
transcribe_state = align(traditional_to_simplified(segments), audio_path)
|
262 |
-
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
|
263 |
-
elif language == 'en':
|
264 |
-
orig_transcript = orig_transcript.lower()
|
265 |
-
target_transcript = target_transcript.lower()
|
266 |
-
transcribe_state = align(segments, audio_path)
|
267 |
print(orig_transcript)
|
268 |
print(target_transcript)
|
269 |
|
@@ -284,26 +246,18 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
|
|
284 |
sf.write(audio_path, audio, 16000)
|
285 |
[orig_transcript, segments, _] = transcribe(audio_path)
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
|
292 |
-
elif language == 'en':
|
293 |
-
orig_transcript = orig_transcript.lower()
|
294 |
-
target_transcript = target_transcript.lower()
|
295 |
-
transcribe_state = align(segments, audio_path)
|
296 |
print(orig_transcript)
|
297 |
target_transcript_copy = target_transcript # for tts cut out
|
298 |
-
|
299 |
-
|
300 |
-
elif language == 'zh':
|
301 |
-
target_transcript_copy = target_transcript_copy[0]
|
302 |
-
target_transcript = orig_transcript + ' ' + target_transcript if language == 'en' else orig_transcript + target_transcript
|
303 |
print(target_transcript)
|
304 |
|
305 |
if mode == "Edit":
|
306 |
-
operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
|
307 |
print(operations)
|
308 |
print("orig_spans: ", orig_spans)
|
309 |
|
@@ -371,15 +325,9 @@ def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_
|
|
371 |
torchaudio.save(audio_path, new_audio, codec_audio_sr)
|
372 |
if tts: # remove the start parts
|
373 |
[new_transcript, new_segments, _] = transcribe(audio_path)
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
tmp1 = transcribe_state['segments'][0]['words'][0]['word']
|
378 |
-
tmp2 = target_transcript_copy
|
379 |
-
elif language == 'en':
|
380 |
-
transcribe_state = align(new_segments, audio_path)
|
381 |
-
tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
|
382 |
-
tmp2 = target_transcript_copy.lower()
|
383 |
if tmp1 == tmp2:
|
384 |
offset = transcribe_state['segments'][0]['words'][0]['start']
|
385 |
else:
|
@@ -406,15 +354,6 @@ demo_text = {
|
|
406 |
|
407 |
def get_app():
|
408 |
with gr.Blocks() as app:
|
409 |
-
with gr.Row():
|
410 |
-
with gr.Column(scale=2):
|
411 |
-
load_models_btn = gr.Button(value="Load models")
|
412 |
-
with gr.Column(scale=5):
|
413 |
-
with gr.Accordion("Select models", open=False) as models_selector:
|
414 |
-
with gr.Row():
|
415 |
-
ssrspeech_model_choice = gr.Radio(label="ssrspeech model", value="English",
|
416 |
-
choices=["English", "Mandarin"])
|
417 |
-
|
418 |
with gr.Row():
|
419 |
with gr.Column(scale=2):
|
420 |
input_audio = gr.Audio(value=f"{DEMO_PATH}/5895_34622_000026_000002.wav", label="Input Audio", type="filepath", interactive=True)
|
@@ -458,10 +397,6 @@ def get_app():
|
|
458 |
|
459 |
success_output = gr.HTML()
|
460 |
|
461 |
-
load_models_btn.click(fn=load_models,
|
462 |
-
inputs=[ssrspeech_model_choice],
|
463 |
-
outputs=[models_selector, success_output])
|
464 |
-
|
465 |
semgents = gr.State() # not used
|
466 |
transcribe_btn.click(fn=transcribe,
|
467 |
inputs=[input_audio],
|
@@ -469,7 +404,7 @@ def get_app():
|
|
469 |
|
470 |
run_btn.click(fn=run,
|
471 |
inputs=[
|
472 |
-
seed, sub_amount,
|
473 |
codec_audio_sr, codec_sr,
|
474 |
top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
|
475 |
aug_text, cfg_coef, prompt_length,
|
|
|
8 |
AudioTokenizer,
|
9 |
TextTokenizer,
|
10 |
)
|
|
|
11 |
from edit_utils_en import parse_edit_en
|
|
|
12 |
from edit_utils_en import parse_tts_en
|
13 |
from inference_scale import inference_one_sample
|
14 |
import librosa
|
|
|
27 |
TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
|
28 |
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
|
29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
30 |
|
31 |
def get_random_string():
|
32 |
return "".join(str(uuid.uuid4()).split("-"))
|
|
|
121 |
segment['text'] = replace_numbers_with_words(segment['text'])
|
122 |
return self.align_model.align(segments, audio_path)
|
123 |
|
124 |
+
ssrspeech_model_name = "English"
|
125 |
+
text_tokenizer = TextTokenizer(backend="espeak")
|
126 |
+
language = "en"
|
127 |
+
transcribe_model_name = "base.en"
|
128 |
+
|
129 |
+
align_model = WhisperxAlignModel(language)
|
130 |
+
transcribe_model = WhisperxModel(transcribe_model_name, align_model, language)
|
131 |
+
|
132 |
+
ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth"
|
133 |
+
if not os.path.exists(ssrspeech_fn):
|
134 |
+
os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn)
|
135 |
+
|
136 |
+
ckpt = torch.load(ssrspeech_fn)
|
137 |
+
model = ssr.SSR_Speech(ckpt["config"])
|
138 |
+
model.load_state_dict(ckpt["model"])
|
139 |
+
config = model.args
|
140 |
+
phn2num = ckpt["phn2num"]
|
141 |
+
model.to(device)
|
142 |
+
|
143 |
+
encodec_fn = f"{MODELS_PATH}/wmencodec.th"
|
144 |
+
if not os.path.exists(encodec_fn):
|
145 |
+
os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn)
|
146 |
+
|
147 |
+
ssrspeech_model = {
|
148 |
+
"config": config,
|
149 |
+
"phn2num": phn2num,
|
150 |
+
"model": model,
|
151 |
+
"text_tokenizer": text_tokenizer,
|
152 |
+
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
|
153 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
def get_transcribe_state(segments):
|
156 |
transcript = " ".join([segment["text"] for segment in segments])
|
|
|
162 |
|
163 |
@spaces.GPU
|
164 |
def transcribe(audio_path):
|
|
|
|
|
165 |
if transcribe_model is None:
|
166 |
raise gr.Error("Transcription model not loaded")
|
167 |
|
|
|
177 |
|
178 |
@spaces.GPU
|
179 |
def align(segments, audio_path):
|
|
|
180 |
if align_model is None:
|
181 |
raise gr.Error("Align model not loaded")
|
182 |
|
|
|
204 |
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
|
205 |
|
206 |
@spaces.GPU
|
207 |
+
def run(seed, sub_amount, codec_audio_sr, codec_sr, top_k, top_p, temperature,
|
208 |
stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length,
|
209 |
audio_path, original_transcript, transcript, mode):
|
210 |
|
|
|
211 |
aug_text = True if aug_text == 1 else False
|
212 |
if ssrspeech_model is None:
|
213 |
raise gr.Error("ssrspeech model not loaded")
|
214 |
|
215 |
seed_everything(seed)
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
# resample audio
|
218 |
audio, _ = librosa.load(audio_path, sr=16000)
|
|
|
223 |
orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
|
224 |
|
225 |
[orig_transcript, segments, _] = transcribe(audio_path)
|
226 |
+
orig_transcript = orig_transcript.lower()
|
227 |
+
target_transcript = target_transcript.lower()
|
228 |
+
transcribe_state = align(segments, audio_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
print(orig_transcript)
|
230 |
print(target_transcript)
|
231 |
|
|
|
246 |
sf.write(audio_path, audio, 16000)
|
247 |
[orig_transcript, segments, _] = transcribe(audio_path)
|
248 |
|
249 |
+
|
250 |
+
orig_transcript = orig_transcript.lower()
|
251 |
+
target_transcript = target_transcript.lower()
|
252 |
+
transcribe_state = align(segments, audio_path)
|
|
|
|
|
|
|
|
|
|
|
253 |
print(orig_transcript)
|
254 |
target_transcript_copy = target_transcript # for tts cut out
|
255 |
+
target_transcript_copy = target_transcript_copy.split(' ')[0]
|
256 |
+
target_transcript = orig_transcript + ' ' + target_transcript
|
|
|
|
|
|
|
257 |
print(target_transcript)
|
258 |
|
259 |
if mode == "Edit":
|
260 |
+
operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
|
261 |
print(operations)
|
262 |
print("orig_spans: ", orig_spans)
|
263 |
|
|
|
325 |
torchaudio.save(audio_path, new_audio, codec_audio_sr)
|
326 |
if tts: # remove the start parts
|
327 |
[new_transcript, new_segments, _] = transcribe(audio_path)
|
328 |
+
transcribe_state = align(new_segments, audio_path)
|
329 |
+
tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
|
330 |
+
tmp2 = target_transcript_copy.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
if tmp1 == tmp2:
|
332 |
offset = transcribe_state['segments'][0]['words'][0]['start']
|
333 |
else:
|
|
|
354 |
|
355 |
def get_app():
|
356 |
with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
with gr.Row():
|
358 |
with gr.Column(scale=2):
|
359 |
input_audio = gr.Audio(value=f"{DEMO_PATH}/5895_34622_000026_000002.wav", label="Input Audio", type="filepath", interactive=True)
|
|
|
397 |
|
398 |
success_output = gr.HTML()
|
399 |
|
|
|
|
|
|
|
|
|
400 |
semgents = gr.State() # not used
|
401 |
transcribe_btn.click(fn=transcribe,
|
402 |
inputs=[input_audio],
|
|
|
404 |
|
405 |
run_btn.click(fn=run,
|
406 |
inputs=[
|
407 |
+
seed, sub_amount,
|
408 |
codec_audio_sr, codec_sr,
|
409 |
top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens,
|
410 |
aug_text, cfg_coef, prompt_length,
|