del fixed embedding voice diffussion
Browse files- Modules/diffusion/modules.py +30 -72
- Modules/diffusion/sampler.py +5 -47
- api.py +3 -3
- models.py +1 -0
- msinference.py +0 -2
Modules/diffusion/modules.py
CHANGED
@@ -146,6 +146,7 @@ class StyleTransformer1d(nn.Module):
|
|
146 |
return mapping
|
147 |
|
148 |
def run(self, x, time, embedding, features):
|
|
|
149 |
|
150 |
mapping = self.get_mapping(time, features)
|
151 |
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
@@ -161,37 +162,22 @@ class StyleTransformer1d(nn.Module):
|
|
161 |
|
162 |
return x
|
163 |
|
164 |
-
def forward(self,
|
165 |
-
|
166 |
-
|
167 |
embedding= None,
|
168 |
-
features = None
|
169 |
-
embedding_scale: float = 1.0) -> Tensor:
|
170 |
-
|
171 |
-
b, device = embedding.shape[0], embedding.device
|
172 |
-
fixed_embedding = self.fixed_embedding(embedding)
|
173 |
-
if embedding_mask_proba > 0.0:
|
174 |
-
# Randomly mask embedding
|
175 |
-
batch_mask = rand_bool(
|
176 |
-
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
177 |
-
)
|
178 |
-
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
# raise ValueError
|
191 |
-
return self.run(x, time, embedding=embedding, features=features)
|
192 |
-
|
193 |
-
|
194 |
-
return x
|
195 |
|
196 |
|
197 |
class StyleTransformerBlock(nn.Module):
|
@@ -216,24 +202,11 @@ class StyleTransformerBlock(nn.Module):
|
|
216 |
features=features,
|
217 |
style_dim=style_dim,
|
218 |
num_heads=num_heads,
|
219 |
-
head_features=head_features
|
220 |
-
use_rel_pos=use_rel_pos,
|
221 |
-
# rel_pos_num_buckets=rel_pos_num_buckets,
|
222 |
-
# rel_pos_max_distance=rel_pos_max_distance,
|
223 |
)
|
224 |
|
225 |
if self.use_cross_attention:
|
226 |
raise ValueError
|
227 |
-
# self.cross_attention = StyleAttention(
|
228 |
-
# features=features,
|
229 |
-
# style_dim=style_dim,
|
230 |
-
# num_heads=num_heads,
|
231 |
-
# head_features=head_features,
|
232 |
-
# context_features=context_features,
|
233 |
-
# use_rel_pos=use_rel_pos,
|
234 |
-
# rel_pos_num_buckets=rel_pos_num_buckets,
|
235 |
-
# rel_pos_max_distance=rel_pos_max_distance,
|
236 |
-
# )
|
237 |
|
238 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
239 |
|
@@ -254,7 +227,7 @@ class StyleAttention(nn.Module):
|
|
254 |
head_features: int,
|
255 |
num_heads: int,
|
256 |
context_features = None,
|
257 |
-
use_rel_pos: bool,
|
258 |
# rel_pos_num_buckets: Optional[int] = None,
|
259 |
# rel_pos_max_distance: Optional[int] = None,
|
260 |
):
|
@@ -274,23 +247,20 @@ class StyleAttention(nn.Module):
|
|
274 |
self.attention = AttentionBase(
|
275 |
features,
|
276 |
num_heads=num_heads,
|
277 |
-
head_features=head_features
|
278 |
-
use_rel_pos=use_rel_pos,
|
279 |
-
# rel_pos_num_buckets=rel_pos_num_buckets,
|
280 |
-
# rel_pos_max_distance=rel_pos_max_distance,
|
281 |
)
|
282 |
|
283 |
-
def forward(self, x
|
284 |
|
285 |
-
|
286 |
-
|
287 |
context = default(context, x)
|
288 |
-
|
289 |
-
|
290 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
291 |
|
292 |
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
293 |
-
|
294 |
return self.attention(q, k, v)
|
295 |
|
296 |
|
@@ -310,25 +280,13 @@ class AttentionBase(nn.Module):
|
|
310 |
features,
|
311 |
*,
|
312 |
head_features,
|
313 |
-
num_heads
|
314 |
-
use_rel_pos,
|
315 |
-
out_features = None,
|
316 |
-
# rel_pos_num_buckets: Optional[int] = None,
|
317 |
-
# rel_pos_max_distance: Optional[int] = None,
|
318 |
-
):
|
319 |
super().__init__()
|
320 |
self.scale = head_features ** -0.5
|
321 |
self.num_heads = num_heads
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
if use_rel_pos:
|
326 |
-
raise ValueError
|
327 |
-
|
328 |
-
if out_features is None:
|
329 |
-
out_features = features
|
330 |
-
|
331 |
-
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
|
332 |
|
333 |
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
334 |
# Split heads
|
@@ -358,7 +316,7 @@ class Attention(nn.Module):
|
|
358 |
num_heads,
|
359 |
out_features=None,
|
360 |
context_features=None,
|
361 |
-
use_rel_pos,
|
362 |
# rel_pos_num_buckets: Optional[int] = None,
|
363 |
# rel_pos_max_distance: Optional[int] = None,
|
364 |
):
|
@@ -381,7 +339,7 @@ class Attention(nn.Module):
|
|
381 |
out_features=out_features,
|
382 |
num_heads=num_heads,
|
383 |
head_features=head_features,
|
384 |
-
use_rel_pos=use_rel_pos,
|
385 |
# rel_pos_num_buckets=rel_pos_num_buckets,
|
386 |
# rel_pos_max_distance=rel_pos_max_distance,
|
387 |
)
|
|
|
146 |
return mapping
|
147 |
|
148 |
def run(self, x, time, embedding, features):
|
149 |
+
# called by forward()
|
150 |
|
151 |
mapping = self.get_mapping(time, features)
|
152 |
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
|
|
162 |
|
163 |
return x
|
164 |
|
165 |
+
def forward(self,
|
166 |
+
x,
|
167 |
+
time,
|
168 |
embedding= None,
|
169 |
+
features = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
b, device = embedding.shape[0], embedding.device
|
172 |
+
# if
|
173 |
+
# embedding_mask_proba: float = 0.0, > 0
|
174 |
+
# fixed_embedding = self.fixed_embedding(embedding)
|
175 |
+
# embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
176 |
+
return self.run(x,
|
177 |
+
time,
|
178 |
+
embedding=embedding,
|
179 |
+
# embedding=self.fixed_embedding(embedding), # fixedemb has noisy beginnings on chapters.wav
|
180 |
+
features=features)
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
|
183 |
class StyleTransformerBlock(nn.Module):
|
|
|
202 |
features=features,
|
203 |
style_dim=style_dim,
|
204 |
num_heads=num_heads,
|
205 |
+
head_features=head_features
|
|
|
|
|
|
|
206 |
)
|
207 |
|
208 |
if self.use_cross_attention:
|
209 |
raise ValueError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
212 |
|
|
|
227 |
head_features: int,
|
228 |
num_heads: int,
|
229 |
context_features = None,
|
230 |
+
# use_rel_pos: bool,
|
231 |
# rel_pos_num_buckets: Optional[int] = None,
|
232 |
# rel_pos_max_distance: Optional[int] = None,
|
233 |
):
|
|
|
247 |
self.attention = AttentionBase(
|
248 |
features,
|
249 |
num_heads=num_heads,
|
250 |
+
head_features=head_features
|
|
|
|
|
|
|
251 |
)
|
252 |
|
253 |
+
def forward(self, x, s, *, context = None):
|
254 |
|
255 |
+
if context is not None:
|
256 |
+
raise ValueError
|
257 |
context = default(context, x)
|
258 |
+
|
259 |
+
|
260 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
261 |
|
262 |
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
263 |
+
|
264 |
return self.attention(q, k, v)
|
265 |
|
266 |
|
|
|
280 |
features,
|
281 |
*,
|
282 |
head_features,
|
283 |
+
num_heads):
|
|
|
|
|
|
|
|
|
|
|
284 |
super().__init__()
|
285 |
self.scale = head_features ** -0.5
|
286 |
self.num_heads = num_heads
|
287 |
+
mid_features = head_features * num_heads
|
288 |
+
self.to_out = nn.Linear(in_features=mid_features,
|
289 |
+
out_features=features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
292 |
# Split heads
|
|
|
316 |
num_heads,
|
317 |
out_features=None,
|
318 |
context_features=None,
|
319 |
+
# use_rel_pos,
|
320 |
# rel_pos_num_buckets: Optional[int] = None,
|
321 |
# rel_pos_max_distance: Optional[int] = None,
|
322 |
):
|
|
|
339 |
out_features=out_features,
|
340 |
num_heads=num_heads,
|
341 |
head_features=head_features,
|
342 |
+
# use_rel_pos=use_rel_pos,
|
343 |
# rel_pos_num_buckets=rel_pos_num_buckets,
|
344 |
# rel_pos_max_distance=rel_pos_max_distance,
|
345 |
)
|
Modules/diffusion/sampler.py
CHANGED
@@ -1,61 +1,18 @@
|
|
1 |
-
from math import
|
2 |
import torch.nn as nn
|
3 |
from einops import rearrange
|
4 |
from torch import Tensor
|
5 |
-
|
6 |
from functools import reduce
|
7 |
-
from inspect import isfunction
|
8 |
-
from math import ceil, floor, log2, pi
|
9 |
-
# from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
-
from einops import rearrange
|
13 |
-
from torch import Generator, Tensor
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
# def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
21 |
-
# return isinstance(obj, list) or isinstance(obj, tuple)
|
22 |
-
|
23 |
|
24 |
def default(val, d):
|
25 |
if val is not None: #exists(val):
|
26 |
return val
|
27 |
return d #d() if isfunction(d) else d
|
28 |
|
29 |
-
|
30 |
-
# def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
31 |
-
# if isinstance(val, tuple):
|
32 |
-
# return list(val)
|
33 |
-
# if isinstance(val, list):
|
34 |
-
# return val
|
35 |
-
# return [val] # type: ignore
|
36 |
-
|
37 |
-
|
38 |
-
# def prod(vals: Sequence[int]) -> int:
|
39 |
-
# return reduce(lambda x, y: x * y, vals)
|
40 |
-
|
41 |
-
|
42 |
-
def closest_power_2(x: float) -> int:
|
43 |
-
exponent = log2(x)
|
44 |
-
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
45 |
-
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
46 |
-
return 2 ** int(exponent_closest)
|
47 |
-
|
48 |
-
def rand_bool(shape, proba, device = None):
|
49 |
-
if proba == 1:
|
50 |
-
return torch.ones(shape, device=device, dtype=torch.bool)
|
51 |
-
elif proba == 0:
|
52 |
-
return torch.zeros(shape, device=device, dtype=torch.bool)
|
53 |
-
else:
|
54 |
-
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
55 |
-
|
56 |
-
# ============================= END functions from diffusION.utils
|
57 |
-
|
58 |
-
|
59 |
class LogNormalDistribution():
|
60 |
def __init__(self, mean: float, std: float):
|
61 |
self.mean = mean
|
@@ -238,7 +195,8 @@ class DiffusionSampler(nn.Module):
|
|
238 |
|
239 |
# Compute sigmas using schedule
|
240 |
sigmas = self.sigma_schedule(num_steps, device)
|
241 |
-
|
|
|
242 |
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
243 |
# Sample using sampler
|
244 |
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
|
|
1 |
+
from math import sqrt
|
2 |
import torch.nn as nn
|
3 |
from einops import rearrange
|
4 |
from torch import Tensor
|
|
|
5 |
from functools import reduce
|
6 |
+
# from inspect import isfunction
|
7 |
+
# from math import ceil, floor, log2, pi
|
|
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def default(val, d):
|
12 |
if val is not None: #exists(val):
|
13 |
return val
|
14 |
return d #d() if isfunction(d) else d
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class LogNormalDistribution():
|
17 |
def __init__(self, mean: float, std: float):
|
18 |
self.mean = mean
|
|
|
195 |
|
196 |
# Compute sigmas using schedule
|
197 |
sigmas = self.sigma_schedule(num_steps, device)
|
198 |
+
|
199 |
+
# L242 KWARGS dict_keys(['embedding', 'features'])
|
200 |
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
201 |
# Sample using sampler
|
202 |
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
api.py
CHANGED
@@ -171,8 +171,8 @@ def tts_multi_sentence(precomputed_style_vector=None,
|
|
171 |
precomputed_style_vector,
|
172 |
alpha=0.3,
|
173 |
beta=0.7,
|
174 |
-
diffusion_steps=diffusion_steps
|
175 |
-
|
176 |
x = np.concatenate(x)
|
177 |
|
178 |
# Fallback - MMS TTS - Non-English
|
@@ -530,7 +530,7 @@ def serve_wav():
|
|
530 |
|
531 |
# audios = [msinference.inference(text,
|
532 |
# msinference.compute_style(f'voices/{voice}.wav'),
|
533 |
-
# alpha=0.3, beta=0.7, diffusion_steps=7
|
534 |
# # for t in [text]:
|
535 |
# output_buffer = io.BytesIO()
|
536 |
# write(output_buffer, 24000, np.concatenate(audios))
|
|
|
171 |
precomputed_style_vector,
|
172 |
alpha=0.3,
|
173 |
beta=0.7,
|
174 |
+
diffusion_steps=diffusion_steps)
|
175 |
+
)
|
176 |
x = np.concatenate(x)
|
177 |
|
178 |
# Fallback - MMS TTS - Non-English
|
|
|
530 |
|
531 |
# audios = [msinference.inference(text,
|
532 |
# msinference.compute_style(f'voices/{voice}.wav'),
|
533 |
+
# alpha=0.3, beta=0.7, diffusion_steps=7)]
|
534 |
# # for t in [text]:
|
535 |
# output_buffer = io.BytesIO()
|
536 |
# write(output_buffer, 24000, np.concatenate(audios))
|
models.py
CHANGED
@@ -69,6 +69,7 @@ class AudioDiffusionConditional(nn.Module):
|
|
69 |
|
70 |
def forward(self, *args, **kwargs):
|
71 |
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
|
|
72 |
return self.diffusion(*args, **{**default_kwargs, **kwargs})
|
73 |
|
74 |
# def sample(self, *args, **kwargs):
|
|
|
69 |
|
70 |
def forward(self, *args, **kwargs):
|
71 |
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
72 |
+
# here embedding_scale = 1.0 is passed to DiffusionSampler() - del no-op if scale = 1.0
|
73 |
return self.diffusion(*args, **{**default_kwargs, **kwargs})
|
74 |
|
75 |
# def sample(self, *args, **kwargs):
|
msinference.py
CHANGED
@@ -174,7 +174,6 @@ def inference(text,
|
|
174 |
alpha = 0.3,
|
175 |
beta = 0.7,
|
176 |
diffusion_steps=7, # 7 if voice is native English else 5 for non-native
|
177 |
-
embedding_scale=1,
|
178 |
use_gruut=False):
|
179 |
text = text.strip()
|
180 |
ps = global_phonemizer.phonemize([text])
|
@@ -213,7 +212,6 @@ def inference(text,
|
|
213 |
|
214 |
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
215 |
embedding=bert_dur,
|
216 |
-
embedding_scale=embedding_scale,
|
217 |
features=ref_s, # reference from the same speaker as the embedding
|
218 |
num_steps=diffusion_steps).squeeze(1)
|
219 |
|
|
|
174 |
alpha = 0.3,
|
175 |
beta = 0.7,
|
176 |
diffusion_steps=7, # 7 if voice is native English else 5 for non-native
|
|
|
177 |
use_gruut=False):
|
178 |
text = text.strip()
|
179 |
ps = global_phonemizer.phonemize([text])
|
|
|
212 |
|
213 |
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
214 |
embedding=bert_dur,
|
|
|
215 |
features=ref_s, # reference from the same speaker as the embedding
|
216 |
num_steps=diffusion_steps).squeeze(1)
|
217 |
|