File size: 2,142 Bytes
8aa105f
10b5125
8cc3802
6322243
7b98f12
 
8cc3802
19b52d8
 
6322243
 
 
 
 
 
 
 
 
 
a78f1e0
5bad11f
a78f1e0
 
 
 
 
 
8cc3802
 
ffbecfe
8cc3802
 
 
 
 
7b98f12
8cc3802
9c31cdf
8cc3802
 
 
 
450be06
9c31cdf
 
 
 
 
83ffe8e
9c31cdf
ffbecfe
8cc3802
19b52d8
 
8cc3802
 
10b5125
f425faa
10b5125
8cc3802
6db40e9
9c31cdf
1ed7dcd
8aa105f
 
 
19b52d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import numpy as np
import os
import requests
import spaces

from fireredtts.fireredtts import FireRedTTS


def download_file(url, filename):
    response = requests.get(url) 
    if response.status_code == 200:
        with open(filename, 'wb') as file:
            file.write(response.content)
        print(f"File downloaded successfully: {filename}")
    else:
        print(f"Failed to download file: HTTP {response.status_code}")


if not os.path.exists('pretrained_models/fireredtts_gpt.pt'):
    print("Start to download checkpoints...")
    download_file('https://huggingface.co/fireredteam/FireRedTTS/resolve/main/fireredtts_gpt.pt',
                'pretrained_models/fireredtts_gpt.pt')
    download_file('https://huggingface.co/fireredteam/FireRedTTS/resolve/main/fireredtts_speaker.bin',
                'pretrained_models/fireredtts_speaker.bin')
    download_file('https://huggingface.co/fireredteam/FireRedTTS/resolve/main/fireredtts_token2wav.pt',
                'pretrained_models/fireredtts_token2wav.pt')


sampling_rate = 24000
tts = FireRedTTS(
    config_path="configs/config_24k.json",
    pretrained_path='pretrained_models',
)

@spaces.GPU
def tts_inference(text, prompt_wav='examples/prompt_1.wav', lang='zh'):
    # Model inference
    syn_audio = tts.synthesize(
        prompt_wav=prompt_wav,
        text=text,
        lang=lang,
    )[0].detach().cpu().numpy()

    # Normalize volume
    syn_audio = syn_audio / np.max(np.abs(syn_audio)) * 0.9
    
    # Convert audio data type
    syn_audio = (syn_audio * 32768).astype(np.int16)

    return sampling_rate, syn_audio


iface = gr.Interface(
    fn=tts_inference,
    inputs=[
        gr.Textbox(label="Input text here"),
        gr.Audio(type="filepath", label="Upload reference audio"),
        gr.Dropdown(["en", "zh"], label="Select language"),
    ],
    outputs=gr.Audio(label="Generated audio"),
    title="FireRedTTS: A Foundation Text-To-Speech Framework for Industry-Level Generative Speech Applications", 
    # description="Enter some text and listen to the generated speech."
)

if __name__ == "__main__":
    iface.launch()