Gregniuki commited on
Commit
58187c4
·
1 Parent(s): 38b2021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -0
app.py CHANGED
@@ -7,3 +7,348 @@ app = FastAPI()
7
  @app.get("/")
8
  def read_root():
9
  return {"message": "Hello, World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @app.get("/")
8
  def read_root():
9
  return {"message": "Hello, World!"}
10
+ def detect_onnx_models(path):
11
+ onnx_models = glob.glob(path + '/*.onnx')
12
+ if len(onnx_models) > 1:
13
+ return onnx_models
14
+ elif len(onnx_models) == 1:
15
+ return onnx_models[0]
16
+ else:
17
+ return None
18
+
19
+
20
+ def main():
21
+ """Main entry point"""
22
+ models_path = "/content/piper/src/python"
23
+ logging.basicConfig(level=logging.DEBUG)
24
+ providers = [
25
+ "CPUExecutionProvider"
26
+ if use_gpu is False
27
+ else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
28
+ ]
29
+ sess_options = onnxruntime.SessionOptions()
30
+ model = None
31
+ onnx_models = detect_onnx_models(models_path)
32
+ speaker_selection = widgets.Dropdown(
33
+ options=[],
34
+ description=f'{lan.translate(lang, "Select speaker")}:',
35
+ layout={'visibility': 'hidden'}
36
+ )
37
+ if onnx_models is None:
38
+ if enhanced_accessibility:
39
+ playaudio("novoices")
40
+ raise Exception(lan.translate(lang, "No downloaded voice packages!"))
41
+ elif isinstance(onnx_models, str):
42
+ onnx_model = onnx_models
43
+ model, config = load_onnx(onnx_model, sess_options, providers)
44
+ if config["num_speakers"] > 1:
45
+ speaker_selection.options = config["speaker_id_map"].values()
46
+ speaker_selection.layout.visibility = 'visible'
47
+ preview_sid = 0
48
+ if enhanced_accessibility:
49
+ playaudio("multispeaker")
50
+ else:
51
+ speaker_selection.layout.visibility = 'hidden'
52
+ preview_sid = None
53
+
54
+ if enhanced_accessibility:
55
+ inferencing(
56
+ model,
57
+ config,
58
+ preview_sid,
59
+ lan.translate(
60
+ config["espeak"]["voice"][:2],
61
+ "Interface openned. Write your texts, configure the different synthesis options or download all the voices you want. Enjoy!"
62
+ )
63
+ )
64
+ else:
65
+ voice_model_names = []
66
+ for current in onnx_models:
67
+ voice_struct = current.split("/")[5]
68
+ voice_model_names.append(voice_struct)
69
+ if enhanced_accessibility:
70
+ playaudio("selectmodel")
71
+ selection = widgets.Dropdown(
72
+ options=voice_model_names,
73
+ description=f'{lan.translate(lang, "Select voice package")}:',
74
+ )
75
+ load_btn = widgets.Button(
76
+ description=lan.translate(lang, "Load it!")
77
+ )
78
+ config = None
79
+ def load_model(button):
80
+ nonlocal config
81
+ global onnx_model
82
+ nonlocal model
83
+ nonlocal models_path
84
+ selected_voice = selection.value
85
+ onnx_model = f"{models_path}/{selected_voice}"
86
+ model, config = load_onnx(onnx_model, sess_options, providers)
87
+ if enhanced_accessibility:
88
+ playaudio("loaded")
89
+ if config["num_speakers"] > 1:
90
+ speaker_selection.options = config["speaker_id_map"].values()
91
+ speaker_selection.layout.visibility = 'visible'
92
+ if enhanced_accessibility:
93
+ playaudio("multispeaker")
94
+ else:
95
+ speaker_selection.layout.visibility = 'hidden'
96
+
97
+ load_btn.on_click(load_model)
98
+ display(selection, load_btn)
99
+ display(speaker_selection)
100
+ speed_slider = widgets.FloatSlider(
101
+ value=1,
102
+ min=0.25,
103
+ max=4,
104
+ step=0.1,
105
+ description=lan.translate(lang, "Rate scale"),
106
+ orientation='horizontal',
107
+ )
108
+ noise_scale_slider = widgets.FloatSlider(
109
+ value=0.667,
110
+ min=0.25,
111
+ max=4,
112
+ step=0.1,
113
+ description=lan.translate(lang, "Phoneme noise scale"),
114
+ orientation='horizontal',
115
+ )
116
+ noise_scale_w_slider = widgets.FloatSlider(
117
+ value=1,
118
+ min=0.25,
119
+ max=4,
120
+ step=0.1,
121
+ description=lan.translate(lang, "Phoneme stressing scale"),
122
+ orientation='horizontal',
123
+ )
124
+ play = widgets.Checkbox(
125
+ value=True,
126
+ description=lan.translate(lang, "Auto-play"),
127
+ disabled=False
128
+ )
129
+ text_input = widgets.Text(
130
+ value='',
131
+ placeholder=f'{lan.translate(lang, "Enter your text here")}:',
132
+ description=lan.translate(lang, "Text to synthesize"),
133
+ layout=widgets.Layout(width='80%')
134
+ )
135
+ synthesize_button = widgets.Button(
136
+ description=lan.translate(lang, "Synthesize"),
137
+ button_style='success', # 'success', 'info', 'warning', 'danger' or ''
138
+ tooltip=lan.translate(lang, "Click here to synthesize the text."),
139
+ icon='check'
140
+ )
141
+ close_button = widgets.Button(
142
+ description=lan.translate(lang, "Exit"),
143
+ tooltip=lan.translate(lang, "Closes this GUI."),
144
+ icon='check'
145
+ )
146
+
147
+ def on_synthesize_button_clicked(b):
148
+ if model is None:
149
+ if enhanced_accessibility:
150
+ playaudio("nomodel")
151
+ raise Exception(lan.translate(lang, "You have not loaded any model from the list!"))
152
+ text = text_input.value
153
+ if config["num_speakers"] > 1:
154
+ sid = speaker_selection.value
155
+ else:
156
+ sid = None
157
+ rate = speed_slider.value
158
+ noise_scale = noise_scale_slider.value
159
+ noise_scale_w = noise_scale_w_slider.value
160
+ auto_play = play.value
161
+ inferencing(model, config, sid, text, rate, noise_scale, noise_scale_w, auto_play)
162
+
163
+ def on_close_button_clicked(b):
164
+ clear_output()
165
+ if enhanced_accessibility:
166
+ playaudio("exit")
167
+
168
+ synthesize_button.on_click(on_synthesize_button_clicked)
169
+ close_button.on_click(on_close_button_clicked)
170
+ display(text_input)
171
+ display(speed_slider)
172
+ display(noise_scale_slider)
173
+ display(noise_scale_w_slider)
174
+ display(play)
175
+ display(synthesize_button)
176
+ display(close_button)
177
+
178
+ def load_onnx(model, sess_options, providers = ["CPUExecutionProvider"]):
179
+ _LOGGER.debug("Loading model from %s", model)
180
+ config = load_config(model)
181
+ model = onnxruntime.InferenceSession(
182
+ str(model),
183
+ sess_options=sess_options,
184
+ providers= providers
185
+ )
186
+ _LOGGER.info("Loaded model from %s", model)
187
+ return model, config
188
+
189
+ def load_config(model):
190
+ with open(f"{model}.json", "r") as file:
191
+ config = json.load(file)
192
+ return config
193
+ PAD = "_" # padding (0)
194
+ BOS = "^" # beginning of sentence
195
+ EOS = "$" # end of sentence
196
+
197
+ class PhonemeType(str, Enum):
198
+ ESPEAK = "espeak"
199
+ TEXT = "text"
200
+
201
+ def phonemize(config, text: str) -> List[List[str]]:
202
+ """Text to phonemes grouped by sentence."""
203
+ if config["phoneme_type"] == PhonemeType.ESPEAK:
204
+ if config["espeak"]["voice"] == "ar":
205
+ # Arabic diacritization
206
+ # https://github.com/mush42/libtashkeel/
207
+ text = tashkeel_run(text)
208
+ return phonemize_espeak(text, config["espeak"]["voice"])
209
+ if config["phoneme_type"] == PhonemeType.TEXT:
210
+ return phonemize_codepoints(text)
211
+ raise ValueError(f'Unexpected phoneme type: {config["phoneme_type"]}')
212
+
213
+ def phonemes_to_ids(config, phonemes: List[str]) -> List[int]:
214
+ """Phonemes to ids."""
215
+ id_map = config["phoneme_id_map"]
216
+ ids: List[int] = list(id_map[BOS])
217
+ for phoneme in phonemes:
218
+ if phoneme not in id_map:
219
+ print("Missing phoneme from id map: %s", phoneme)
220
+ continue
221
+ ids.extend(id_map[phoneme])
222
+ ids.extend(id_map[PAD])
223
+ ids.extend(id_map[EOS])
224
+ return ids
225
+
226
+ def inferencing(model, config, sid, line, length_scale = 1, noise_scale = 0.667, noise_scale_w = 0.8, auto_play=True):
227
+ audios = []
228
+ if config["phoneme_type"] == "PhonemeType.ESPEAK":
229
+ config["phoneme_type"] = "espeak"
230
+ text = phonemize(config, line)
231
+ for phonemes in text:
232
+ phoneme_ids = phonemes_to_ids(config, phonemes)
233
+ num_speakers = config["num_speakers"]
234
+ if num_speakers == 1:
235
+ speaker_id = None # for now
236
+ else:
237
+ speaker_id = sid
238
+ text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
239
+ text_lengths = np.array([text.shape[1]], dtype=np.int64)
240
+ scales = np.array(
241
+ [noise_scale, length_scale, noise_scale_w],
242
+ dtype=np.float32,
243
+ )
244
+ sid = None
245
+ if speaker_id is not None:
246
+ sid = np.array([speaker_id], dtype=np.int64)
247
+ audio = model.run(
248
+ None,
249
+ {
250
+ "input": text,
251
+ "input_lengths": text_lengths,
252
+ "scales": scales,
253
+ "sid": sid,
254
+ },
255
+ )[0].squeeze((0, 1))
256
+ audio = audio_float_to_int16(audio.squeeze())
257
+ audios.append(audio)
258
+ merged_audio = np.concatenate(audios)
259
+ sample_rate = config["audio"]["sample_rate"]
260
+ display(Markdown(f"{line}"))
261
+ display(Audio(merged_audio, rate=sample_rate, autoplay=auto_play))
262
+
263
+ def denoise(
264
+ audio: np.ndarray, bias_spec: np.ndarray, denoiser_strength: float
265
+ ) -> np.ndarray:
266
+ audio_spec, audio_angles = transform(audio)
267
+
268
+ a = bias_spec.shape[-1]
269
+ b = audio_spec.shape[-1]
270
+ repeats = max(1, math.ceil(b / a))
271
+ bias_spec_repeat = np.repeat(bias_spec, repeats, axis=-1)[..., :b]
272
+
273
+ audio_spec_denoised = audio_spec - (bias_spec_repeat * denoiser_strength)
274
+ audio_spec_denoised = np.clip(audio_spec_denoised, a_min=0.0, a_max=None)
275
+ audio_denoised = inverse(audio_spec_denoised, audio_angles)
276
+
277
+ return audio_denoised
278
+
279
+
280
+ def stft(x, fft_size, hopsamp):
281
+ """Compute and return the STFT of the supplied time domain signal x.
282
+ Args:
283
+ x (1-dim Numpy array): A time domain signal.
284
+ fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used.
285
+ hopsamp (int):
286
+ Returns:
287
+ The STFT. The rows are the time slices and columns are the frequency bins.
288
+ """
289
+ window = np.hanning(fft_size)
290
+ fft_size = int(fft_size)
291
+ hopsamp = int(hopsamp)
292
+ return np.array(
293
+ [
294
+ np.fft.rfft(window * x[i : i + fft_size])
295
+ for i in range(0, len(x) - fft_size, hopsamp)
296
+ ]
297
+ )
298
+
299
+
300
+ def istft(X, fft_size, hopsamp):
301
+ """Invert a STFT into a time domain signal.
302
+ Args:
303
+ X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins.
304
+ fft_size (int):
305
+ hopsamp (int): The hop size, in samples.
306
+ Returns:
307
+ The inverse STFT.
308
+ """
309
+ fft_size = int(fft_size)
310
+ hopsamp = int(hopsamp)
311
+ window = np.hanning(fft_size)
312
+ time_slices = X.shape[0]
313
+ len_samples = int(time_slices * hopsamp + fft_size)
314
+ x = np.zeros(len_samples)
315
+ for n, i in enumerate(range(0, len(x) - fft_size, hopsamp)):
316
+ x[i : i + fft_size] += window * np.real(np.fft.irfft(X[n]))
317
+ return x
318
+
319
+
320
+ def inverse(magnitude, phase):
321
+ recombine_magnitude_phase = np.concatenate(
322
+ [magnitude * np.cos(phase), magnitude * np.sin(phase)], axis=1
323
+ )
324
+
325
+ x_org = recombine_magnitude_phase
326
+ n_b, n_f, n_t = x_org.shape # pylint: disable=unpacking-non-sequence
327
+ x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64)
328
+ x.real = x_org[:, : n_f // 2]
329
+ x.imag = x_org[:, n_f // 2 :]
330
+ inverse_transform = []
331
+ for y in x:
332
+ y_ = istft(y.T, fft_size=1024, hopsamp=256)
333
+ inverse_transform.append(y_[None, :])
334
+
335
+ inverse_transform = np.concatenate(inverse_transform, 0)
336
+
337
+ return inverse_transform
338
+
339
+
340
+ def transform(input_data):
341
+ x = input_data
342
+ real_part = []
343
+ imag_part = []
344
+ for y in x:
345
+ y_ = stft(y, fft_size=1024, hopsamp=256).T
346
+ real_part.append(y_.real[None, :, :]) # pylint: disable=unsubscriptable-object
347
+ imag_part.append(y_.imag[None, :, :]) # pylint: disable=unsubscriptable-object
348
+ real_part = np.concatenate(real_part, 0)
349
+ imag_part = np.concatenate(imag_part, 0)
350
+
351
+ magnitude = np.sqrt(real_part**2 + imag_part**2)
352
+ phase = np.arctan2(imag_part.data, real_part.data)
353
+
354
+ return magnitude, phase