import os import requests import re from num2words import num2words import gradio as gr import torch import torchaudio from data.tokenizer import ( AudioTokenizer, TextTokenizer, ) from edit_utils_en import parse_edit_en from edit_utils_en import parse_tts_en from edit_utils_zh import parse_edit_zh from edit_utils_zh import parse_tts_zh from inference_scale import inference_one_sample import librosa import soundfile as sf from models import ssr import io import numpy as np import random import uuid import opencc import spaces import nltk nltk.download('punkt') DEMO_PATH = os.getenv("DEMO_PATH", "./demo") TMP_PATH = os.getenv("TMP_PATH", "./demo/temp") MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models") os.makedirs(MODELS_PATH, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" # download wmencodec url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th" filename = os.path.join(MODELS_PATH, "wmencodec.th") response = requests.get(url, stream=True) response.raise_for_status() with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"File downloaded to: {filename}") # download english model url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth" filename = os.path.join(MODELS_PATH, "English.th") response = requests.get(url, stream=True) response.raise_for_status() with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"File downloaded to: {filename}") # download mandarin model url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth" filename = os.path.join(MODELS_PATH, "Mandarin.pth") response = requests.get(url, stream=True) response.raise_for_status() with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"File downloaded to: {filename}") def get_random_string(): return "".join(str(uuid.uuid4()).split("-")) @spaces.GPU def seed_everything(seed): if seed != -1: os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def get_mask_interval(transcribe_state, word_span): print(transcribe_state) seg_num = len(transcribe_state['segments']) data = [] for i in range(seg_num): words = transcribe_state['segments'][i]['words'] for item in words: data.append([item['start'], item['end'], item['word']]) s, e = word_span[0], word_span[1] assert s <= e, f"s:{s}, e:{e}" assert s >= 0, f"s:{s}" assert e <= len(data), f"e:{e}" if e == 0: # start start = 0. end = float(data[0][0]) elif s == len(data): # end start = float(data[-1][1]) end = float(data[-1][1]) # don't know the end yet elif s == e: # insert start = float(data[s-1][1]) end = float(data[s][0]) else: start = float(data[s-1][1]) if s > 0 else float(data[s][0]) end = float(data[e][0]) if e < len(data) else float(data[-1][1]) return (start, end) def traditional_to_simplified(segments): converter = opencc.OpenCC('t2s') seg_num = len(segments) for i in range(seg_num): words = segments[i]['words'] for j in range(len(words)): segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word']) segments[i]['text'] = converter.convert(segments[i]['text']) return segments from whisperx import load_align_model, load_model, load_audio from whisperx import align as align_func # Load models text_tokenizer_en = TextTokenizer(backend="espeak") text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn') ssrspeech_fn_en = f"{MODELS_PATH}/English.pth" ckpt_en = torch.load(ssrspeech_fn_en) model_en = ssr.SSR_Speech(ckpt_en["config"]) model_en.load_state_dict(ckpt_en["model"]) config_en = model_en.args phn2num_en = ckpt_en["phn2num"] model_en.to(device) ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth" ckpt_zh = torch.load(ssrspeech_fn_zh) model_zh = ssr.SSR_Speech(ckpt_zh["config"]) model_zh.load_state_dict(ckpt_zh["model"]) config_zh = model_zh.args phn2num_zh = ckpt_zh["phn2num"] model_zh.to(device) encodec_fn = f"{MODELS_PATH}/wmencodec.th" ssrspeech_model_en = { "config": config_en, "phn2num": phn2num_en, "model": model_en, "text_tokenizer": text_tokenizer_en, "audio_tokenizer": AudioTokenizer(signature=encodec_fn) } ssrspeech_model_zh = { "config": config_zh, "phn2num": phn2num_zh, "model": model_zh, "text_tokenizer": text_tokenizer_zh, "audio_tokenizer": AudioTokenizer(signature=encodec_fn) } def get_transcribe_state(segments): transcript = " ".join([segment["text"] for segment in segments]) transcript = transcript[1:] if transcript[0] == " " else transcript return { "segments": segments, "transcript": transcript, } @spaces.GPU def transcribe_en(audio_path): language = "en" transcribe_model_name = "base.en" transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language) segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"] for segment in segments: segment['text'] = replace_numbers_with_words(segment['text']) _, segments = align_en(segments, audio_path) state = get_transcribe_state(segments) success_message = "Success: Transcribe completed successfully!" return [ state["transcript"], state['segments'], state, success_message ] @spaces.GPU def transcribe_zh(audio_path): language = "zh" transcribe_model_name = "base" transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language) segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"] for segment in segments: segment['text'] = replace_numbers_with_words(segment['text']) _, segments = align_zh(segments, audio_path) state = get_transcribe_state(segments) success_message = "Success: Transcribe completed successfully!" return [ state["transcript"], state['segments'], state, success_message ] @spaces.GPU def align_en(segments, audio_path): language = "en" align_model, metadata = load_align_model(language_code=language, device=device) audio = load_audio(audio_path) segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"] state = get_transcribe_state(segments) return state, segments @spaces.GPU def align_zh(segments, audio_path): language = "zh" align_model, metadata = load_align_model(language_code=language, device=device) audio = load_audio(audio_path) segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"] state = get_transcribe_state(segments) return state, segments def get_output_audio(audio_tensors, codec_audio_sr): result = torch.cat(audio_tensors, 1) buffer = io.BytesIO() torchaudio.save(buffer, result, int(codec_audio_sr), format="wav") buffer.seek(0) return buffer.read() def replace_numbers_with_words(sentence): sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers def replace_with_words(match): num = match.group(0) try: return num2words(num) # Convert numbers to words except: return num # In case num2words fails (unlikely with digits but just to be safe) return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers @spaces.GPU def run_edit_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _, _] = transcribe_en(audio_path) orig_transcript = orig_transcript.lower() target_transcript = target_transcript.lower() transcribe_state,_ = align_en(segments, audio_path) print(orig_transcript) print(target_transcript) operations, orig_spans = parse_edit_en(orig_transcript, target_transcript) print(operations) print("orig_spans: ", orig_spans) if len(orig_spans) > 3: raise gr.Error("Current model only supports maximum 3 editings") starting_intervals = [] ending_intervals = [] for orig_span in orig_spans: start, end = get_mask_interval(transcribe_state, orig_span) starting_intervals.append(start) ending_intervals.append(end) print("intervals: ", starting_intervals, ending_intervals) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate def combine_spans(spans, threshold=0.2): spans.sort(key=lambda x: x[0]) combined_spans = [] current_span = spans[0] for i in range(1, len(spans)): next_span = spans[i] if current_span[1] >= next_span[0] - threshold: current_span[1] = max(current_span[1], next_span[1]) else: combined_spans.append(current_span) current_span = next_span combined_spans.append(current_span) return combined_spans morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)] for start, end in zip(starting_intervals, ending_intervals)] # in seconds morphed_span = combine_spans(morphed_span, threshold=0.2) print("morphed_spans: ", morphed_span) mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_en["model"], ssrspeech_model_en["config"], ssrspeech_model_en["phn2num"], ssrspeech_model_en["text_tokenizer"], ssrspeech_model_en["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, False, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "Success: Inference successfully!" return output_audio, success_message @spaces.GPU def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _, _] = transcribe_en(audio_path) orig_transcript = orig_transcript.lower() target_transcript = target_transcript.lower() transcribe_state,_ = align_en(segments, audio_path) print(orig_transcript) print(target_transcript) info = torchaudio.info(audio_path) duration = info.num_frames / info.sample_rate cut_length = duration # Cut long audio for tts if duration > prompt_length: seg_num = len(transcribe_state['segments']) for i in range(seg_num): words = transcribe_state['segments'][i]['words'] for item in words: if item['end'] >= prompt_length: cut_length = min(item['end'], cut_length) audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length) sf.write(audio_path, audio, 16000) [orig_transcript, segments, _, _] = transcribe_en(audio_path) orig_transcript = orig_transcript.lower() target_transcript = target_transcript.lower() transcribe_state,_ = align_en(segments, audio_path) print(orig_transcript) target_transcript_copy = target_transcript # for tts cut out target_transcript_copy = target_transcript_copy.split(' ')[0] target_transcript = orig_transcript + ' ' + target_transcript print(target_transcript) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate morphed_span = [(audio_dur, audio_dur)] # in seconds mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now print("mask_interval: ", mask_interval) decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_en["model"], ssrspeech_model_en["config"], ssrspeech_model_en["phn2num"], ssrspeech_model_en["text_tokenizer"], ssrspeech_model_en["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, True, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) [new_transcript, new_segments, _, _] = transcribe_en(audio_path) transcribe_state,_ = align_en(new_segments, audio_path) tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower() tmp2 = target_transcript_copy.lower() if tmp1 == tmp2: offset = transcribe_state['segments'][0]['words'][0]['start'] else: offset = transcribe_state['segments'][0]['words'][1]['start'] new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr)) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "Success: Inference successfully!" return output_audio, success_message @spaces.GPU def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _] = transcribe_zh(audio_path) converter = opencc.OpenCC('t2s') orig_transcript = converter.convert(orig_transcript) transcribe_state = align_zh(traditional_to_simplified(segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) print(orig_transcript) print(target_transcript) operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript) print(operations) print("orig_spans: ", orig_spans) if len(orig_spans) > 3: raise gr.Error("Current model only supports maximum 3 editings") starting_intervals = [] ending_intervals = [] for orig_span in orig_spans: start, end = get_mask_interval(transcribe_state, orig_span) starting_intervals.append(start) ending_intervals.append(end) print("intervals: ", starting_intervals, ending_intervals) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate def combine_spans(spans, threshold=0.2): spans.sort(key=lambda x: x[0]) combined_spans = [] current_span = spans[0] for i in range(1, len(spans)): next_span = spans[i] if current_span[1] >= next_span[0] - threshold: current_span[1] = max(current_span[1], next_span[1]) else: combined_spans.append(current_span) current_span = next_span combined_spans.append(current_span) return combined_spans morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)] for start, end in zip(starting_intervals, ending_intervals)] # in seconds morphed_span = combine_spans(morphed_span, threshold=0.2) print("morphed_spans: ", morphed_span) mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_zh["model"], ssrspeech_model_zh["config"], ssrspeech_model_zh["phn2num"], ssrspeech_model_zh["text_tokenizer"], ssrspeech_model_zh["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, False, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "Success: Inference successfully!" return output_audio, success_message @spaces.GPU def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _] = transcribe_zh(audio_path) converter = opencc.OpenCC('t2s') orig_transcript = converter.convert(orig_transcript) transcribe_state = align_zh(traditional_to_simplified(segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) print(orig_transcript) print(target_transcript) info = torchaudio.info(audio_path) duration = info.num_frames / info.sample_rate cut_length = duration # Cut long audio for tts if duration > prompt_length: seg_num = len(transcribe_state['segments']) for i in range(seg_num): words = transcribe_state['segments'][i]['words'] for item in words: if item['end'] >= prompt_length: cut_length = min(item['end'], cut_length) audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length) sf.write(audio_path, audio, 16000) [orig_transcript, segments, _] = transcribe_zh(audio_path) converter = opencc.OpenCC('t2s') orig_transcript = converter.convert(orig_transcript) transcribe_state = align_zh(traditional_to_simplified(segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) print(orig_transcript) target_transcript_copy = target_transcript # for tts cut out target_transcript_copy = target_transcript_copy[0] target_transcript = orig_transcript + target_transcript print(target_transcript) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate morphed_span = [(audio_dur, audio_dur)] # in seconds mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now print("mask_interval: ", mask_interval) decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_zh["model"], ssrspeech_model_zh["config"], ssrspeech_model_zh["phn2num"], ssrspeech_model_zh["text_tokenizer"], ssrspeech_model_zh["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, True, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) [new_transcript, new_segments, _] = transcribe_zh(audio_path) transcribe_state = align_zh(traditional_to_simplified(new_segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) tmp1 = transcribe_state['segments'][0]['words'][0]['word'] tmp2 = target_transcript_copy if tmp1 == tmp2: offset = transcribe_state['segments'][0]['words'][0]['start'] else: offset = transcribe_state['segments'][0]['words'][1]['start'] new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr)) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "Success: Inference successfully!" return output_audio, success_message if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Ssrspeech gradio app.") parser.add_argument("--demo-path", default="./demo", help="Path to demo directory") parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory") parser.add_argument("--models-path", default="./pretrained_models", help="Path to ssrspeech models directory") parser.add_argument("--port", default=7860, type=int, help="App port") parser.add_argument("--share", action="store_true", help="Launch with public url") os.environ["USER"] = os.getenv("USER", "user") args = parser.parse_args() DEMO_PATH = args.demo_path TMP_PATH = args.tmp_path MODELS_PATH = args.models_path # app = get_app() # app.queue().launch(share=args.share, server_port=args.port) # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer Generate and edit speech from text. Adjust advanced settings for more control. Learn more about 🚀**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/). """) # Tabs for Generate and Edit with gr.Tab("English Speech Editing"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=5, info="cfg stride, 5 is a good value for English, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() semgents = gr.State() # not used state = gr.State() # not used transcribe_btn.click(fn=transcribe_en, inputs=[input_audio], outputs=[original_transcript, semgents, state, success_output]) run_btn.click(fn=run_edit_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_edit_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) with gr.Tab("English TTS"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=5, info="cfg stride, 5 is a good value for English, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() semgents = gr.State() # not used state = gr.State() # not used transcribe_btn.click(fn=transcribe_en, inputs=[input_audio], outputs=[original_transcript, semgents, state, success_output]) run_btn.click(fn=run_tts_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_tts_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) with gr.Tab("Mandarin Speech Editing"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=1, info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() semgents = gr.State() # not used state = gr.State() # not used transcribe_btn.click(fn=transcribe_zh, inputs=[input_audio], outputs=[original_transcript, semgents, state, success_output]) run_btn.click(fn=run_edit_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_edit_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) with gr.Tab("Mandarin TTS"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="Debug", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="Debug", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=1, info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() semgents = gr.State() # not used state = gr.State() # not used transcribe_btn.click(fn=transcribe_zh, inputs=[input_audio], outputs=[original_transcript, semgents, state, success_output]) run_btn.click(fn=run_tts_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_tts_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) # Launch the Gradio demo demo.launch()