Dionyssos commited on
Commit
70184a3
·
1 Parent(s): 7d11d0a
Files changed (3) hide show
  1. Modules/diffusion/modules.py +2 -277
  2. README.md +4 -4
  3. models.py +2 -4
Modules/diffusion/modules.py CHANGED
@@ -279,207 +279,6 @@ class StyleAttention(nn.Module):
279
  q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
  # Compute and return attention
281
  return self.attention(q, k, v)
282
-
283
- class Transformer1d(nn.Module):
284
- def __init__(
285
- self,
286
- num_layers: int,
287
- channels: int,
288
- num_heads: int,
289
- head_features: int,
290
- multiplier: int,
291
- use_context_time: bool = True,
292
- use_rel_pos: bool = False,
293
- context_features_multiplier: int = 1,
294
- rel_pos_num_buckets: Optional[int] = None,
295
- rel_pos_max_distance: Optional[int] = None,
296
- context_features: Optional[int] = None,
297
- context_embedding_features: Optional[int] = None,
298
- embedding_max_length: int = 512,
299
- ):
300
- super().__init__()
301
-
302
- self.blocks = nn.ModuleList(
303
- [
304
- TransformerBlock(
305
- features=channels + context_embedding_features,
306
- head_features=head_features,
307
- num_heads=num_heads,
308
- multiplier=multiplier,
309
- use_rel_pos=use_rel_pos,
310
- rel_pos_num_buckets=rel_pos_num_buckets,
311
- rel_pos_max_distance=rel_pos_max_distance,
312
- )
313
- for i in range(num_layers)
314
- ]
315
- )
316
-
317
- self.to_out = nn.Sequential(
318
- Rearrange("b t c -> b c t"),
319
- nn.Conv1d(
320
- in_channels=channels + context_embedding_features,
321
- out_channels=channels,
322
- kernel_size=1,
323
- ),
324
- )
325
-
326
- use_context_features = exists(context_features)
327
- self.use_context_features = use_context_features
328
- self.use_context_time = use_context_time
329
-
330
- if use_context_time or use_context_features:
331
- context_mapping_features = channels + context_embedding_features
332
-
333
- self.to_mapping = nn.Sequential(
334
- nn.Linear(context_mapping_features, context_mapping_features),
335
- nn.GELU(),
336
- nn.Linear(context_mapping_features, context_mapping_features),
337
- nn.GELU(),
338
- )
339
-
340
- if use_context_time:
341
- assert exists(context_mapping_features)
342
- self.to_time = nn.Sequential(
343
- TimePositionalEmbedding(
344
- dim=channels, out_features=context_mapping_features
345
- ),
346
- nn.GELU(),
347
- )
348
-
349
- if use_context_features:
350
- assert exists(context_features) and exists(context_mapping_features)
351
- self.to_features = nn.Sequential(
352
- nn.Linear(
353
- in_features=context_features, out_features=context_mapping_features
354
- ),
355
- nn.GELU(),
356
- )
357
-
358
- self.fixed_embedding = FixedEmbedding(
359
- max_length=embedding_max_length, features=context_embedding_features
360
- )
361
-
362
-
363
- def get_mapping(
364
- self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
365
- ) -> Optional[Tensor]:
366
- """Combines context time features and features into mapping"""
367
- items, mapping = [], None
368
- # Compute time features
369
- if self.use_context_time:
370
- assert_message = "use_context_time=True but no time features provided"
371
- assert exists(time), assert_message
372
- items += [self.to_time(time)]
373
- # Compute features
374
- if self.use_context_features:
375
- assert_message = "context_features exists but no features provided"
376
- assert exists(features), assert_message
377
- items += [self.to_features(features)]
378
-
379
- # Compute joint mapping
380
- if self.use_context_time or self.use_context_features:
381
- mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
382
- mapping = self.to_mapping(mapping)
383
-
384
- return mapping
385
-
386
- def run(self, x, time, embedding, features):
387
-
388
- mapping = self.get_mapping(time, features)
389
- x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
390
- mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
391
-
392
- for block in self.blocks:
393
- x = x + mapping
394
- x = block(x)
395
-
396
- x = x.mean(axis=1).unsqueeze(1)
397
- x = self.to_out(x)
398
- x = x.transpose(-1, -2)
399
-
400
- return x
401
-
402
- def forward(self, x: Tensor,
403
- time: Tensor,
404
- embedding_mask_proba: float = 0.0,
405
- embedding: Optional[Tensor] = None,
406
- features: Optional[Tensor] = None,
407
- embedding_scale: float = 1.0) -> Tensor:
408
-
409
- b, device = embedding.shape[0], embedding.device
410
- fixed_embedding = self.fixed_embedding(embedding)
411
- if embedding_mask_proba > 0.0:
412
- # Randomly mask embedding
413
- batch_mask = rand_bool(
414
- shape=(b, 1, 1), proba=embedding_mask_proba, device=device
415
- )
416
- embedding = torch.where(batch_mask, fixed_embedding, embedding)
417
-
418
- if embedding_scale != 1.0:
419
- # Compute both normal and fixed embedding outputs
420
- out = self.run(x, time, embedding=embedding, features=features)
421
- out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
422
- # Scale conditional output using classifier-free guidance
423
- return out_masked + (out - out_masked) * embedding_scale
424
- else:
425
- return self.run(x, time, embedding=embedding, features=features)
426
-
427
- return x
428
-
429
-
430
- """
431
- Attention Components
432
- """
433
-
434
-
435
- class RelativePositionBias(nn.Module):
436
- def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
437
- super().__init__()
438
- self.num_buckets = num_buckets
439
- self.max_distance = max_distance
440
- self.num_heads = num_heads
441
- self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
442
-
443
- @staticmethod
444
- def _relative_position_bucket(
445
- relative_position: Tensor, num_buckets: int, max_distance: int
446
- ):
447
- num_buckets //= 2
448
- ret = (relative_position >= 0).to(torch.long) * num_buckets
449
- n = torch.abs(relative_position)
450
-
451
- max_exact = num_buckets // 2
452
- is_small = n < max_exact
453
-
454
- val_if_large = (
455
- max_exact
456
- + (
457
- torch.log(n.float() / max_exact)
458
- / log(max_distance / max_exact)
459
- * (num_buckets - max_exact)
460
- ).long()
461
- )
462
- val_if_large = torch.min(
463
- val_if_large, torch.full_like(val_if_large, num_buckets - 1)
464
- )
465
-
466
- ret += torch.where(is_small, n, val_if_large)
467
- return ret
468
-
469
- def forward(self, num_queries: int, num_keys: int) -> Tensor:
470
- i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
471
- q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
472
- k_pos = torch.arange(j, dtype=torch.long, device=device)
473
- rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
474
-
475
- relative_position_bucket = self._relative_position_bucket(
476
- rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
477
- )
478
-
479
- bias = self.relative_attention_bias(relative_position_bucket)
480
- bias = rearrange(bias, "m n h -> 1 h m n")
481
- return bias
482
-
483
 
