Dionyssos commited on
Commit
6a0b5fd
·
1 Parent(s): a032fce

del fixed embedding voice diffussion

Browse files
Modules/diffusion/modules.py CHANGED
@@ -146,6 +146,7 @@ class StyleTransformer1d(nn.Module):
146
  return mapping
147
 
148
  def run(self, x, time, embedding, features):
 
149
 
150
  mapping = self.get_mapping(time, features)
151
  x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
@@ -161,37 +162,22 @@ class StyleTransformer1d(nn.Module):
161
 
162
  return x
163
 
164
- def forward(self, x: Tensor,
165
- time: Tensor,
166
- embedding_mask_proba: float = 0.0,
167
  embedding= None,
168
- features = None,
169
- embedding_scale: float = 1.0) -> Tensor:
170
-
171
- b, device = embedding.shape[0], embedding.device
172
- fixed_embedding = self.fixed_embedding(embedding)
173
- if embedding_mask_proba > 0.0:
174
- # Randomly mask embedding
175
- batch_mask = rand_bool(
176
- shape=(b, 1, 1), proba=embedding_mask_proba, device=device
177
- )
178
- embedding = torch.where(batch_mask, fixed_embedding, embedding)
179
 
180
- if embedding_scale != 1.0:
181
-
182
-
183
- out = self.run(x, time, embedding=embedding, features=features)
184
- out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
185
-
186
- raise ValueError
187
- return out_masked + (out - out_masked) * embedding_scale
188
-
189
- else:
190
- # raise ValueError
191
- return self.run(x, time, embedding=embedding, features=features)
192
-
193
-
194
- return x
195
 
196
 
197
  class StyleTransformerBlock(nn.Module):
@@ -216,24 +202,11 @@ class StyleTransformerBlock(nn.Module):
216
  features=features,
217
  style_dim=style_dim,
218
  num_heads=num_heads,
219
- head_features=head_features,
220
- use_rel_pos=use_rel_pos,
221
- # rel_pos_num_buckets=rel_pos_num_buckets,
222
- # rel_pos_max_distance=rel_pos_max_distance,
223
  )
224
 
225
  if self.use_cross_attention:
226
  raise ValueError
227
- # self.cross_attention = StyleAttention(
228
- # features=features,
229
- # style_dim=style_dim,
230
- # num_heads=num_heads,
231
- # head_features=head_features,
232
- # context_features=context_features,
233
- # use_rel_pos=use_rel_pos,
234
- # rel_pos_num_buckets=rel_pos_num_buckets,
235
- # rel_pos_max_distance=rel_pos_max_distance,
236
- # )
237
 
238
  self.feed_forward = FeedForward(features=features, multiplier=multiplier)
239
 
@@ -254,7 +227,7 @@ class StyleAttention(nn.Module):
254
  head_features: int,
255
  num_heads: int,
256
  context_features = None,
257
- use_rel_pos: bool,
258
  # rel_pos_num_buckets: Optional[int] = None,
259
  # rel_pos_max_distance: Optional[int] = None,
260
  ):
@@ -274,23 +247,20 @@ class StyleAttention(nn.Module):
274
  self.attention = AttentionBase(
275
  features,
276
  num_heads=num_heads,
277
- head_features=head_features,
278
- use_rel_pos=use_rel_pos,
279
- # rel_pos_num_buckets=rel_pos_num_buckets,
280
- # rel_pos_max_distance=rel_pos_max_distance,
281
  )
282
 
283
- def forward(self, x: Tensor, s: Tensor, *, context = None):
284
 
285
- # raise ValueError
286
- # Use context if provided
287
  context = default(context, x)
288
- # print(context.shape,'ppppppppppppppppppppppppppppppppppppppppppp') # bs, time, 1024
289
- # Normalize then compute q from input and k,v from context
290
  x, context = self.norm(x, s), self.norm_context(context, s)
291
 
292
  q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
293
- # Compute and return attention
294
  return self.attention(q, k, v)
295
 
296
 
@@ -310,25 +280,13 @@ class AttentionBase(nn.Module):
310
  features,
311
  *,
312
  head_features,
313
- num_heads,
314
- use_rel_pos,
315
- out_features = None,
316
- # rel_pos_num_buckets: Optional[int] = None,
317
- # rel_pos_max_distance: Optional[int] = None,
318
- ):
319
  super().__init__()
320
  self.scale = head_features ** -0.5
321
  self.num_heads = num_heads
322
- self.use_rel_pos = use_rel_pos
323
- mid_features = head_features * num_heads
324
-
325
- if use_rel_pos:
326
- raise ValueError
327
-
328
- if out_features is None:
329
- out_features = features
330
-
331
- self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
332
 
333
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
334
  # Split heads
@@ -358,7 +316,7 @@ class Attention(nn.Module):
358
  num_heads,
359
  out_features=None,
360
  context_features=None,
361
- use_rel_pos,
362
  # rel_pos_num_buckets: Optional[int] = None,
