File size: 1,950 Bytes
8aa105f
10b5125
8cc3802
6322243
8cc3802
19b52d8
 
6322243
 
 
 
 
 
 
 
 
 
a78f1e0
5bad11f
a78f1e0
 
 
 
 
 
8cc3802
 
ffbecfe
8cc3802
 
 
 
 
 
 
 
 
 
450be06
b6d3994
83ffe8e
ffbecfe
8cc3802
19b52d8
 
8cc3802
 
10b5125
f425faa
10b5125
8cc3802
6db40e9
19b52d8
1ed7dcd
8aa105f
 
 
19b52d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import numpy as np
import os
import requests
from fireredtts.fireredtts import FireRedTTS


def download_file(url, filename):
    response = requests.get(url) 
    if response.status_code == 200:
        with open(filename, 'wb') as file:
            file.write(response.content)
        print(f"File downloaded successfully: {filename}")
    else:
        print(f"Failed to download file: HTTP {response.status_code}")


if not os.path.exists('pretrained_models/fireredtts_gpt.pt'):
    print("Start to download checkpoints...")
    download_file('https://huggingface.co/fireredteam/FireRedTTS/resolve/main/fireredtts_gpt.pt',
                'pretrained_models/fireredtts_gpt.pt')
    download_file('https://huggingface.co/fireredteam/FireRedTTS/resolve/main/fireredtts_speaker.bin',
                'pretrained_models/fireredtts_speaker.bin')
    download_file('https://huggingface.co/fireredteam/FireRedTTS/resolve/main/fireredtts_token2wav.pt',
                'pretrained_models/fireredtts_token2wav.pt')


sampling_rate = 24000
tts = FireRedTTS(
    config_path="configs/config_24k.json",
    pretrained_path='pretrained_models',
)

def tts_inference(text, prompt_wav='examples/prompt_1.wav', lang='zh'):
    syn_audio = tts.synthesize(
        prompt_wav=prompt_wav,
        text=text,
        lang=lang,
    )[0].detach().cpu().numpy()
    print(f'Generate waveform with the shape of {syn_audio.shape}')
    syn_audio = (syn_audio * 32768).astype(np.int16)
    return sampling_rate, syn_audio


iface = gr.Interface(
    fn=tts_inference,
    inputs=[
        gr.Textbox(label="Input text here"),
        gr.Audio(type="filepath", label="Upload reference audio"),
        gr.Dropdown(["en", "zh"], label="Select language"),
    ],
    outputs=gr.Audio(label="Generated audio"),
    title="TTS Demo", 
    # description="Enter some text and listen to the generated speech."
)

if __name__ == "__main__":
    iface.launch()