Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model:
|
3 |
+
- proxectonos/Nos_TTS-celtia-vits-graphemes
|
4 |
+
---
|
5 |
+
onnx converted model for https://huggingface.co/proxectonos/Nos_TTS-celtia-vits-graphemes
|
6 |
+
|
7 |
+
for inference
|
8 |
+
````python
|
9 |
+
# minimal onnx inference extracted from coqui-tts
|
10 |
+
import json
|
11 |
+
import re
|
12 |
+
from typing import Callable, List
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import onnxruntime as ort
|
16 |
+
import scipy
|
17 |
+
|
18 |
+
# Regular expression matching whitespace:
|
19 |
+
_whitespace_re = re.compile(r"\s+")
|
20 |
+
|
21 |
+
|
22 |
+
class Graphemes:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
characters: str = None,
|
26 |
+
punctuations: str = None,
|
27 |
+
pad: str = None,
|
28 |
+
eos: str = None,
|
29 |
+
bos: str = None,
|
30 |
+
blank: str = "<BLNK>",
|
31 |
+
is_unique: bool = False,
|
32 |
+
is_sorted: bool = True,
|
33 |
+
) -> None:
|
34 |
+
self._characters = characters
|
35 |
+
self._punctuations = punctuations
|
36 |
+
self._pad = pad
|
37 |
+
self._eos = eos
|
38 |
+
self._bos = bos
|
39 |
+
self._blank = blank
|
40 |
+
self.is_unique = is_unique
|
41 |
+
self.is_sorted = is_sorted
|
42 |
+
self._create_vocab()
|
43 |
+
|
44 |
+
@property
|
45 |
+
def pad_id(self) -> int:
|
46 |
+
return self.char_to_id(self.pad) if self.pad else len(self.vocab)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def blank_id(self) -> int:
|
50 |
+
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
51 |
+
|
52 |
+
@property
|
53 |
+
def eos_id(self) -> int:
|
54 |
+
return self.char_to_id(self.eos) if self.eos else len(self.vocab)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def bos_id(self) -> int:
|
58 |
+
return self.char_to_id(self.bos) if self.bos else len(self.vocab)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def characters(self):
|
62 |
+
return self._characters
|
63 |
+
|
64 |
+
@characters.setter
|
65 |
+
def characters(self, characters):
|
66 |
+
self._characters = characters
|
67 |
+
self._create_vocab()
|
68 |
+
|
69 |
+
@property
|
70 |
+
def punctuations(self):
|
71 |
+
return self._punctuations
|
72 |
+
|
73 |
+
@punctuations.setter
|
74 |
+
def punctuations(self, punctuations):
|
75 |
+
self._punctuations = punctuations
|
76 |
+
self._create_vocab()
|
77 |
+
|
78 |
+
@property
|
79 |
+
def pad(self):
|
80 |
+
return self._pad
|
81 |
+
|
82 |
+
@pad.setter
|
83 |
+
def pad(self, pad):
|
84 |
+
self._pad = pad
|
85 |
+
self._create_vocab()
|
86 |
+
|
87 |
+
@property
|
88 |
+
def eos(self):
|
89 |
+
return self._eos
|
90 |
+
|
91 |
+
@eos.setter
|
92 |
+
def eos(self, eos):
|
93 |
+
self._eos = eos
|
94 |
+
self._create_vocab()
|
95 |
+
|
96 |
+
@property
|
97 |
+
def bos(self):
|
98 |
+
return self._bos
|
99 |
+
|
100 |
+
@bos.setter
|
101 |
+
def bos(self, bos):
|
102 |
+
self._bos = bos
|
103 |
+
self._create_vocab()
|
104 |
+
|
105 |
+
@property
|
106 |
+
def blank(self):
|
107 |
+
return self._blank
|
108 |
+
|
109 |
+
@blank.setter
|
110 |
+
def blank(self, blank):
|
111 |
+
self._blank = blank
|
112 |
+
self._create_vocab()
|
113 |
+
|
114 |
+
@property
|
115 |
+
def vocab(self):
|
116 |
+
return self._vocab
|
117 |
+
|
118 |
+
@vocab.setter
|
119 |
+
def vocab(self, vocab):
|
120 |
+
self._vocab = vocab
|
121 |
+
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
122 |
+
self._id_to_char = {
|
123 |
+
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
124 |
+
}
|
125 |
+
|
126 |
+
@property
|
127 |
+
def num_chars(self):
|
128 |
+
return len(self._vocab)
|
129 |
+
|
130 |
+
def _create_vocab(self):
|
131 |
+
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
132 |
+
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
133 |
+
# pylint: disable=unnecessary-comprehension
|
134 |
+
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
135 |
+
|
136 |
+
def char_to_id(self, char: str) -> int:
|
137 |
+
try:
|
138 |
+
return self._char_to_id[char]
|
139 |
+
except KeyError as e:
|
140 |
+
raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
|
141 |
+
|
142 |
+
def id_to_char(self, idx: int) -> str:
|
143 |
+
return self._id_to_char[idx]
|
144 |
+
|
145 |
+
|
146 |
+
class TTSTokenizer:
|
147 |
+
"""🐸TTS tokenizer to convert input characters to token IDs and back.
|
148 |
+
|
149 |
+
Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
characters (Characters):
|
153 |
+
A Characters object to use for character-to-ID and ID-to-character mappings.
|
154 |
+
|
155 |
+
text_cleaner (callable):
|
156 |
+
A function to pre-process the text before tokenization and phonemization. Defaults to None.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
text_cleaner: Callable = None,
|
162 |
+
characters: Graphemes = None,
|
163 |
+
add_blank: bool = False,
|
164 |
+
use_eos_bos=False,
|
165 |
+
):
|
166 |
+
self.text_cleaner = text_cleaner
|
167 |
+
self.add_blank = add_blank
|
168 |
+
self.use_eos_bos = use_eos_bos
|
169 |
+
self.characters = characters
|
170 |
+
self.not_found_characters = []
|
171 |
+
|
172 |
+
@property
|
173 |
+
def characters(self):
|
174 |
+
return self._characters
|
175 |
+
|
176 |
+
@characters.setter
|
177 |
+
def characters(self, new_characters):
|
178 |
+
self._characters = new_characters
|
179 |
+
self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
|
180 |
+
self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
|
181 |
+
|
182 |
+
def encode(self, text: str) -> List[int]:
|
183 |
+
"""Encodes a string of text as a sequence of IDs."""
|
184 |
+
token_ids = []
|
185 |
+
for char in text:
|
186 |
+
try:
|
187 |
+
idx = self.characters.char_to_id(char)
|
188 |
+
token_ids.append(idx)
|
189 |
+
except KeyError:
|
190 |
+
# discard but store not found characters
|
191 |
+
if char not in self.not_found_characters:
|
192 |
+
self.not_found_characters.append(char)
|
193 |
+
print(text)
|
194 |
+
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
195 |
+
return token_ids
|
196 |
+
|
197 |
+
def text_to_ids(self, text: str) -> List[int]: # pylint: disable=unused-argument
|
198 |
+
"""Converts a string of text to a sequence of token IDs.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
text(str):
|
202 |
+
The text to convert to token IDs.
|
203 |
+
|
204 |
+
1. Text normalization
|
205 |
+
3. Add blank char between characters
|
206 |
+
4. Add BOS and EOS characters
|
207 |
+
5. Text to token IDs
|
208 |
+
"""
|
209 |
+
if self.text_cleaner is not None:
|
210 |
+
text = self.text_cleaner(text)
|
211 |
+
text = self.encode(text)
|
212 |
+
if self.add_blank:
|
213 |
+
text = self.intersperse_blank_char(text, True)
|
214 |
+
if self.use_eos_bos:
|
215 |
+
text = self.pad_with_bos_eos(text)
|
216 |
+
return text
|
217 |
+
|
218 |
+
def pad_with_bos_eos(self, char_sequence: List[str]):
|
219 |
+
"""Pads a sequence with the special BOS and EOS characters."""
|
220 |
+
return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]
|
221 |
+
|
222 |
+
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
|
223 |
+
"""Intersperses the blank character between characters in a sequence.
|
224 |
+
|
225 |
+
Use the ```blank``` character if defined else use the ```pad``` character.
|
226 |
+
"""
|
227 |
+
char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
|
228 |
+
result = [char_to_use] * (len(char_sequence) * 2 + 1)
|
229 |
+
result[1::2] = char_sequence
|
230 |
+
return result
|
231 |
+
|
232 |
+
|
233 |
+
class VitsOnnxInference:
|
234 |
+
def __init__(self, onnx_model_path: str, config_path: str, cuda=False):
|
235 |
+
self.config = {}
|
236 |
+
if config_path:
|
237 |
+
with open(config_path) as f:
|
238 |
+
self.config = json.load(f)
|
239 |
+
providers = [
|
240 |
+
"CPUExecutionProvider"
|
241 |
+
if cuda is False
|
242 |
+
else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
|
243 |
+
]
|
244 |
+
sess_options = ort.SessionOptions()
|
245 |
+
self.onnx_sess = ort.InferenceSession(
|
246 |
+
onnx_model_path,
|
247 |
+
sess_options=sess_options,
|
248 |
+
providers=providers,
|
249 |
+
)
|
250 |
+
|
251 |
+
_pad = self.config.get("characters", {}).get("pad", "_")
|
252 |
+
_punctuations = self.config.get("characters", {}).get("punctuations", "!\"(),-.:;?\u00a1\u00bf ")
|
253 |
+
_letters = self.config.get("characters", {}).get("characters",
|
254 |
+
"ABCDEFGHIJKLMNOPQRSTUVXYZabcdefghijklmnopqrstuvwxyz\u00c1\u00c9\u00cd\u00d3\u00da\u00e1\u00e9\u00ed\u00f1\u00f3\u00fa\u00fc")
|
255 |
+
|
256 |
+
vocab = Graphemes(characters=_letters,
|
257 |
+
punctuations=_punctuations,
|
258 |
+
pad=_pad)
|
259 |
+
|
260 |
+
self.tokenizer = TTSTokenizer(
|
261 |
+
text_cleaner=self.normalize_text,
|
262 |
+
characters=vocab,
|
263 |
+
add_blank=self.config.get("add_blank", True),
|
264 |
+
use_eos_bos=False,
|
265 |
+
)
|
266 |
+
|
267 |
+
@staticmethod
|
268 |
+
def normalize_text(text: str) -> str:
|
269 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
270 |
+
text = text.lower()
|
271 |
+
text = text.replace(";", ",")
|
272 |
+
text = text.replace("-", " ")
|
273 |
+
text = text.replace(":", ",")
|
274 |
+
text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text)
|
275 |
+
text = re.sub(_whitespace_re, " ", text).strip()
|
276 |
+
return text
|
277 |
+
|
278 |
+
def inference_onnx(self, text: str):
|
279 |
+
"""ONNX inference"""
|
280 |
+
x = np.asarray(
|
281 |
+
self.tokenizer.text_to_ids(text),
|
282 |
+
dtype=np.int64,
|
283 |
+
)[None, :]
|
284 |
+
|
285 |
+
x_lengths = np.array([x.shape[1]], dtype=np.int64)
|
286 |
+
|
287 |
+
scales = np.array(
|
288 |
+
[self.config.get("inference_noise_scale", 0.667),
|
289 |
+
self.config.get("length_scale", 1.0),
|
290 |
+
self.config.get("inference_noise_scale_dp", 1.0), ],
|
291 |
+
dtype=np.float32,
|
292 |
+
)
|
293 |
+
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
|
294 |
+
|
295 |
+
audio = self.onnx_sess.run(
|
296 |
+
["output"],
|
297 |
+
input_params,
|
298 |
+
)
|
299 |
+
return audio[0][0]
|
300 |
+
|
301 |
+
@staticmethod
|
302 |
+
def save_wav(wav: np.ndarray, path: str, sample_rate: int = 16000) -> None:
|
303 |
+
"""Save float waveform to a file using Scipy.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
307 |
+
path (str): Path to a output file.
|
308 |
+
"""
|
309 |
+
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
310 |
+
wav_norm = wav_norm.astype(np.int16)
|
311 |
+
scipy.io.wavfile.write(path, sample_rate, wav_norm)
|
312 |
+
|
313 |
+
def synth(self, text: str, path: str):
|
314 |
+
wavs = self.inference_onnx(text)
|
315 |
+
self.save_wav(wavs[0], path, self.config.get("sample_rate", 16000))
|
316 |
+
```
|