format
Browse files- Modules/diffusion/modules.py +2 -277
- README.md +4 -4
- models.py +2 -4
Modules/diffusion/modules.py
CHANGED
@@ -279,207 +279,6 @@ class StyleAttention(nn.Module):
|
|
279 |
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
280 |
# Compute and return attention
|
281 |
return self.attention(q, k, v)
|
282 |
-
|
283 |
-
class Transformer1d(nn.Module):
|
284 |
-
def __init__(
|
285 |
-
self,
|
286 |
-
num_layers: int,
|
287 |
-
channels: int,
|
288 |
-
num_heads: int,
|
289 |
-
head_features: int,
|
290 |
-
multiplier: int,
|
291 |
-
use_context_time: bool = True,
|
292 |
-
use_rel_pos: bool = False,
|
293 |
-
context_features_multiplier: int = 1,
|
294 |
-
rel_pos_num_buckets: Optional[int] = None,
|
295 |
-
rel_pos_max_distance: Optional[int] = None,
|
296 |
-
context_features: Optional[int] = None,
|
297 |
-
context_embedding_features: Optional[int] = None,
|
298 |
-
embedding_max_length: int = 512,
|
299 |
-
):
|
300 |
-
super().__init__()
|
301 |
-
|
302 |
-
self.blocks = nn.ModuleList(
|
303 |
-
[
|
304 |
-
TransformerBlock(
|
305 |
-
features=channels + context_embedding_features,
|
306 |
-
head_features=head_features,
|
307 |
-
num_heads=num_heads,
|
308 |
-
multiplier=multiplier,
|
309 |
-
use_rel_pos=use_rel_pos,
|
310 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
311 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
312 |
-
)
|
313 |
-
for i in range(num_layers)
|
314 |
-
]
|
315 |
-
)
|
316 |
-
|
317 |
-
self.to_out = nn.Sequential(
|
318 |
-
Rearrange("b t c -> b c t"),
|
319 |
-
nn.Conv1d(
|
320 |
-
in_channels=channels + context_embedding_features,
|
321 |
-
out_channels=channels,
|
322 |
-
kernel_size=1,
|
323 |
-
),
|
324 |
-
)
|
325 |
-
|
326 |
-
use_context_features = exists(context_features)
|
327 |
-
self.use_context_features = use_context_features
|
328 |
-
self.use_context_time = use_context_time
|
329 |
-
|
330 |
-
if use_context_time or use_context_features:
|
331 |
-
context_mapping_features = channels + context_embedding_features
|
332 |
-
|
333 |
-
self.to_mapping = nn.Sequential(
|
334 |
-
nn.Linear(context_mapping_features, context_mapping_features),
|
335 |
-
nn.GELU(),
|
336 |
-
nn.Linear(context_mapping_features, context_mapping_features),
|
337 |
-
nn.GELU(),
|
338 |
-
)
|
339 |
-
|
340 |
-
if use_context_time:
|
341 |
-
assert exists(context_mapping_features)
|
342 |
-
self.to_time = nn.Sequential(
|
343 |
-
TimePositionalEmbedding(
|
344 |
-
dim=channels, out_features=context_mapping_features
|
345 |
-
),
|
346 |
-
nn.GELU(),
|
347 |
-
)
|
348 |
-
|
349 |
-
if use_context_features:
|
350 |
-
assert exists(context_features) and exists(context_mapping_features)
|
351 |
-
self.to_features = nn.Sequential(
|
352 |
-
nn.Linear(
|
353 |
-
in_features=context_features, out_features=context_mapping_features
|
354 |
-
),
|
355 |
-
nn.GELU(),
|
356 |
-
)
|
357 |
-
|
358 |
-
self.fixed_embedding = FixedEmbedding(
|
359 |
-
max_length=embedding_max_length, features=context_embedding_features
|
360 |
-
)
|
361 |
-
|
362 |
-
|
363 |
-
def get_mapping(
|
364 |
-
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
365 |
-
) -> Optional[Tensor]:
|
366 |
-
"""Combines context time features and features into mapping"""
|
367 |
-
items, mapping = [], None
|
368 |
-
# Compute time features
|
369 |
-
if self.use_context_time:
|
370 |
-
assert_message = "use_context_time=True but no time features provided"
|
371 |
-
assert exists(time), assert_message
|
372 |
-
items += [self.to_time(time)]
|
373 |
-
# Compute features
|
374 |
-
if self.use_context_features:
|
375 |
-
assert_message = "context_features exists but no features provided"
|
376 |
-
assert exists(features), assert_message
|
377 |
-
items += [self.to_features(features)]
|
378 |
-
|
379 |
-
# Compute joint mapping
|
380 |
-
if self.use_context_time or self.use_context_features:
|
381 |
-
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
382 |
-
mapping = self.to_mapping(mapping)
|
383 |
-
|
384 |
-
return mapping
|
385 |
-
|
386 |
-
def run(self, x, time, embedding, features):
|
387 |
-
|
388 |
-
mapping = self.get_mapping(time, features)
|
389 |
-
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
390 |
-
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
|
391 |
-
|
392 |
-
for block in self.blocks:
|
393 |
-
x = x + mapping
|
394 |
-
x = block(x)
|
395 |
-
|
396 |
-
x = x.mean(axis=1).unsqueeze(1)
|
397 |
-
x = self.to_out(x)
|
398 |
-
x = x.transpose(-1, -2)
|
399 |
-
|
400 |
-
return x
|
401 |
-
|
402 |
-
def forward(self, x: Tensor,
|
403 |
-
time: Tensor,
|
404 |
-
embedding_mask_proba: float = 0.0,
|
405 |
-
embedding: Optional[Tensor] = None,
|
406 |
-
features: Optional[Tensor] = None,
|
407 |
-
embedding_scale: float = 1.0) -> Tensor:
|
408 |
-
|
409 |
-
b, device = embedding.shape[0], embedding.device
|
410 |
-
fixed_embedding = self.fixed_embedding(embedding)
|
411 |
-
if embedding_mask_proba > 0.0:
|
412 |
-
# Randomly mask embedding
|
413 |
-
batch_mask = rand_bool(
|
414 |
-
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
415 |
-
)
|
416 |
-
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
417 |
-
|
418 |
-
if embedding_scale != 1.0:
|
419 |
-
# Compute both normal and fixed embedding outputs
|
420 |
-
out = self.run(x, time, embedding=embedding, features=features)
|
421 |
-
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
422 |
-
# Scale conditional output using classifier-free guidance
|
423 |
-
return out_masked + (out - out_masked) * embedding_scale
|
424 |
-
else:
|
425 |
-
return self.run(x, time, embedding=embedding, features=features)
|
426 |
-
|
427 |
-
return x
|
428 |
-
|
429 |
-
|
430 |
-
"""
|
431 |
-
Attention Components
|
432 |
-
"""
|
433 |
-
|
434 |
-
|
435 |
-
class RelativePositionBias(nn.Module):
|
436 |
-
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
|
437 |
-
super().__init__()
|
438 |
-
self.num_buckets = num_buckets
|
439 |
-
self.max_distance = max_distance
|
440 |
-
self.num_heads = num_heads
|
441 |
-
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
442 |
-
|
443 |
-
@staticmethod
|
444 |
-
def _relative_position_bucket(
|
445 |
-
relative_position: Tensor, num_buckets: int, max_distance: int
|
446 |
-
):
|
447 |
-
num_buckets //= 2
|
448 |
-
ret = (relative_position >= 0).to(torch.long) * num_buckets
|
449 |
-
n = torch.abs(relative_position)
|
450 |
-
|
451 |
-
max_exact = num_buckets // 2
|
452 |
-
is_small = n < max_exact
|
453 |
-
|
454 |
-
val_if_large = (
|
455 |
-
max_exact
|
456 |
-
+ (
|
457 |
-
torch.log(n.float() / max_exact)
|
458 |
-
/ log(max_distance / max_exact)
|
459 |
-
* (num_buckets - max_exact)
|
460 |
-
).long()
|
461 |
-
)
|
462 |
-
val_if_large = torch.min(
|
463 |
-
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
|
464 |
-
)
|
465 |
-
|
466 |
-
ret += torch.where(is_small, n, val_if_large)
|
467 |
-
return ret
|
468 |
-
|
469 |
-
def forward(self, num_queries: int, num_keys: int) -> Tensor:
|
470 |
-
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
|
471 |
-
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
|
472 |
-
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
473 |
-
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
|
474 |
-
|
475 |
-
relative_position_bucket = self._relative_position_bucket(
|
476 |
-
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
|
477 |
-
)
|
478 |
-
|
479 |
-
bias = self.relative_attention_bias(relative_position_bucket)
|
480 |
-
bias = rearrange(bias, "m n h -> 1 h m n")
|
481 |
-
return bias
|
482 |
-
|
483 |
|
484 |
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
485 |
mid_features = features * multiplier
|
@@ -509,12 +308,8 @@ class AttentionBase(nn.Module):
|
|
509 |
mid_features = head_features * num_heads
|
510 |
|
511 |
if use_rel_pos:
|
512 |
-
|
513 |
-
|
514 |
-
num_buckets=rel_pos_num_buckets,
|
515 |
-
max_distance=rel_pos_max_distance,
|
516 |
-
num_heads=num_heads,
|
517 |
-
)
|
518 |
if out_features is None:
|
519 |
out_features = features
|
520 |
|
@@ -584,76 +379,6 @@ class Attention(nn.Module):
|
|
584 |
return self.attention(q, k, v)
|
585 |
|
586 |
|
587 |
-
"""
|
588 |
-
Transformer Blocks
|
589 |
-
"""
|
590 |
-
|
591 |
-
|
592 |
-
class TransformerBlock(nn.Module):
|
593 |
-
def __init__(
|
594 |
-
self,
|
595 |
-
features: int,
|
596 |
-
num_heads: int,
|
597 |
-
head_features: int,
|
598 |
-
multiplier: int,
|
599 |
-
use_rel_pos: bool,
|
600 |
-
rel_pos_num_buckets: Optional[int] = None,
|
601 |
-
rel_pos_max_distance: Optional[int] = None,
|
602 |
-
context_features: Optional[int] = None,
|
603 |
-
):
|
604 |
-
super().__init__()
|
605 |
-
|
606 |
-
self.use_cross_attention = exists(context_features) and context_features > 0
|
607 |
-
|
608 |
-
self.attention = Attention(
|
609 |
-
features=features,
|
610 |
-
num_heads=num_heads,
|
611 |
-
head_features=head_features,
|
612 |
-
use_rel_pos=use_rel_pos,
|
613 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
614 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
615 |
-
)
|
616 |
-
|
617 |
-
if self.use_cross_attention:
|
618 |
-
self.cross_attention = Attention(
|
619 |
-
features=features,
|
620 |
-
num_heads=num_heads,
|
621 |
-
head_features=head_features,
|
622 |
-
context_features=context_features,
|
623 |
-
use_rel_pos=use_rel_pos,
|
624 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
625 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
626 |
-
)
|
627 |
-
|
628 |
-
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
629 |
-
|
630 |
-
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
|
631 |
-
x = self.attention(x) + x
|
632 |
-
if self.use_cross_attention:
|
633 |
-
x = self.cross_attention(x, context=context) + x
|
634 |
-
x = self.feed_forward(x) + x
|
635 |
-
return x
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
"""
|
640 |
-
Time Embeddings
|
641 |
-
"""
|
642 |
-
|
643 |
-
|
644 |
-
class SinusoidalEmbedding(nn.Module):
|
645 |
-
def __init__(self, dim: int):
|
646 |
-
super().__init__()
|
647 |
-
self.dim = dim
|
648 |
-
|
649 |
-
def forward(self, x: Tensor) -> Tensor:
|
650 |
-
device, half_dim = x.device, self.dim // 2
|
651 |
-
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
|
652 |
-
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
653 |
-
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
|
654 |
-
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
655 |
-
|
656 |
-
|
657 |
class LearnedPositionalEmbedding(nn.Module):
|
658 |
"""Used for continuous time"""
|
659 |
|
|
|
279 |
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
280 |
# Compute and return attention
|
281 |
return self.attention(q, k, v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
284 |
mid_features = features * multiplier
|
|
|
308 |
mid_features = head_features * num_heads
|
309 |
|
310 |
if use_rel_pos:
|
311 |
+
raise ValueError
|
312 |
+
|
|
|
|
|
|
|
|
|
313 |
if out_features is None:
|
314 |
out_features = features
|
315 |
|
|
|
379 |
return self.attention(q, k, v)
|
380 |
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
class LearnedPositionalEmbedding(nn.Module):
|
383 |
"""Used for continuous time"""
|
384 |
|
README.md
CHANGED
@@ -26,7 +26,7 @@ Beta version of [SHIFT](https://shift-europe.eu/) TTS tool with [AudioGen sounds
|
|
26 |
- [Analysis of emotion of SHIFT TTS](https://huggingface.co/dkounadis/artificial-styletts2/discussions/2)
|
27 |
- [Listen Also foreign languages](https://huggingface.co/dkounadis/artificial-styletts2/discussions/4) synthesized via [MMS TTS](https://huggingface.co/facebook/mms-tts)
|
28 |
|
29 |
-
## Listen
|
30 |
|
31 |
|
32 |
<a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20">Native English</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6">Non-native English: Accents</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv">Foreign languages</a>
|
@@ -39,10 +39,10 @@ Beta version of [SHIFT](https://shift-europe.eu/) TTS tool with [AudioGen sounds
|
|
39 |
|
40 |
<details>
|
41 |
<summary>
|
42 |
-
|
43 |
-
Build virtualenv / run `api.py`
|
44 |
-
|
45 |
</summary>
|
|
|
|
|
46 |
|
47 |
Clone
|
48 |
|
|
|
26 |
- [Analysis of emotion of SHIFT TTS](https://huggingface.co/dkounadis/artificial-styletts2/discussions/2)
|
27 |
- [Listen Also foreign languages](https://huggingface.co/dkounadis/artificial-styletts2/discussions/4) synthesized via [MMS TTS](https://huggingface.co/facebook/mms-tts)
|
28 |
|
29 |
+
## Listen Voices
|
30 |
|
31 |
|
32 |
<a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20">Native English</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6">Non-native English: Accents</a> / <a href="https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv">Foreign languages</a>
|
|
|
39 |
|
40 |
<details>
|
41 |
<summary>
|
42 |
+
Build virtualenv & run api.py
|
|
|
|
|
43 |
</summary>
|
44 |
+
Besides `demo.py` that runs as a standalone script. We also provide `api.py` that enables long-form TTS with soundscape
|
45 |
+
w/o need to load the TTS/AudioGen model again & again. The examples below use api.py.
|
46 |
|
47 |
Clone
|
48 |
|
models.py
CHANGED
@@ -16,7 +16,7 @@ from Utils.ASR.models import ASRCNN
|
|
16 |
from Utils.JDC.model import JDCNet
|
17 |
|
18 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
19 |
-
from Modules.diffusion.modules import
|
20 |
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
21 |
|
22 |
|
@@ -551,9 +551,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
551 |
context_features=args.style_dim*2,
|
552 |
**args.diffusion.transformer)
|
553 |
else:
|
554 |
-
|
555 |
-
context_embedding_features=bert.config.hidden_size,
|
556 |
-
**args.diffusion.transformer)
|
557 |
|
558 |
diffusion = AudioDiffusionConditional(
|
559 |
in_channels=1,
|
|
|
16 |
from Utils.JDC.model import JDCNet
|
17 |
|
18 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
19 |
+
from Modules.diffusion.modules import StyleTransformer1d
|
20 |
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
21 |
|
22 |
|
|
|
551 |
context_features=args.style_dim*2,
|
552 |
**args.diffusion.transformer)
|
553 |
else:
|
554 |
+
raise NotImplementedError
|
|
|
|
|
555 |
|
556 |
diffusion = AudioDiffusionConditional(
|
557 |
in_channels=1,
|