File size: 2,871 Bytes
3fb186a
 
 
 
334cc2e
 
3fb186a
334cc2e
 
3fb186a
 
 
 
 
cb1ad55
 
3fb186a
cb1ad55
 
 
 
 
 
 
 
 
334cc2e
 
cb1ad55
334cc2e
cb1ad55
334cc2e
cb1ad55
 
 
 
334cc2e
cb1ad55
 
3fb186a
334cc2e
 
 
 
 
 
 
cb1ad55
 
3fb186a
334cc2e
 
 
3fb186a
 
 
334cc2e
 
 
 
 
 
3fb186a
334cc2e
 
 
cb1ad55
 
 
 
 
 
334cc2e
cb1ad55
 
 
 
 
 
 
 
 
 
 
 
 
334cc2e
3fb186a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import tempfile
from scipy.io.wavfile import write
import gradio as gr
from transformers import VitsTokenizer, VitsModel, set_seed, pipeline
import torch
from datetime import datetime


model_name = "leks-forever/vits_lez_tts"  
tokenizer = VitsTokenizer.from_pretrained(model_name)
model = VitsModel.from_pretrained(model_name)

tts_pipeline = pipeline("text-to-speech", model=model_name)

new_sentence = '!.?'
in_sentence = ',-.:;'


def canonize_lez(text):
    for abruptive_letter in ['к', 'К', 'п', 'П', 'т', 'Т', 'ц', 'Ц', 'ч', 'Ч']:
        for abruptive_symbol in ['1', 'l', 'i', 'I', '|', 'ӏ', 'Ӏ', 'ӏ']:
            text = text.replace(abruptive_letter+abruptive_symbol, abruptive_letter+'Ӏ')
    return text


def tts_function(input_text, speaking_rate, noise_scale, add_pauses):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    fixed_text = canonize_lez(input_text)

    if add_pauses:
            
        for symb in new_sentence:
            fixed_text = fixed_text.replace(symb, '   ')

        for symb in in_sentence:
            fixed_text = fixed_text.replace(symb, '  ')

    inputs = tokenizer(text=fixed_text, return_tensors="pt")

    inputs = {key: value.to(device) for key, value in inputs.items()}

    model.to(device)
    model.eval()

    set_seed(900)

    model.speaking_rate = speaking_rate
    model.noise_scale = noise_scale

    with torch.no_grad():
        outputs = model(**inputs)
        waveform = outputs.waveform[0]

    waveform = waveform.detach().cpu().float().numpy()

    sampling_rate = model.config.sampling_rate

    timestamp = datetime.now().strftime("H%M%S")
    filename_part = input_text[:20].replace(' ', '_')  
    filename = f"{filename_part}_{timestamp}.wav"

    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
        write(filename, rate=sampling_rate, data=waveform)
        return filename


with gr.Blocks() as interface:
    gr.Markdown("### Text-to-speech Лезги ЧIалал")

    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="Введите текст на лезгинском", lines=4)
            add_pauses = gr.Checkbox(label="Добавить больше пауз у знаков препинания", value=False)
            speaking_rate = gr.Slider(label="Скорость речи (speaking_rate)", minimum=0, maximum=2, step=0.1, value=0.9)
            noise_scale = gr.Slider(label="Шум (noise_scale)", minimum=0, maximum=5, step=0.1, value=0)
            submit_button = gr.Button("Сгенерировать")
        
        with gr.Column():
            output_audio = gr.Audio(label="Аудио")

    submit_button.click(
        fn=tts_function,
        inputs=[input_text, speaking_rate, noise_scale, add_pauses],
        outputs=output_audio,
    )
    
interface.launch()