File size: 6,209 Bytes
37ced70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import json
import torch
from fireredtts.modules.gpt.gpt import GPT
from fireredtts.modules import Token2Wav, MelSpectrogramExtractor
from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer
from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor
from fireredtts.utils.utils import load_audio

import time


class FireRedTTS:
    def __init__(self, config_path, pretrained_path, device="cuda"):
        self.device = device
        self.config = json.load(open(config_path))
        self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt")
        self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt")
        self.speaker_extractor_path = os.path.join(
            pretrained_path, "fireredtts_speaker.bin"
        )
        assert os.path.exists(self.token2wav_path)
        assert os.path.exists(self.gpt_path)
        assert os.path.exists(self.speaker_extractor_path)

        # tokenizer;
        self.text_tokenizer = VoiceBpeTokenizer()

        # speaker ectractor
        self.speaker_extractor = SpeakerEmbedddingExtractor(
            ckpt_path=self.speaker_extractor_path, device=device
        )

        # load gpt model
        self.gpt = GPT(
            start_text_token=self.config["gpt"]["gpt_start_text_token"],
            stop_text_token=self.config["gpt"]["gpt_stop_text_token"],
            layers=self.config["gpt"]["gpt_layers"],
            model_dim=self.config["gpt"]["gpt_n_model_channels"],
            heads=self.config["gpt"]["gpt_n_heads"],
            max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"],
            max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"],
            max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"],
            code_stride_len=self.config["gpt"]["gpt_code_stride_len"],
            number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"],
            num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"],
            start_audio_token=self.config["gpt"]["gpt_start_audio_token"],
            stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"],
        )

        sd = torch.load(self.gpt_path, map_location=device)["model"]
        self.gpt.load_state_dict(sd, strict=True)
        self.gpt = self.gpt.to(device=device)
        self.gpt.eval()
        self.gpt.init_gpt_for_inference(kv_cache=True)

        # mel-spectrogram extractor
        self.mel_extractor = MelSpectrogramExtractor()

        # load token2wav model
        self.token2wav = Token2Wav.init_from_config(self.config)
        sd = torch.load(self.token2wav_path, map_location="cpu")
        self.token2wav.load_state_dict(sd, strict=True)
        self.token2wav.generator.remove_weight_norm()
        self.token2wav.eval()
        self.token2wav = self.token2wav.to(device)

    def extract_spk_embeddings(self, prompt_wav):
        _, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000)
        audio_len = torch.tensor(
            data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
        )

        # speaker embeddings [1,512]
        spk_embeddings = self.speaker_extractor(
            audio_resampled.to(device="cuda")
        ).unsqueeze(0)

        return spk_embeddings

    def do_gpt_inference(self, spk_gpt, text_tokens):
        """_summary_

        Args:
            spk_gpt (_type_): speaker embeddidng in gpt
            text_tokens (_type_): text tokens
        """
        with torch.no_grad():
            gpt_codes = self.gpt.generate(
                cond_latents=spk_gpt,
                text_inputs=text_tokens,
                input_tokens=None,
                do_sample=True,
                top_p=0.85,
                top_k=30,
                temperature=0.75,
                num_return_sequences=9,
                num_beams=1,
                length_penalty=1.0,
                repetition_penalty=2.0,
                output_attentions=False,
            )

        seqs = []
        EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"]
        for seq in gpt_codes:
            index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0]
            seq = seq[:index]
            seqs.append(seq)

        sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
        gpt_codes = sorted_seqs[2].unsqueeze(0)  # [1, len]
        # sorted_len = [len(l) for l in sorted_seqs]
        # print("---sorted_len:", sorted_len)

        return gpt_codes

    def synthesize(self, prompt_wav, text, lang="auto"):
        """_summary_

        Args:
            prompts_wav (_type_): prompts_wav path
            text (_type_): text
            lang (_type_): language of text
        """
        # Currently only supports Chinese and English
        assert lang in ["zh", "en", "auto"]
        assert os.path.exists(prompt_wav)

        # text to tokens
        text_tokens = self.text_tokenizer.encode(text=text, lang=lang)
        text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
        assert text_tokens.shape[-1] < 400

        # extract speaker embedding
        spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0)
        with torch.no_grad():
            spk_gpt = self.gpt.reference_embedding(spk_embeddings)

        # gpt inference
        gpt_start_time = time.time()
        gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens)
        gpt_end_time = time.time()
        gpt_dur = gpt_end_time - gpt_start_time

        # prompt mel-spectrogram compute
        prompt_mel = (
            self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device)
        )
        # convert token to waveform (b=1, t)
        voc_start_time = time.time()
        rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10)
        voc_end_time = time.time()
        voc_dur = voc_end_time - voc_start_time
        all_dur = voc_end_time - gpt_start_time

        # rtf compute
        # audio_dur = rec_wavs.shape[-1] / 24000
        # rtf_gpt = gpt_dur / audio_dur
        # rtf_voc = voc_dur / audio_dur
        # rtf_all = all_dur / audio_dur

        return rec_wavs