484
  def FeedForward(features: int, multiplier: int) -> nn.Module:
485
  mid_features = features * multiplier
@@ -509,12 +308,8 @@ class AttentionBase(nn.Module):
509
  mid_features = head_features * num_heads
510
 
511
  if use_rel_pos:
512
- assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
513
- self.rel_pos = RelativePositionBias(
514
- num_buckets=rel_pos_num_buckets,
515
- max_distance=rel_pos_max_distance,
516
- num_heads=num_heads,
517
- )
518
  if out_features is None:
519
  out_features = features
520
 
@@ -584,76 +379,6 @@ class Attention(nn.Module):
584
  return self.attention(q, k, v)
585
 
586
 
587
- """
588
- Transformer Blocks
589
- """
590
-
591
-
592
- class TransformerBlock(nn.Module):
593
- def __init__(
594
- self,
595
- features: int,
596
- num_heads: int,
597
- head_features: int,
598
- multiplier: int,
599
- use_rel_pos: bool,
600
- rel_pos_num_buckets: Optional[int] = None,
601
- rel_pos_max_distance: Optional[int] = None,
602
- context_features: Optional[int] = None,
603
- ):
604
- super().__init__()
605
-
606
- self.use_cross_attention = exists(context_features) and context_features > 0
607
-
608
- self.attention = Attention(
609
- features=features,
610
- num_heads=num_heads,
611
- head_features=head_features,
612
- use_rel_pos=use_rel_pos,
613
- rel_pos_num_buckets=rel_pos_num_buckets,
614
- rel_pos_max_distance=rel_pos_max_distance,
615
- )
616
-
617
- if self.use_cross_attention:
618
- self.cross_attention = Attention(
619
- features=features,
620
- num_heads=num_heads,
621
- head_features=head_features,
622
- context_features=context_features,
623
- use_rel_pos=use_rel_pos,
624
- rel_pos_num_buckets=rel_pos_num_buckets,
625
- rel_pos_max_distance=rel_pos_max_distance,
626
- )
627
-
628
- self.feed_forward = FeedForward(features=features, multiplier=multiplier)
629
-
630
- def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
631
- x = self.attention(x) + x
632
- if self.use_cross_attention:
633
- x = self.cross_attention(x, context=context) + x
634
- x = self.feed_forward(x) + x
635
- return x
636
-
637
-
638
-
639
- """
640
- Time Embeddings
641
- """
642
-
643
-
644
- class SinusoidalEmbedding(nn.Module):
645
- def __init__(self, dim: int):
646
- super().__init__()
647
- self.dim = dim
648
-
649
- def forward(self, x: Tensor) -> Tensor:
650
- device, half_dim = x.device, self.dim // 2
651
- emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
652
- emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
653
- emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
654
- return torch.cat((emb.sin(), emb.cos()), dim=-1)
655
-
656
-
657
  class LearnedPositionalEmbedding(nn.Module):
