tjysdsg's picture
Add more description
17c8086
raw
history blame
5 kB
import os
import gradio as gr
import torchaudio
from typing import Tuple, Optional
import soundfile as sf
from s2st_inference import s2st_inference
from utils import download_model
DESCRIPTION = r"""
**Speech-to-Speech Translation from Spanish to English**
- Paper: Direct Speech-to-Speech Translation With Discrete Units
- Dataset: CVSS-C
- Toolkit: [ESPnet](https://github.com/espnet/espnet)
- Pretrained Speech-to-Unit translation model: https://huggingface.co/espnet/jiyang_tang_cvss-c_es-en_discrete_unit
- Pretrained WaveGAN vocoder: https://huggingface.co/espnet/cvss-c_en_wavegan_hubert_vocoder
Part of a CMU MIIS capstone project with [@realzza](https://github.com/realzza)
and [@sophia1488](https://github.com/sophia1488)
"""
SAMPLE_RATE = 16000
MAX_INPUT_LENGTH = 60 # seconds
S2UT_TAG = 'espnet/jiyang_tang_cvss-c_es-en_discrete_unit'
S2UT_DIR = 'model'
VOCODER_TAG = 'espnet/cvss-c_en_wavegan_hubert_vocoder'
VOCODER_DIR = 'vocoder'
NGPU = 0
BEAM_SIZE = 1
class App:
def __init__(self):
# Download models
os.makedirs(S2UT_DIR, exist_ok=True)
os.makedirs(VOCODER_DIR, exist_ok=True)
self.s2ut_path = download_model(S2UT_TAG, S2UT_DIR)
self.vocoder_path = download_model(VOCODER_TAG, VOCODER_DIR)
def s2st(
self,
input_audio: Optional[str],
):
orig_wav, orig_sr = torchaudio.load(input_audio)
wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE)
max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE)
if wav.shape[1] > max_length:
wav = wav[:, :max_length]
gr.Warning(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")
wav = wav[0] # mono
# Temporary change cwd to model dir so that it loads correctly
cwd = os.getcwd()
os.chdir(self.s2ut_path)
# Translate wav
out_wav = s2st_inference(
wav,
train_config=os.path.join(
self.s2ut_path,
'exp',
's2st_train_s2st_discrete_unit_raw_fbank_es_en',
'config.yaml',
),
model_file=os.path.join(
self.s2ut_path,
'exp',
's2st_train_s2st_discrete_unit_raw_fbank_es_en',
'500epoch.pth',
),
vocoder_file=os.path.join(
self.vocoder_path,
'checkpoint-450000steps.pkl',
),
vocoder_config=os.path.join(
self.vocoder_path,
'config.yml',
),
ngpu=NGPU,
beam_size=BEAM_SIZE,
)
# Restore working directory
os.chdir(cwd)
# Save result
output_path = 'output.wav'
sf.write(
output_path,
out_wav,
16000,
"PCM_16",
)
return output_path
def main():
app = App()
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
input_audio = gr.Audio(
label="Input speech",
type="filepath",
sources=["upload", "microphone"],
format='wav',
streaming=False,
visible=True,
)
btn = gr.Button("Translate")
output_audio = gr.Audio(
label="Translated speech",
autoplay=False,
streaming=False,
type="numpy",
)
# Placeholders so that the example section can show these values
source_text = gr.Text(label='Source Text', visible=False)
target_text = gr.Text(label='Target Text', visible=False)
# Examples
with gr.Row():
gr.Examples(
examples=[
["examples/example1.wav",
"fue enterrada en el cementerio forest lawn memorial park de hollywood hills",
"she was buried at the forest lawn memorial park of hollywood hills"],
["examples/example2.wav",
"diversos otros m煤sicos han interpretado esta canci贸n en conciertos en vivo",
"many other musicians have played this song in live concerts"],
["examples/example3.wav",
"es g贸mez-moreno el primero en situar su origen en guadalajara, hoy ampliamente aceptado",
"gomez moreno was the first person to place its origin in guadalajara which is now broadly accepted"],
],
inputs=[input_audio, source_text, target_text],
outputs=[output_audio],
)
btn.click(
fn=app.s2st,
inputs=[input_audio],
outputs=[output_audio],
api_name="run",
)
demo.queue(max_size=50).launch()
if __name__ == '__main__':
main()