File size: 4,744 Bytes
06d33a3
c6fbbbc
5f7d4f2
666d1ff
5f7d4f2
df2963e
5f7d4f2
 
666d1ff
5f7d4f2
c6fbbbc
af468c1
666d1ff
 
 
 
 
 
 
 
 
 
2ad7bbc
3e36194
666d1ff
3e36194
 
 
 
 
 
666d1ff
3e36194
 
 
 
 
 
 
 
666d1ff
 
2ad7bbc
5f7d4f2
2ad7bbc
3e36194
666d1ff
 
 
2ad7bbc
0350865
3e36194
5f7d4f2
3e36194
666d1ff
 
 
 
 
 
 
 
 
 
2ad7bbc
666d1ff
 
 
 
2c2732a
666d1ff
2c2732a
666d1ff
 
 
 
3e36194
666d1ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6fbbbc
5f7d4f2
666d1ff
c6fbbbc
5f7d4f2
ed08c96
1722436
 
 
 
c43cb5d
1722436
c43cb5d
1722436
c43cb5d
1722436
 
96f8e84
c43cb5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
import os
import sys
import tempfile
from scipy.io.wavfile import write
import numpy as np
from tqdm import tqdm
from underthesea import sent_tokenize

try:
    from TTS.tts.configs.xtts_config import XttsConfig
    from TTS.tts.models.xtts import Xtts
except ImportError:
    os.system("git clone https://github.com/hellcatmon/XTTSv2-Finetuning-for-New-Languages.git")
    if os.path.exists("XTTSv2-Finetuning-for-New-Languages/TTS"):
        os.system("mv XTTSv2-Finetuning-for-New-Languages/TTS ./")
    sys.path.append("./TTS")
    from TTS.tts.configs.xtts_config import XttsConfig
    from TTS.tts.models.xtts import Xtts

# Шляхі да файлаў (цяпер як радкі)
repo_id = "archivartaunik/BE_XTTS_V2_60epoch3Dataset"
model_dir = "./model"  # Дырэкторыя для захавання мадэлі
os.makedirs(model_dir, exist_ok=True) # Ствараем дырэкторыю, калі яе няма
checkpoint_file = os.path.join(model_dir, "model.pth")
config_file = os.path.join(model_dir, "config.json")
vocab_file = os.path.join(model_dir, "vocab.json")
default_voice_file = os.path.join(model_dir, "voice.wav")

if not os.path.exists(checkpoint_file):
    hf_hub_download(repo_id, filename="model.pth", local_dir=model_dir)
if not os.path.exists(config_file):
    hf_hub_download(repo_id, filename="config.json", local_dir=model_dir)
if not os.path.exists(vocab_file):
    hf_hub_download(repo_id, filename="vocab.json", local_dir=model_dir)
if not os.path.exists(default_voice_file):
    hf_hub_download(repo_id, filename="voice.wav", local_dir=model_dir)

# Загрузка канфігурацыі і мадэлі адзін раз
config = XttsConfig()
config.load_json(config_file)
XTTS_MODEL = Xtts.init_from_config(config)
XTTS_MODEL.load_checkpoint(config, checkpoint_path=checkpoint_file, vocab_path=vocab_file, use_deepspeed=False) # Тут выпраўленне
device = "cuda:0" if torch.cuda.is_available() else "cpu"
XTTS_MODEL.to(device)
sampling_rate = XTTS_MODEL.config.audio["sample_rate"]

@spaces.GPU(duration=60)
def text_to_speech(belarusian_story, speaker_audio_file=None):
    if not speaker_audio_file or (not isinstance(speaker_audio_file, str) and speaker_audio_file.name == ""):
        speaker_audio_file = default_voice_file

    try:
        gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
            audio_path=speaker_audio_file,
            gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
            max_ref_length=XTTS_MODEL.config.max_ref_len,
            sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
        )
    except Exception as e:
        return f"Error getting conditioning latents: {e}"

    try:
        tts_texts = sent_tokenize(belarusian_story)
    except Exception as e:
        return f"Error tokenizing text: {e}"

    all_wavs = []
    for text in tqdm(tts_texts):
        try:
            with torch.no_grad():
                wav_chunk = XTTS_MODEL.inference(
                    text=text,
                    language="be",
                    gpt_cond_latent=gpt_cond_latent,
                    speaker_embedding=speaker_embedding,
                    temperature=0.1,
                    length_penalty=1.0,
                    repetition_penalty=10.0,
                    top_k=10,
                    top_p=0.3,
                )
            all_wavs.append(wav_chunk["wav"])
        except Exception as e:
            return f"Error generating audio: {e}"

    try:
        out_wav = np.concatenate(all_wavs)
    except ValueError:
        return "Немагчыма згенерыраваць аўдыё. Праверце ўваходны тэкст і аўдыёфайл."
    except Exception as e:
        return f"Error concatenating audio: {e}"

    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
    write(temp_file.name, sampling_rate, out_wav)

    return temp_file.name

demo = gr.Interface(
    fn=text_to_speech,
    inputs=[
        gr.Textbox(lines=5, label="Тэкст на беларускай мове"),
        gr.Audio(type="filepath", label="Запішыце або загрузіце файл голасу (без іншых гукаў) не карацей 7 секунд", interactive=True),
    ],
    outputs="audio",
    title="XTTS Belarusian TTS Demo",
    description="Увядзіце тэкст, і мадэль пераўтворыць яго ў аўдыя. Вы можаце выкарыстоўваць голас па змаўчанні, загрузіць уласны файл або запісаць аўдыё.",
)

if __name__ == "__main__":
    demo.launch()