658
  """Used for continuous time"""
659
 
 
279
  q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
  # Compute and return attention
281
  return self.attention(q, k, v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  def FeedForward(features: int, multiplier: int) -> nn.Module:
284
  mid_features = features * multiplier
 
308
  mid_features = head_features * num_heads
309
 
310
  if use_rel_pos:
311
+ raise ValueError
312
+
 
 
 
 
313
  if out_features is None:
314
  out_features = features
315
 
 
379
  return self.attention(q, k, v)
380
 
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  class LearnedPositionalEmbedding(nn.Module):
383
  """Used for continuous time"""
384
 
README.md CHANGED
@@ -26,7 +26,7 @@ Beta version of [SHIFT](https://shift-europe.eu/) TTS tool with [AudioGen sounds
26
  - [Analysis of emotion of SHIFT TTS](https://huggingface.co/dkounadis/artificial-styletts2/discussions/2)
27
  - [Listen Also foreign languages](https://huggingface.co/dkounadis/artificial-styletts2/discussions/4) synthesized via [MMS TTS](https://huggingface.co/facebook/mms-tts)
28
 
29
- ## Listen to Available Voices!
30
 
31
 
32
  <a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20">Native English</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6">Non-native English: Accents</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv">Foreign languages</a>
@@ -39,10 +39,10 @@ Beta version of [SHIFT](https://shift-europe.eu/) TTS tool with [AudioGen sounds
39
 
40
  <details>
41
  <summary>
42
-
43
- Build virtualenv / run `api.py`
44
-
45
  </summary>
 
 
46
 
47
  Clone
48
 
 
26
  - [Analysis of emotion of SHIFT TTS](https://huggingface.co/dkounadis/artificial-styletts2/discussions/2)
27
  - [Listen Also foreign languages](https://huggingface.co/dkounadis/artificial-styletts2/discussions/4) synthesized via [MMS TTS](https://huggingface.co/facebook/mms-tts)
28
 
29
+ ## Listen Voices
30
 
31
 
32
  <a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20">Native English</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6">Non-native English: Accents</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv">Foreign languages</a>
 
39
 
40
  <details>
41
  <summary>
42
+ Build virtualenv & run api.py
 
 
43
  </summary>
44
+ Besides `demo.py` that runs as a standalone script. We also provide `api.py` that enables long-form TTS with soundscape
45
+ w/o need to load the TTS/AudioGen model again & again. The examples below use api.py.
46
 
47
  Clone
48
 
models.py CHANGED
@@ -16,7 +16,7 @@ from Utils.ASR.models import ASRCNN
16
  from Utils.JDC.model import JDCNet
17
 
18
  from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
- from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
20
  from Modules.diffusion.diffusion import AudioDiffusionConditional
21
 
22
 
@@ -551,9 +551,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
551
  context_features=args.style_dim*2,
552
  **args.diffusion.transformer)
553
  else:
554
- transformer = Transformer1d(channels=args.style_dim*2,
555
- context_embedding_features=bert.config.hidden_size,
556
- **args.diffusion.transformer)
557
 
558
  diffusion = AudioDiffusionConditional(
559
  in_channels=1,
 
16
  from Utils.JDC.model import JDCNet
17
 
18
  from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
+ from Modules.diffusion.modules import StyleTransformer1d
20
  from Modules.diffusion.diffusion import AudioDiffusionConditional
21
 
22
 
 
551
  context_features=args.style_dim*2,
552
  **args.diffusion.transformer)
553
  else:
554
+ raise NotImplementedError
 
 
555
 
556
  diffusion = AudioDiffusionConditional(
557
  in_channels=1,