363
  # rel_pos_max_distance: Optional[int] = None,
364
  ):
@@ -381,7 +339,7 @@ class Attention(nn.Module):
381
  out_features=out_features,
382
  num_heads=num_heads,
383
  head_features=head_features,
384
- use_rel_pos=use_rel_pos,
385
  # rel_pos_num_buckets=rel_pos_num_buckets,
386
  # rel_pos_max_distance=rel_pos_max_distance,
387
  )
 
146
  return mapping
147
 
148
  def run(self, x, time, embedding, features):
149
+ # called by forward()
150
 
151
  mapping = self.get_mapping(time, features)
152
  x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
 
162
 
163
  return x
164
 
165
+ def forward(self,
166
+ x,
167
+ time,
168
  embedding= None,
169
+ features = None):
 
 
 
 
 
 
 
 
 
 
170
 
171
+ b, device = embedding.shape[0], embedding.device
172
+ # if
173
+ # embedding_mask_proba: float = 0.0, > 0
174
+ # fixed_embedding = self.fixed_embedding(embedding)
175
+ # embedding = torch.where(batch_mask, fixed_embedding, embedding)
176
+ return self.run(x,
177
+ time,
178
+ embedding=embedding,
179
+ # embedding=self.fixed_embedding(embedding), # fixedemb has noisy beginnings on chapters.wav
180
+ features=features)
 
 
 
 
 
181
 
182
 
183
  class StyleTransformerBlock(nn.Module):
 
202
  features=features,
203
  style_dim=style_dim,
204
  num_heads=num_heads,
205
+ head_features=head_features
 
 
 
206
  )
207
 
208
  if self.use_cross_attention:
209
  raise ValueError
 
 
 
 
 
 
 
 
 
 
210
 
211
  self.feed_forward = FeedForward(features=features, multiplier=multiplier)
212
 
 
227
  head_features: int,
228
  num_heads: int,
229
  context_features = None,
230
+ # use_rel_pos: bool,
231
  # rel_pos_num_buckets: Optional[int] = None,
232
  # rel_pos_max_distance: Optional[int] = None,
233
  ):
 
247
  self.attention = AttentionBase(
248
  features,
249
  num_heads=num_heads,
250
+ head_features=head_features
 
 
 
251
  )
252
 
253
+ def forward(self, x, s, *, context = None):
254
 
255
+ if context is not None:
256
+ raise ValueError
257
  context = default(context, x)
258
+
259
+
260
  x, context = self.norm(x, s), self.norm_context(context, s)
261
 
262
  q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
263
+
264
  return self.attention(q, k, v)
265
 
266
 
 
280
  features,
281
  *,
282
  head_features,
283
+ num_heads):
 
 
 
 
 
284
  super().__init__()
285
  self.scale = head_features ** -0.5
286
  self.num_heads = num_heads
287
+ mid_features = head_features * num_heads
288
+ self.to_out = nn.Linear(in_features=mid_features,
289
+ out_features=features)
 
 
 
 
 
 
 
290
 
291
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
292
  # Split heads
 
316
  num_heads,
317
  out_features=None,
318
  context_features=None,
319
+ # use_rel_pos,
320
  # rel_pos_num_buckets: Optional[int] = None,
321
  # rel_pos_max_distance: Optional[int] = None,
322
  ):
 
339
  out_features=out_features,
340
  num_heads=num_heads,
341
  head_features=head_features,
342
+ # use_rel_pos=use_rel_pos,
343
  # rel_pos_num_buckets=rel_pos_num_buckets,
344
  # rel_pos_max_distance=rel_pos_max_distance,
345
  )
Modules/diffusion/sampler.py CHANGED
@@ -1,61 +1,18 @@
1
- from math import atan, cos, pi, sin, sqrt
2
  import torch.nn as nn
3
  from einops import rearrange
4
  from torch import Tensor
5
-
6
  from functools import reduce
7
- from inspect import isfunction
8
- from math import ceil, floor, log2, pi
9
- # from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
10
  import torch
11
  import torch.nn.functional as F
12
- from einops import rearrange
13
- from torch import Generator, Tensor
14
-
15
-
16
-
17
-
18
-
19
-
20
- # def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
21
- # return isinstance(obj, list) or isinstance(obj, tuple)
22
-
23
 
24
  def default(val, d):
25
  if val is not None: #exists(val):
26
  return val
27
  return d #d() if isfunction(d) else d
28
 
