Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
from num2words import num2words | |
import gradio as gr | |
import torch | |
import torchaudio | |
from data.tokenizer import ( | |
AudioTokenizer, | |
TextTokenizer, | |
) | |
from edit_utils_zh import parse_edit_zh | |
from edit_utils_en import parse_edit_en | |
from edit_utils_zh import parse_tts_zh | |
from edit_utils_en import parse_tts_en | |
from inference_scale import inference_one_sample | |
import librosa | |
import soundfile as sf | |
from models import ssr | |
import io | |
import numpy as np | |
import random | |
import uuid | |
import opencc | |
import spaces | |
import nltk | |
nltk.download('punkt') | |
DEMO_PATH = os.getenv("DEMO_PATH", "./demo") | |
TMP_PATH = os.getenv("TMP_PATH", "./demo/temp") | |
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
transcribe_model, align_model, ssrspeech_model = None, None, None | |
def get_random_string(): | |
return "".join(str(uuid.uuid4()).split("-")) | |
def traditional_to_simplified(segments): | |
converter = opencc.OpenCC('t2s') | |
seg_num = len(segments) | |
for i in range(seg_num): | |
words = segments[i]['words'] | |
for j in range(len(words)): | |
segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word']) | |
segments[i]['text'] = converter.convert(segments[i]['text']) | |
return segments | |
def seed_everything(seed): | |
if seed != -1: | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
def get_mask_interval(transcribe_state, word_span): | |
print(transcribe_state) | |
seg_num = len(transcribe_state['segments']) | |
data = [] | |
for i in range(seg_num): | |
words = transcribe_state['segments'][i]['words'] | |
for item in words: | |
data.append([item['start'], item['end'], item['word']]) | |
s, e = word_span[0], word_span[1] | |
assert s <= e, f"s:{s}, e:{e}" | |
assert s >= 0, f"s:{s}" | |
assert e <= len(data), f"e:{e}" | |
if e == 0: # start | |
start = 0. | |
end = float(data[0][0]) | |
elif s == len(data): # end | |
start = float(data[-1][1]) | |
end = float(data[-1][1]) # don't know the end yet | |
elif s == e: # insert | |
start = float(data[s-1][1]) | |
end = float(data[s][0]) | |
else: | |
start = float(data[s-1][1]) if s > 0 else float(data[s][0]) | |
end = float(data[e][0]) if e < len(data) else float(data[-1][1]) | |
return (start, end) | |
class WhisperxAlignModel: | |
def __init__(self, language): | |
from whisperx import load_align_model | |
self.model, self.metadata = load_align_model(language_code=language, device=device) | |
def align(self, segments, audio_path): | |
from whisperx import align, load_audio | |
audio = load_audio(audio_path) | |
return align(segments, self.model, self.metadata, audio, device, return_char_alignments=False)["segments"] | |
class WhisperModel: | |
def __init__(self, model_name, language): | |
from whisper import load_model | |
self.model = load_model(model_name, device, language=language) | |
from whisper.tokenizer import get_tokenizer | |
tokenizer = get_tokenizer(multilingual=False, language=language) | |
self.supress_tokens = [-1] + [ | |
i | |
for i in range(tokenizer.eot) | |
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" ")) | |
] | |
def transcribe(self, audio_path): | |
return self.model.transcribe(audio_path, suppress_tokens=self.supress_tokens, word_timestamps=True)["segments"] | |
class WhisperxModel: | |
def __init__(self, model_name, align_model, language): | |
from whisperx import load_model | |
self.model = load_model(model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language) | |
self.align_model = align_model | |
def transcribe(self, audio_path): | |
segments = self.model.transcribe(audio_path, batch_size=8)["segments"] | |
for segment in segments: | |
segment['text'] = replace_numbers_with_words(segment['text']) | |
return self.align_model.align(segments, audio_path) | |
def load_models(ssrspeech_model_name): | |
global transcribe_model, align_model, ssrspeech_model | |
alignment_model_name = "whisperX" | |
whisper_backend_name = "whisperX" | |
if ssrspeech_model_name == "English": | |
ssrspeech_model_name = "English" | |
text_tokenizer = TextTokenizer(backend="espeak") | |
language = "en" | |
transcribe_model_name = "base.en" | |
elif ssrspeech_model_name == "Mandarin": | |
ssrspeech_model_name = "Mandarin" | |
text_tokenizer = TextTokenizer(backend="espeak", language='cmn') | |
language = "zh" | |
transcribe_model_name = "base" | |
align_model = WhisperxAlignModel(language) | |
transcribe_model = WhisperxModel(transcribe_model_name, align_model, language) | |
ssrspeech_fn = f"{MODELS_PATH}/{ssrspeech_model_name}.pth" | |
if not os.path.exists(ssrspeech_fn): | |
os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-{ssrspeech_model_name}/resolve/main/{ssrspeech_model_name}.pth -O " + ssrspeech_fn) | |
print(transcribe_model, align_model) | |
ckpt = torch.load(ssrspeech_fn) | |
model = ssr.SSR_Speech(ckpt["config"]) | |
model.load_state_dict(ckpt["model"]) | |
config = model.args | |
phn2num = ckpt["phn2num"] | |
model.to(device) | |
encodec_fn = f"{MODELS_PATH}/wmencodec.th" | |
if not os.path.exists(encodec_fn): | |
os.system(f"wget https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th -O " + encodec_fn) | |
ssrspeech_model = { | |
"config": config, | |
"phn2num": phn2num, | |
"model": model, | |
"text_tokenizer": text_tokenizer, | |
"audio_tokenizer": AudioTokenizer(signature=encodec_fn) | |
} | |
success_message = "<span style='color:green;'>Success: Models loading completed successfully!</span>" | |
return [ | |
gr.Accordion(), | |
success_message | |
] | |
def get_transcribe_state(segments): | |
transcript = " ".join([segment["text"] for segment in segments]) | |
transcript = transcript[1:] if transcript[0] == " " else transcript | |
return { | |
"segments": segments, | |
"transcript": transcript, | |
} | |
def transcribe(audio_path): | |
global transcribe_model | |
if transcribe_model is None: | |
raise gr.Error("Transcription model not loaded") | |
segments = transcribe_model.transcribe(audio_path) | |
state = get_transcribe_state(segments) | |
success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>" | |
return [ | |
state["transcript"], state['segments'], | |
state, success_message | |
] | |
def align(segments, audio_path): | |
global align_model | |
if align_model is None: | |
raise gr.Error("Align model not loaded") | |
segments = align_model.align(segments, audio_path) | |
state = get_transcribe_state(segments) | |
return state | |
def get_output_audio(audio_tensors, codec_audio_sr): | |
result = torch.cat(audio_tensors, 1) | |
buffer = io.BytesIO() | |
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav") | |
buffer.seek(0) | |
return buffer.read() | |
def replace_numbers_with_words(sentence): | |
sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers | |
def replace_with_words(match): | |
num = match.group(0) | |
try: | |
return num2words(num) # Convert numbers to words | |
except: | |
return num # In case num2words fails (unlikely with digits but just to be safe) | |
return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers | |
def run(seed, sub_amount, ssrspeech_model_choice, codec_audio_sr, codec_sr, top_k, top_p, temperature, | |
stop_repetition, kvcache, silence_tokens, aug_text, cfg_coef, prompt_length, | |
audio_path, original_transcript, transcript, mode): | |
global transcribe_model, align_model, ssrspeech_model | |
aug_text = True if aug_text == 1 else False | |
if ssrspeech_model is None: | |
raise gr.Error("ssrspeech model not loaded") | |
seed_everything(seed) | |
if ssrspeech_model_choice == "English": | |
language = "en" | |
elif ssrspeech_model_choice == "Mandarin": | |
language = "zh" | |
# resample audio | |
audio, _ = librosa.load(audio_path, sr=16000) | |
sf.write(audio_path, audio, 16000) | |
# text normalization | |
target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") | |
orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") | |
[orig_transcript, segments, _] = transcribe(audio_path) | |
if language == 'zh': | |
converter = opencc.OpenCC('t2s') | |
orig_transcript = converter.convert(orig_transcript) | |
transcribe_state = align(traditional_to_simplified(segments), audio_path) | |
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) | |
elif language == 'en': | |
orig_transcript = orig_transcript.lower() | |
target_transcript = target_transcript.lower() | |
transcribe_state = align(segments, audio_path) | |
print(orig_transcript) | |
print(target_transcript) | |
if mode == "TTS": | |
info = torchaudio.info(audio_path) | |
duration = info.num_frames / info.sample_rate | |
cut_length = duration | |
# Cut long audio for tts | |
if duration > prompt_length: | |
seg_num = len(transcribe_state['segments']) | |
for i in range(seg_num): | |
words = transcribe_state['segments'][i]['words'] | |
for item in words: | |
if item['end'] >= prompt_length: | |
cut_length = min(item['end'], cut_length) | |
audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length) | |
sf.write(audio_path, audio, 16000) | |
[orig_transcript, segments, _] = transcribe(audio_path) | |
if language == 'zh': | |
converter = opencc.OpenCC('t2s') | |
orig_transcript = converter.convert(orig_transcript) | |
transcribe_state = align(traditional_to_simplified(segments), audio_path) | |
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) | |
elif language == 'en': | |
orig_transcript = orig_transcript.lower() | |
target_transcript = target_transcript.lower() | |
transcribe_state = align(segments, audio_path) | |
print(orig_transcript) | |
target_transcript_copy = target_transcript # for tts cut out | |
if language == 'en': | |
target_transcript_copy = target_transcript_copy.split(' ')[0] | |
elif language == 'zh': | |
target_transcript_copy = target_transcript_copy[0] | |
target_transcript = orig_transcript + ' ' + target_transcript if language == 'en' else orig_transcript + target_transcript | |
print(target_transcript) | |
if mode == "Edit": | |
operations, orig_spans = parse_edit_en(orig_transcript, target_transcript) if language == 'en' else parse_edit_zh(orig_transcript, target_transcript) | |
print(operations) | |
print("orig_spans: ", orig_spans) | |
if len(orig_spans) > 3: | |
raise gr.Error("Current model only supports maximum 3 editings") | |
starting_intervals = [] | |
ending_intervals = [] | |
for orig_span in orig_spans: | |
start, end = get_mask_interval(transcribe_state, orig_span) | |
starting_intervals.append(start) | |
ending_intervals.append(end) | |
print("intervals: ", starting_intervals, ending_intervals) | |
info = torchaudio.info(audio_path) | |
audio_dur = info.num_frames / info.sample_rate | |
def combine_spans(spans, threshold=0.2): | |
spans.sort(key=lambda x: x[0]) | |
combined_spans = [] | |
current_span = spans[0] | |
for i in range(1, len(spans)): | |
next_span = spans[i] | |
if current_span[1] >= next_span[0] - threshold: | |
current_span[1] = max(current_span[1], next_span[1]) | |
else: | |
combined_spans.append(current_span) | |
current_span = next_span | |
combined_spans.append(current_span) | |
return combined_spans | |
morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)] | |
for start, end in zip(starting_intervals, ending_intervals)] # in seconds | |
morphed_span = combine_spans(morphed_span, threshold=0.2) | |
print("morphed_spans: ", morphed_span) | |
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] | |
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now | |
else: | |
info = torchaudio.info(audio_path) | |
audio_dur = info.num_frames / info.sample_rate | |
morphed_span = [(audio_dur, audio_dur)] # in seconds | |
mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] | |
mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now | |
print("mask_interval: ", mask_interval) | |
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} | |
tts = True if mode == "TTS" else False | |
new_audio = inference_one_sample( | |
ssrspeech_model["model"], | |
ssrspeech_model["config"], | |
ssrspeech_model["phn2num"], | |
ssrspeech_model["text_tokenizer"], | |
ssrspeech_model["audio_tokenizer"], | |
audio_path, orig_transcript, target_transcript, mask_interval, | |
cfg_coef, aug_text, False, True, tts, | |
device, decode_config | |
) | |
audio_tensors = [] | |
# save segments for comparison | |
new_audio = new_audio[0].cpu() | |
torchaudio.save(audio_path, new_audio, codec_audio_sr) | |
if tts: # remove the start parts | |
[new_transcript, new_segments, _] = transcribe(audio_path) | |
if language == 'zh': | |
transcribe_state = align(traditional_to_simplified(new_segments), audio_path) | |
transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) | |
tmp1 = transcribe_state['segments'][0]['words'][0]['word'] | |
tmp2 = target_transcript_copy | |
elif language == 'en': | |
transcribe_state = align(new_segments, audio_path) | |
tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower() | |
tmp2 = target_transcript_copy.lower() | |
if tmp1 == tmp2: | |
offset = transcribe_state['segments'][0]['words'][0]['start'] | |
else: | |
offset = transcribe_state['segments'][0]['words'][1]['start'] | |
new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr)) | |
audio_tensors.append(new_audio) | |
output_audio = get_output_audio(audio_tensors, codec_audio_sr) | |
success_message = "<span style='color:green;'>Success: Inference successfully!</span>" | |
return output_audio, success_message | |
demo_original_transcript = "Gwynplaine had, besides, for his work and for his feats of strength, round his neck and over his shoulders, an esclavine of leather." | |
demo_text = { | |
"TTS": { | |
"regular": "Gwynplaine had, besides, for his work and for his feats of strength, I cannot believe that the same model can also do text to speech synthesis too!" | |
}, | |
"Edit": { | |
"regular": "Gwynplaine had, besides, for his work and for his feats of strength, take over the stage for half an hour, an esclavine of leather." | |
}, | |
} | |
def get_app(): | |
with gr.Blocks() as app: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
load_models_btn = gr.Button(value="Load models") | |
with gr.Column(scale=5): | |
with gr.Accordion("Select models", open=False) as models_selector: | |
with gr.Row(): | |
ssrspeech_model_choice = gr.Radio(label="ssrspeech model", value="English", | |
choices=["English", "Mandarin"]) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_audio = gr.Audio(value=f"{DEMO_PATH}/5895_34622_000026_000002.wav", label="Input Audio", type="filepath", interactive=True) | |
with gr.Group(): | |
original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, | |
info="Use whisperx model to get the transcript.") | |
transcribe_btn = gr.Button(value="Transcribe") | |
with gr.Column(scale=3): | |
with gr.Group(): | |
transcript = gr.Textbox(label="Text", lines=7, value=demo_text["Edit"]["regular"]) | |
with gr.Row(): | |
mode = gr.Radio(label="Mode", choices=["Edit", "TTS"], value="Edit") | |
run_btn = gr.Button(value="Run") | |
with gr.Column(scale=2): | |
output_audio = gr.Audio(label="Output Audio") | |
with gr.Row(): | |
with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False): | |
stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3, 4], value=2, | |
info="if there are long silence in the generated audio, reduce the stop_repetition to 2 or 1. -1 = disabled") | |
seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") | |
kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1, | |
info="set to 0 to use less VRAM, but with slower inference") | |
aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, | |
info="set to 1 to use cfg") | |
cfg_coef = gr.Number(label="cfg_coef", value=1.5, | |
info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") | |
prompt_length = gr.Number(label="prompt_length", value=3, | |
info="used for tts prompt, will automatically cut the prompt audio to this length") | |
sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") | |
top_p = gr.Number(label="top_p", value=0.8, info="0.9 is a good value, 0.8 is also good") | |
temperature = gr.Number(label="temperature", value=1, info="haven't try other values, do not change") | |
top_k = gr.Number(label="top_k", value=0, info="0 means we don't use topk sampling, because we use topp sampling") | |
codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000, info='encodec specific, do not change') | |
codec_sr = gr.Number(label="codec_sr", value=50, info='encodec specific, do not change') | |
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]", info="encodec specific, do not change") | |
success_output = gr.HTML() | |
load_models_btn.click(fn=load_models, | |
inputs=[ssrspeech_model_choice], | |
outputs=[models_selector, success_output]) | |
semgents = gr.State() # not used | |
transcribe_btn.click(fn=transcribe, | |
inputs=[input_audio], | |
outputs=[original_transcript, semgents, success_output]) | |
run_btn.click(fn=run, | |
inputs=[ | |
seed, sub_amount, ssrspeech_model_choice, | |
codec_audio_sr, codec_sr, | |
top_k, top_p, temperature, stop_repetition, kvcache, silence_tokens, | |
aug_text, cfg_coef, prompt_length, | |
input_audio, original_transcript, transcript, | |
mode | |
], | |
outputs=[output_audio, success_output]) | |
return app | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Ssrspeech gradio app.") | |
parser.add_argument("--demo-path", default="./demo", help="Path to demo directory") | |
parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory") | |
parser.add_argument("--models-path", default="./pretrained_models", help="Path to ssrspeech models directory") | |
parser.add_argument("--port", default=7860, type=int, help="App port") | |
parser.add_argument("--share", action="store_true", help="Launch with public url") | |
os.environ["USER"] = os.getenv("USER", "user") | |
args = parser.parse_args() | |
DEMO_PATH = args.demo_path | |
TMP_PATH = args.tmp_path | |
MODELS_PATH = args.models_path | |
app = get_app() | |
app.queue().launch(share=args.share, server_port=args.port) |