File size: 3,354 Bytes
72d1bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b47d35
72d1bae
0b47d35
72d1bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch

from transformers import VitsModel, VitsTokenizer, set_seed


title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
    <div
        style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
    > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
        VITS TTS Demo
    </h1> </div>
</div>
 """

description = """
VITS is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. It is a conditional variational autoencoder (VAE) comprised of a posterior encoder, decoder, and conditional prior.

This demo showcases the official VITS checkpoints, trained on the [LJSpeech](https://huggingface.co/kakao-enterprise/vits-ljs) and [VCTK](https://huggingface.co/kakao-enterprise/vits-vctk) datasets.
"""

article = "Model by Jaehyeon Kim et al. from Kakao Enterprise. Code and demo by 🤗 Hugging Face."

ljs_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs")
ljs_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-ljs")

vctk_model = VitsModel.from_pretrained("kakao-enterprise/vits-vctk")
vctk_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-vctk")

device = "cuda" if torch.cuda.is_available() else "cpu"
ljs_model.to(device)
vctk_model.to(device)

def ljs_forward(text, speaking_rate=1.0):
    inputs = ljs_tokenizer(text, return_tensors="pt")

    ljs_model.speaking_rate = speaking_rate
    set_seed(555)
    with torch.no_grad():
        outputs = ljs_model(**inputs)[0]

    waveform = outputs[0].cpu().float().numpy()
    return gr.make_waveform((22050, waveform))


def vctk_forward(text, speaking_rate=1.0, speaker_id=1):
    inputs = vctk_tokenizer(text, return_tensors="pt")

    vctk_model.speaking_rate = speaking_rate
    set_seed(555)
    with torch.no_grad():
        outputs = vctk_model(**inputs, speaker_id=speaker_id - 1)[0]

    waveform = outputs[0].cpu().float().numpy()
    return gr.make_waveform((22050, waveform))


ljs_inference = gr.Interface(
    fn=ljs_forward,
    inputs=[
        gr.Textbox(
            value="Hey, it's Hugging Face on the phone",
            max_lines=1,
            label="Input text",
        ),
        gr.Slider(
            0.5,
            1.5,
            value=1,
            step=0.1,
            label="Speaking rate",
        ),
    ],
    outputs=gr.Audio(),
)

vctk_inference = gr.Interface(
    fn=vctk_forward,
    inputs=[
        gr.Textbox(
            value="Hey, it's Hugging Face on the phone",
            max_lines=1,
            label="Input text",
        ),
        gr.Slider(
            0.5,
            1.5,
            value=1,
            step=0.1,
            label="Speaking rate",
        ),
        gr.Slider(
            1,
            vctk_model.config.num_speakers,
            value=1,
            step=1,
            label="Speaker id",
            info=f"The VCTK model is trained on {vctk_model.config.num_speakers} speakers. You can prompt the model using one of these speaker ids.",
        ),
    ],
    outputs=gr.Audio(),
)

demo = gr.Blocks()

with demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.TabbedInterface([ljs_inference, vctk_inference], ["LJ Speech", "VCTK"])
    gr.Markdown(article)

demo.queue(max_size=10)
demo.launch()