29
-
30
- # def to_list(val: Union[T, Sequence[T]]) -> List[T]:
31
- # if isinstance(val, tuple):
32
- # return list(val)
33
- # if isinstance(val, list):
34
- # return val
35
- # return [val] # type: ignore
36
-
37
-
38
- # def prod(vals: Sequence[int]) -> int:
39
- # return reduce(lambda x, y: x * y, vals)
40
-
41
-
42
- def closest_power_2(x: float) -> int:
43
- exponent = log2(x)
44
- distance_fn = lambda z: abs(x - 2 ** z) # noqa
45
- exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
46
- return 2 ** int(exponent_closest)
47
-
48
- def rand_bool(shape, proba, device = None):
49
- if proba == 1:
50
- return torch.ones(shape, device=device, dtype=torch.bool)
51
- elif proba == 0:
52
- return torch.zeros(shape, device=device, dtype=torch.bool)
53
- else:
54
- return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
55
-
56
- # ============================= END functions from diffusION.utils
57
-
58
-
59
  class LogNormalDistribution():
60
  def __init__(self, mean: float, std: float):
61
  self.mean = mean
@@ -238,7 +195,8 @@ class DiffusionSampler(nn.Module):
238
 
239
  # Compute sigmas using schedule
240
  sigmas = self.sigma_schedule(num_steps, device)
241
- # Append additional kwargs to denoise function (used e.g. for conditional unet)
 
242
  fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
243
  # Sample using sampler
244
  x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
 
1
+ from math import sqrt
2
  import torch.nn as nn
3
  from einops import rearrange
4
  from torch import Tensor
 
5
  from functools import reduce
6
+ # from inspect import isfunction
7
+ # from math import ceil, floor, log2, pi
 
8
  import torch
9
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def default(val, d):
12
  if val is not None: #exists(val):
13
  return val
14
  return d #d() if isfunction(d) else d
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class LogNormalDistribution():
17
  def __init__(self, mean: float, std: float):
18
  self.mean = mean
 
195
 
196
  # Compute sigmas using schedule
197
  sigmas = self.sigma_schedule(num_steps, device)
198
+
199
+ # L242 KWARGS dict_keys(['embedding', 'features'])
200
  fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
201
  # Sample using sampler
202
  x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
api.py CHANGED
@@ -171,8 +171,8 @@ def tts_multi_sentence(precomputed_style_vector=None,
171
  precomputed_style_vector,
172
  alpha=0.3,
173
  beta=0.7,
174
- diffusion_steps=diffusion_steps,
175
- embedding_scale=1))
176
  x = np.concatenate(x)
177
 
178
  # Fallback - MMS TTS - Non-English
@@ -530,7 +530,7 @@ def serve_wav():
530
 
531
  # audios = [msinference.inference(text,
532
  # msinference.compute_style(f'voices/{voice}.wav'),
533
- # alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1)]
534
  # # for t in [text]:
535
  # output_buffer = io.BytesIO()
536
  # write(output_buffer, 24000, np.concatenate(audios))
 
171
  precomputed_style_vector,
172
  alpha=0.3,
173
  beta=0.7,
174
+ diffusion_steps=diffusion_steps)
175
+ )
176
  x = np.concatenate(x)
177
 
178
  # Fallback - MMS TTS - Non-English
 
530
 
531
  # audios = [msinference.inference(text,
532
  # msinference.compute_style(f'voices/{voice}.wav'),
533
+ # alpha=0.3, beta=0.7, diffusion_steps=7)]
534
  # # for t in [text]:
535
  # output_buffer = io.BytesIO()
536
  # write(output_buffer, 24000, np.concatenate(audios))
models.py CHANGED
@@ -69,6 +69,7 @@ class AudioDiffusionConditional(nn.Module):
69
 
70
  def forward(self, *args, **kwargs):
71
  default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
 
72
  return self.diffusion(*args, **{**default_kwargs, **kwargs})
73
 
74
  # def sample(self, *args, **kwargs):
 
69
 
70
  def forward(self, *args, **kwargs):
71
  default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
72
+ # here embedding_scale = 1.0 is passed to DiffusionSampler() - del no-op if scale = 1.0
73
  return self.diffusion(*args, **{**default_kwargs, **kwargs})
74
 
75
  # def sample(self, *args, **kwargs):
msinference.py CHANGED
@@ -174,7 +174,6 @@ def inference(text,
174
  alpha = 0.3,
175
  beta = 0.7,
176
  diffusion_steps=7, # 7 if voice is native English else 5 for non-native
177
- embedding_scale=1,
178
  use_gruut=False):
179
  text = text.strip()
180
  ps = global_phonemizer.phonemize([text])
@@ -213,7 +212,6 @@ def inference(text,
213
 
214
  s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
215
  embedding=bert_dur,
216
- embedding_scale=embedding_scale,
217
  features=ref_s, # reference from the same speaker as the embedding
218
  num_steps=diffusion_steps).squeeze(1)
219
 
 
174
  alpha = 0.3,
175
  beta = 0.7,
176
  diffusion_steps=7, # 7 if voice is native English else 5 for non-native
 
177
  use_gruut=False):
178
  text = text.strip()
179
  ps = global_phonemizer.phonemize([text])
 
212
 
213
  s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
214
  embedding=bert_dur,
 
215
  features=ref_s, # reference from the same speaker as the embedding
216
  num_steps=diffusion_steps).squeeze(1)
217