import os
os.system("bash setup.sh")
import requests
import re
from num2words import num2words
import gradio as gr
import torch
import torchaudio
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from edit_utils_en import parse_edit_en
from edit_utils_en import parse_tts_en
from edit_utils_zh import parse_edit_zh
from edit_utils_zh import parse_tts_zh
from inference_scale import inference_one_sample
import librosa
import soundfile as sf
from models import ssr
import io
import numpy as np
import random
import uuid
import opencc
import spaces
import nltk
nltk.download('punkt')
DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
os.makedirs(MODELS_PATH, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
# download wmencodec
url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
filename = os.path.join(MODELS_PATH, "wmencodec.th")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(filename, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"File downloaded to: {filename}")
else:
print("wmencodec model found")
if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
# download english model
url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
filename = os.path.join(MODELS_PATH, "English.pth")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(filename, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"File downloaded to: {filename}")
else:
print("english model found")
if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
# download mandarin model
url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
filename = os.path.join(MODELS_PATH, "Mandarin.pth")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(filename, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"File downloaded to: {filename}")
else:
print("mandarin model found")
def get_random_string():
return "".join(str(uuid.uuid4()).split("-"))
@spaces.GPU
def seed_everything(seed):
if seed != -1:
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_mask_interval(transcribe_state, word_span):
print(transcribe_state)
seg_num = len(transcribe_state['segments'])
data = []
for i in range(seg_num):
words = transcribe_state['segments'][i]['words']
for item in words:
data.append([item['start'], item['end'], item['word']])
s, e = word_span[0], word_span[1]
assert s <= e, f"s:{s}, e:{e}"
assert s >= 0, f"s:{s}"
assert e <= len(data), f"e:{e}"
if e == 0: # start
start = 0.
end = float(data[0][0])
elif s == len(data): # end
start = float(data[-1][1])
end = float(data[-1][1]) # don't know the end yet
elif s == e: # insert
start = float(data[s-1][1])
end = float(data[s][0])
else:
start = float(data[s-1][1]) if s > 0 else float(data[s][0])
end = float(data[e][0]) if e < len(data) else float(data[-1][1])
return (start, end)
def traditional_to_simplified(segments):
converter = opencc.OpenCC('t2s')
seg_num = len(segments)
for i in range(seg_num):
words = segments[i]['words']
for j in range(len(words)):
segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word'])
segments[i]['text'] = converter.convert(segments[i]['text'])
return segments
from whisperx import load_align_model, load_model, load_audio
from whisperx import align as align_func
# Load models
text_tokenizer_en = TextTokenizer(backend="espeak")
text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn-latn-pinyin')
ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
ckpt_en = torch.load(ssrspeech_fn_en)
model_en = ssr.SSR_Speech(ckpt_en["config"])
model_en.load_state_dict(ckpt_en["model"])
config_en = model_en.args
phn2num_en = ckpt_en["phn2num"]
model_en.to(device)
ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
ckpt_zh = torch.load(ssrspeech_fn_zh)
model_zh = ssr.SSR_Speech(ckpt_zh["config"])
model_zh.load_state_dict(ckpt_zh["model"])
config_zh = model_zh.args
phn2num_zh = ckpt_zh["phn2num"]
model_zh.to(device)
encodec_fn = f"{MODELS_PATH}/wmencodec.th"
ssrspeech_model_en = {
"config": config_en,
"phn2num": phn2num_en,
"model": model_en,
"text_tokenizer": text_tokenizer_en,
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}
ssrspeech_model_zh = {
"config": config_zh,
"phn2num": phn2num_zh,
"model": model_zh,
"text_tokenizer": text_tokenizer_zh,
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}
def get_transcribe_state(segments):
transcript = " ".join([segment["text"] for segment in segments])
transcript = transcript[1:] if transcript[0] == " " else transcript
return {
"segments": segments,
"transcript": transcript,
}
@spaces.GPU
def transcribe_en(audio_path):
language = "en"
transcribe_model_name = "medium.en"
transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
for segment in segments:
segment['text'] = replace_numbers_with_words(segment['text'])
_, segments = align_en(segments, audio_path)
state = get_transcribe_state(segments)
success_message = "Success: Transcribe completed successfully!"
torch.cuda.empty_cache()
return [
state["transcript"], state['segments'],
state, success_message
]
@spaces.GPU
def transcribe_zh(audio_path):
language = "zh"
transcribe_model_name = "medium"
transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
_, segments = align_zh(segments, audio_path)
state = get_transcribe_state(segments)
success_message = "Success: Transcribe completed successfully!"
converter = opencc.OpenCC('t2s')
state["transcript"] = converter.convert(state["transcript"])
torch.cuda.empty_cache()
return [
state["transcript"], state['segments'],
state, success_message
]
@spaces.GPU
def align_en(segments, audio_path):
language = "en"
align_model, metadata = load_align_model(language_code=language, device=device)
audio = load_audio(audio_path)
segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
state = get_transcribe_state(segments)
torch.cuda.empty_cache()
return state, segments
@spaces.GPU
def align_zh(segments, audio_path):
language = "zh"
align_model, metadata = load_align_model(language_code=language, device=device)
audio = load_audio(audio_path)
segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
state = get_transcribe_state(segments)
torch.cuda.empty_cache()
return state, segments
def get_output_audio(audio_tensors, codec_audio_sr):
result = torch.cat(audio_tensors, 1)
buffer = io.BytesIO()
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
buffer.seek(0)
return buffer.read()
def replace_numbers_with_words(sentence):
sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers
def replace_with_words(match):
num = match.group(0)
try:
return num2words(num) # Convert numbers to words
except:
return num # In case num2words fails (unlikely with digits but just to be safe)
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
@spaces.GPU
def run_edit_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
# orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_en(audio_path)
orig_transcript = orig_transcript.lower()
target_transcript = target_transcript.lower()
transcribe_state,_ = align_en(segments, audio_path)
print(orig_transcript)
print(target_transcript)
operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
print(operations)
print("orig_spans: ", orig_spans)
if len(orig_spans) > 3:
raise gr.Error("Current model only supports maximum 3 editings")
starting_intervals = []
ending_intervals = []
for orig_span in orig_spans:
start, end = get_mask_interval(transcribe_state, orig_span)
starting_intervals.append(start)
ending_intervals.append(end)
print("intervals: ", starting_intervals, ending_intervals)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
def combine_spans(spans, threshold=0.2):
spans.sort(key=lambda x: x[0])
combined_spans = []
current_span = spans[0]
for i in range(1, len(spans)):
next_span = spans[i]
if current_span[1] >= next_span[0] - threshold:
current_span[1] = max(current_span[1], next_span[1])
else:
combined_spans.append(current_span)
current_span = next_span
combined_spans.append(current_span)
return combined_spans
morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
for start, end in zip(starting_intervals, ending_intervals)] # in seconds
morphed_span = combine_spans(morphed_span, threshold=0.2)
print("morphed_spans: ", morphed_span)
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_en["model"],
ssrspeech_model_en["config"],
ssrspeech_model_en["phn2num"],
ssrspeech_model_en["text_tokenizer"],
ssrspeech_model_en["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, False,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
# torchaudio.save(audio_path, new_audio, codec_audio_sr)
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
torch.cuda.empty_cache()
success_message = "Success: Inference successfully!"
return output_audio, success_message
@spaces.GPU
def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
# orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_en(audio_path)
orig_transcript = orig_transcript.lower()
target_transcript = target_transcript.lower()
transcribe_state,_ = align_en(segments, audio_path)
print(orig_transcript)
print(target_transcript)
info = torchaudio.info(audio_path)
duration = info.num_frames / info.sample_rate
cut_length = duration
# Cut long audio for tts
if duration > prompt_length:
seg_num = len(transcribe_state['segments'])
for i in range(seg_num):
words = transcribe_state['segments'][i]['words']
for item in words:
if item['end'] >= prompt_length:
cut_length = min(item['end'], cut_length)
audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
audio_path = audio_path.replace('.','_tmp.')
sf.write(audio_path, audio, 16000)
[orig_transcript, segments, _, _] = transcribe_en(audio_path)
orig_transcript = orig_transcript.lower()
target_transcript = target_transcript.lower()
transcribe_state,_ = align_en(segments, audio_path)
print(orig_transcript)
target_transcript_copy = target_transcript # for tts cut out
target_transcript_copy = target_transcript_copy.split(' ')[0]
target_transcript = orig_transcript + ' ' + target_transcript
print(target_transcript)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
morphed_span = [(audio_dur, audio_dur)] # in seconds
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
print("mask_interval: ", mask_interval)
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_en["model"],
ssrspeech_model_en["config"],
ssrspeech_model_en["phn2num"],
ssrspeech_model_en["text_tokenizer"],
ssrspeech_model_en["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, True,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
torchaudio.save(audio_path, new_audio, codec_audio_sr)
[new_transcript, new_segments, _, _] = transcribe_en(audio_path)
transcribe_state,_ = align_en(new_segments, audio_path)
tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
tmp2 = target_transcript_copy.lower()
if tmp1 == tmp2:
offset = transcribe_state['segments'][0]['words'][0]['start']
else:
offset = transcribe_state['segments'][0]['words'][1]['start']
new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
torch.cuda.empty_cache()
success_message = "Success: Inference successfully!"
return output_audio, success_message
@spaces.GPU
def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
# orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_zh(audio_path)
converter = opencc.OpenCC('t2s')
orig_transcript = converter.convert(orig_transcript)
transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
print(orig_transcript)
print(target_transcript)
operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
print(operations)
print("orig_spans: ", orig_spans)
if len(orig_spans) > 3:
raise gr.Error("Current model only supports maximum 3 editings")
starting_intervals = []
ending_intervals = []
for orig_span in orig_spans:
start, end = get_mask_interval(transcribe_state, orig_span)
starting_intervals.append(start)
ending_intervals.append(end)
print("intervals: ", starting_intervals, ending_intervals)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
def combine_spans(spans, threshold=0.2):
spans.sort(key=lambda x: x[0])
combined_spans = []
current_span = spans[0]
for i in range(1, len(spans)):
next_span = spans[i]
if current_span[1] >= next_span[0] - threshold:
current_span[1] = max(current_span[1], next_span[1])
else:
combined_spans.append(current_span)
current_span = next_span
combined_spans.append(current_span)
return combined_spans
morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
for start, end in zip(starting_intervals, ending_intervals)] # in seconds
morphed_span = combine_spans(morphed_span, threshold=0.2)
print("morphed_spans: ", morphed_span)
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_zh["model"],
ssrspeech_model_zh["config"],
ssrspeech_model_zh["phn2num"],
ssrspeech_model_zh["text_tokenizer"],
ssrspeech_model_zh["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, False,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
# torchaudio.save(audio_path, new_audio, codec_audio_sr)
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
torch.cuda.empty_cache()
success_message = "Success: Inference successfully!"
return output_audio, success_message
@spaces.GPU
def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
audio_path, transcript):
codec_audio_sr = 16000
codec_sr = 50
top_k = 0
top_p = 0.8
temperature = 1
kvcache = 1
stop_repetition = 2
aug_text = True if aug_text == 1 else False
seed_everything(seed)
# resample audio
audio, _ = librosa.load(audio_path, sr=16000)
sf.write(audio_path, audio, 16000)
# text normalization
target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
# orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
[orig_transcript, segments, _, _] = transcribe_zh(audio_path)
converter = opencc.OpenCC('t2s')
orig_transcript = converter.convert(orig_transcript)
transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
print(orig_transcript)
print(target_transcript)
info = torchaudio.info(audio_path)
duration = info.num_frames / info.sample_rate
cut_length = duration
# Cut long audio for tts
if duration > prompt_length:
seg_num = len(transcribe_state['segments'])
for i in range(seg_num):
words = transcribe_state['segments'][i]['words']
for item in words:
if item['end'] >= prompt_length:
cut_length = min(item['end'], cut_length)
audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
audio_path = audio_path.replace('.','_tmp.')
sf.write(audio_path, audio, 16000)
[orig_transcript, segments, _, _] = transcribe_zh(audio_path)
converter = opencc.OpenCC('t2s')
orig_transcript = converter.convert(orig_transcript)
transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
print(orig_transcript)
target_transcript_copy = target_transcript # for tts cut out
target_transcript_copy = target_transcript_copy[0]
target_transcript = orig_transcript + target_transcript
print(target_transcript)
info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate
morphed_span = [(audio_dur, audio_dur)] # in seconds
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
print("mask_interval: ", mask_interval)
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
new_audio = inference_one_sample(
ssrspeech_model_zh["model"],
ssrspeech_model_zh["config"],
ssrspeech_model_zh["phn2num"],
ssrspeech_model_zh["text_tokenizer"],
ssrspeech_model_zh["audio_tokenizer"],
audio_path, orig_transcript, target_transcript, mask_interval,
cfg_coef, cfg_stride, aug_text, False, True, True,
device, decode_config
)
audio_tensors = []
# save segments for comparison
new_audio = new_audio[0].cpu()
torchaudio.save(audio_path, new_audio, codec_audio_sr)
[new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
tmp1 = transcribe_state['segments'][0]['words'][0]['word']
tmp2 = target_transcript_copy
if tmp1 == tmp2:
offset = transcribe_state['segments'][0]['words'][0]['start']
else:
offset = transcribe_state['segments'][0]['words'][1]['start']
new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
audio_tensors.append(new_audio)
output_audio = get_output_audio(audio_tensors, codec_audio_sr)
torch.cuda.empty_cache()
success_message = "Success: Inference successfully!"
return output_audio, success_message
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Ssrspeech gradio app.")
parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory")
parser.add_argument("--models-path", default="./pretrained_models", help="Path to ssrspeech models directory")
parser.add_argument("--port", default=7860, type=int, help="App port")
parser.add_argument("--share", action="store_true", help="Launch with public url")
os.environ["USER"] = os.getenv("USER", "user")
args = parser.parse_args()
DEMO_PATH = args.demo_path
TMP_PATH = args.tmp_path
MODELS_PATH = args.models_path
# app = get_app()
# app.queue().launch(share=args.share, server_port=args.port)
# CSS styling (optional)
css = """
#col-container {
margin: 0 auto;
max-width: 1280px;
}
"""
# Gradio Blocks layout
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer
Generate and edit speech from text. Adjust advanced settings for more control.
Learn more about 🚀**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/).
""")
# Tabs for Generate and Edit
with gr.Tab("English Speech Editing"):
with gr.Row():
with gr.Column(scale=2):
input_audio1 = gr.Audio(
value=f"{DEMO_PATH}/84_121550_000074_000000.wav",
label="Input Audio",
type="filepath",
interactive=True
)
with gr.Group():
original_transcript1 = gr.Textbox(
label="Original transcript",
lines=5,
value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.",
info="Use whisperx model to get the transcript."
)
transcribe_btn1 = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript1 = gr.Textbox(
label="Text",
lines=7,
value="but when I saw the mirage of the lake in the distance, which the sense deceives, lost not by distance any of its marks.",
interactive=True
)
run_btn1 = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio1 = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed1 = gr.Number(label="seed", value=1234, precision=0, info="random seeds always works :)")
aug_text1 = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef1 = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride1 = gr.Number(label="cfg_stride", value=5,
info="cfg stride, 5 is a good value for English, change if you don't like the results")
prompt_length1 = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount1 = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output1 = gr.HTML()
transcribe_btn1.click(
fn=transcribe_en,
inputs=[input_audio1],
outputs=[original_transcript1, gr.State(), gr.State(), success_output1]
)
run_btn1.click(fn=run_edit_en,
inputs=[
seed1, sub_amount1,
aug_text1, cfg_coef1, cfg_stride1, prompt_length1,
input_audio1, transcript1,
],
outputs=[output_audio1, success_output1])
transcript1.submit(fn=run_edit_en,
inputs=[
seed1, sub_amount1,
aug_text1, cfg_coef1, cfg_stride1, prompt_length1,
input_audio1, transcript1,
],
outputs=[output_audio1, success_output1]
)
with gr.Tab("English TTS"):
with gr.Row():
with gr.Column(scale=2):
input_audio2 = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript2 = gr.Textbox(label="Original transcript", lines=5, value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.",
info="Use whisperx model to get the transcript.")
transcribe_btn2 = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript2 = gr.Textbox(label="Text", lines=7, value="I cannot believe that the same model can also do text to speech synthesis too!", interactive=True)
run_btn2 = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio2 = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed2 = gr.Number(label="seed", value=1234, precision=0, info="random seeds always works :)")
aug_text2 = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef2 = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride2 = gr.Number(label="cfg_stride", value=5,
info="cfg stride, 5 is a good value for English, change if you don't like the results")
prompt_length2 = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount2 = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output2 = gr.HTML()
transcribe_btn2.click(fn=transcribe_en,
inputs=[input_audio2],
outputs=[original_transcript2, gr.State(), gr.State(), success_output2])
run_btn2.click(fn=run_tts_en,
inputs=[
seed2, sub_amount2,
aug_text2, cfg_coef2, cfg_stride2, prompt_length2,
input_audio2, transcript2,
],
outputs=[output_audio2, success_output2])
transcript2.submit(fn=run_tts_en,
inputs=[
seed2, sub_amount2,
aug_text2, cfg_coef2, cfg_stride2, prompt_length2,
input_audio2, transcript2,
],
outputs=[output_audio2, success_output2]
)
with gr.Tab("Mandarin Speech Editing"):
with gr.Row():
with gr.Column(scale=2):
input_audio3 = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript3 = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
info="Use whisperx model to get the transcript.")
transcribe_btn3 = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript3 = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
run_btn3 = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio3 = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed3 = gr.Number(label="seed", value=1234, precision=0, info="random seeds always works :)")
aug_text3 = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef3 = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride3 = gr.Number(label="cfg_stride", value=1,
info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
prompt_length3 = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount3 = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output3 = gr.HTML()
transcribe_btn3.click(fn=transcribe_zh,
inputs=[input_audio3],
outputs=[original_transcript3, gr.State(), gr.State(), success_output3])
run_btn3.click(fn=run_edit_zh,
inputs=[
seed3, sub_amount3,
aug_text3, cfg_coef3, cfg_stride3, prompt_length3,
input_audio3, transcript3,
],
outputs=[output_audio3, success_output3])
transcript3.submit(fn=run_edit_zh,
inputs=[
seed3, sub_amount3,
aug_text3, cfg_coef3, cfg_stride3, prompt_length3,
input_audio3, transcript3,
],
outputs=[output_audio3, success_output3]
)
with gr.Tab("Mandarin TTS"):
with gr.Row():
with gr.Column(scale=2):
input_audio4 = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
with gr.Group():
original_transcript4 = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
info="Use whisperx model to get the transcript.")
transcribe_btn4 = gr.Button(value="Transcribe")
with gr.Column(scale=3):
with gr.Group():
transcript4 = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
run_btn4 = gr.Button(value="Run")
with gr.Column(scale=2):
output_audio4 = gr.Audio(label="Output Audio")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
seed4 = gr.Number(label="seed", value=1234, precision=0, info="random seeds always works :)")
aug_text4 = gr.Radio(label="aug_text", choices=[0, 1], value=1,
info="set to 1 to use classifer-free guidance, change if you don't like the results")
cfg_coef4 = gr.Number(label="cfg_coef", value=1.5,
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
cfg_stride4 = gr.Number(label="cfg_stride", value=1,
info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
prompt_length4 = gr.Number(label="prompt_length", value=3,
info="used for tts prompt, will automatically cut the prompt audio to this length")
sub_amount4 = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
success_output4 = gr.HTML()
transcribe_btn4.click(fn=transcribe_zh,
inputs=[input_audio4],
outputs=[original_transcript4, gr.State(), gr.State(), success_output4])
run_btn4.click(fn=run_tts_zh,
inputs=[
seed4, sub_amount4,
aug_text4, cfg_coef4, cfg_stride4, prompt_length4,
input_audio4, transcript4,
],
outputs=[output_audio4, success_output4])
transcript4.submit(fn=run_tts_zh,
inputs=[
seed4, sub_amount4,
aug_text4, cfg_coef4, cfg_stride4, prompt_length4,
input_audio4, transcript4,
],
outputs=[output_audio4, success_output4]
)
# Launch the Gradio demo
demo.launch()