|
|
|
|
|
|
|
import math |
|
from inspect import isfunction |
|
from math import ceil, floor, log, pi, log2 |
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union |
|
from packaging import version |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, reduce, repeat |
|
from einops.layers.torch import Rearrange |
|
from einops_exts import rearrange_many |
|
from torch import Tensor, einsum |
|
from torch.backends.cuda import sdp_kernel |
|
from torch.nn import functional as F |
|
from dac.nn.layers import Snake1d |
|
import pdb |
|
""" |
|
Utils |
|
""" |
|
|
|
|
|
class ConditionedSequential(nn.Module): |
|
def __init__(self, *modules): |
|
super().__init__() |
|
self.module_list = nn.ModuleList(*modules) |
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None): |
|
for module in self.module_list: |
|
x = module(x, mapping) |
|
return x |
|
|
|
T = TypeVar("T") |
|
|
|
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
def exists(val: Optional[T]) -> T: |
|
return val is not None |
|
|
|
def closest_power_2(x: float) -> int: |
|
exponent = log2(x) |
|
distance_fn = lambda z: abs(x - 2 ** z) |
|
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) |
|
return 2 ** int(exponent_closest) |
|
|
|
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: |
|
return_dicts: Tuple[Dict, Dict] = ({}, {}) |
|
for key in d.keys(): |
|
no_prefix = int(not key.startswith(prefix)) |
|
return_dicts[no_prefix][key] = d[key] |
|
return return_dicts |
|
|
|
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: |
|
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) |
|
if keep_prefix: |
|
return kwargs_with_prefix, kwargs |
|
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} |
|
return kwargs_no_prefix, kwargs |
|
|
|
""" |
|
Convolutional Blocks |
|
""" |
|
import typing as tp |
|
|
|
|
|
|
|
|
|
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, |
|
padding_total: int = 0) -> int: |
|
"""See `pad_for_conv1d`.""" |
|
length = x.shape[-1] |
|
n_frames = (length - kernel_size + padding_total) / stride + 1 |
|
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) |
|
return ideal_length - length |
|
|
|
|
|
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): |
|
"""Pad for a convolution to make sure that the last window is full. |
|
Extra padding is added at the end. This is required to ensure that we can rebuild |
|
an output of the same length, as otherwise, even with padding, some time steps |
|
might get removed. |
|
For instance, with total padding = 4, kernel size = 4, stride = 2: |
|
0 0 1 2 3 4 5 0 0 # (0s are padding) |
|
1 2 3 # (output frames of a convolution, last 0 is never used) |
|
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) |
|
1 2 3 4 # once you removed padding, we are missing one time step ! |
|
""" |
|
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) |
|
return F.pad(x, (0, extra_padding)) |
|
|
|
|
|
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): |
|
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input. |
|
If this is the case, we insert extra 0 padding to the right before the reflection happen. |
|
""" |
|
length = x.shape[-1] |
|
padding_left, padding_right = paddings |
|
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
|
if mode == 'reflect': |
|
max_pad = max(padding_left, padding_right) |
|
extra_pad = 0 |
|
if length <= max_pad: |
|
extra_pad = max_pad - length + 1 |
|
x = F.pad(x, (0, extra_pad)) |
|
padded = F.pad(x, paddings, mode, value) |
|
end = padded.shape[-1] - extra_pad |
|
return padded[..., :end] |
|
else: |
|
return F.pad(x, paddings, mode, value) |
|
|
|
|
|
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): |
|
"""Remove padding from x, handling properly zero padding. Only for 1d!""" |
|
padding_left, padding_right = paddings |
|
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
|
assert (padding_left + padding_right) <= x.shape[-1] |
|
end = x.shape[-1] - padding_right |
|
return x[..., padding_left: end] |
|
|
|
|
|
class Conv1d(nn.Conv1d): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, x: Tensor, causal=False) -> Tensor: |
|
kernel_size = self.kernel_size[0] |
|
stride = self.stride[0] |
|
dilation = self.dilation[0] |
|
kernel_size = (kernel_size - 1) * dilation + 1 |
|
padding_total = kernel_size - stride |
|
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) |
|
if causal: |
|
|
|
x = pad1d(x, (padding_total, extra_padding)) |
|
else: |
|
|
|
padding_right = padding_total // 2 |
|
padding_left = padding_total - padding_right |
|
x = pad1d(x, (padding_left, padding_right + extra_padding)) |
|
return super().forward(x) |
|
|
|
class ConvTranspose1d(nn.ConvTranspose1d): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, x: Tensor, causal=False) -> Tensor: |
|
kernel_size = self.kernel_size[0] |
|
stride = self.stride[0] |
|
padding_total = kernel_size - stride |
|
|
|
y = super().forward(x) |
|
|
|
|
|
|
|
|
|
|
|
if causal: |
|
padding_right = ceil(padding_total) |
|
padding_left = padding_total - padding_right |
|
y = unpad1d(y, (padding_left, padding_right)) |
|
else: |
|
|
|
padding_right = padding_total // 2 |
|
padding_left = padding_total - padding_right |
|
y = unpad1d(y, (padding_left, padding_right)) |
|
return y |
|
|
|
|
|
def Downsample1d( |
|
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 |
|
) -> nn.Module: |
|
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" |
|
|
|
return Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=factor * kernel_multiplier + 1, |
|
stride=factor |
|
) |
|
|
|
|
|
def Upsample1d( |
|
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False |
|
) -> nn.Module: |
|
|
|
if factor == 1: |
|
return Conv1d( |
|
in_channels=in_channels, out_channels=out_channels, kernel_size=3 |
|
) |
|
|
|
if use_nearest: |
|
return nn.Sequential( |
|
nn.Upsample(scale_factor=factor, mode="nearest"), |
|
Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=3 |
|
), |
|
) |
|
else: |
|
return ConvTranspose1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=factor * 2, |
|
stride=factor |
|
) |
|
|
|
|
|
class ConvBlock1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
*, |
|
kernel_size: int = 3, |
|
stride: int = 1, |
|
dilation: int = 1, |
|
num_groups: int = 8, |
|
use_norm: bool = True, |
|
use_snake: bool = False |
|
) -> None: |
|
super().__init__() |
|
|
|
self.groupnorm = ( |
|
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) |
|
if use_norm |
|
else nn.Identity() |
|
) |
|
|
|
if use_snake: |
|
self.activation = Snake1d(in_channels) |
|
else: |
|
self.activation = nn.SiLU() |
|
|
|
self.project = Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
) |
|
|
|
def forward( |
|
self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False |
|
) -> Tensor: |
|
x = self.groupnorm(x) |
|
if exists(scale_shift): |
|
scale, shift = scale_shift |
|
x = x * (scale + 1) + shift |
|
x = self.activation(x) |
|
return self.project(x, causal=causal) |
|
|
|
|
|
class MappingToScaleShift(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
channels: int, |
|
): |
|
super().__init__() |
|
|
|
self.to_scale_shift = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(in_features=features, out_features=channels * 2), |
|
) |
|
|
|
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: |
|
scale_shift = self.to_scale_shift(mapping) |
|
scale_shift = rearrange(scale_shift, "b c -> b c 1") |
|
scale, shift = scale_shift.chunk(2, dim=1) |
|
return scale, shift |
|
|
|
|
|
class ResnetBlock1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
*, |
|
kernel_size: int = 3, |
|
stride: int = 1, |
|
dilation: int = 1, |
|
use_norm: bool = True, |
|
use_snake: bool = False, |
|
num_groups: int = 8, |
|
context_mapping_features: Optional[int] = None, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.use_mapping = exists(context_mapping_features) |
|
|
|
self.block1 = ConvBlock1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
use_norm=use_norm, |
|
num_groups=num_groups, |
|
use_snake=use_snake |
|
) |
|
|
|
if self.use_mapping: |
|
assert exists(context_mapping_features) |
|
self.to_scale_shift = MappingToScaleShift( |
|
features=context_mapping_features, channels=out_channels |
|
) |
|
|
|
self.block2 = ConvBlock1d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
use_norm=use_norm, |
|
num_groups=num_groups, |
|
use_snake=use_snake |
|
) |
|
|
|
self.to_out = ( |
|
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) |
|
if in_channels != out_channels |
|
else nn.Identity() |
|
) |
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: |
|
assert_message = "context mapping required if context_mapping_features > 0" |
|
assert not (self.use_mapping ^ exists(mapping)), assert_message |
|
|
|
h = self.block1(x, causal=causal) |
|
|
|
scale_shift = None |
|
if self.use_mapping: |
|
scale_shift = self.to_scale_shift(mapping) |
|
|
|
h = self.block2(h, scale_shift=scale_shift, causal=causal) |
|
|
|
return h + self.to_out(x) |
|
|
|
|
|
class Patcher(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
patch_size: int, |
|
context_mapping_features: Optional[int] = None, |
|
use_snake: bool = False, |
|
): |
|
super().__init__() |
|
assert_message = f"out_channels must be divisible by patch_size ({patch_size})" |
|
assert out_channels % patch_size == 0, assert_message |
|
self.patch_size = patch_size |
|
|
|
self.block = ResnetBlock1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels // patch_size, |
|
num_groups=1, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: |
|
x = self.block(x, mapping, causal=causal) |
|
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) |
|
return x |
|
|
|
|
|
class Unpatcher(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
patch_size: int, |
|
context_mapping_features: Optional[int] = None, |
|
use_snake: bool = False |
|
): |
|
super().__init__() |
|
assert_message = f"in_channels must be divisible by patch_size ({patch_size})" |
|
assert in_channels % patch_size == 0, assert_message |
|
self.patch_size = patch_size |
|
|
|
self.block = ResnetBlock1d( |
|
in_channels=in_channels // patch_size, |
|
out_channels=out_channels, |
|
num_groups=1, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
|
|
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: |
|
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) |
|
x = self.block(x, mapping, causal=causal) |
|
return x |
|
|
|
|
|
""" |
|
Attention Components |
|
""" |
|
def FeedForward(features: int, multiplier: int) -> nn.Module: |
|
mid_features = features * multiplier |
|
return nn.Sequential( |
|
nn.Linear(in_features=features, out_features=mid_features), |
|
nn.GELU(), |
|
nn.Linear(in_features=mid_features, out_features=features), |
|
) |
|
|
|
def add_mask(sim: Tensor, mask: Tensor) -> Tensor: |
|
b, ndim = sim.shape[0], mask.ndim |
|
if ndim == 3: |
|
mask = rearrange(mask, "b n m -> b 1 n m") |
|
if ndim == 2: |
|
mask = repeat(mask, "n m -> b 1 n m", b=b) |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
sim = sim.masked_fill(~mask, max_neg_value) |
|
return sim |
|
|
|
def causal_mask(q: Tensor, k: Tensor) -> Tensor: |
|
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device |
|
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) |
|
mask = repeat(mask, "n m -> b n m", b=b) |
|
return mask |
|
|
|
class AttentionBase(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
*, |
|
head_features: int, |
|
num_heads: int, |
|
out_features: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.scale = head_features**-0.5 |
|
self.num_heads = num_heads |
|
mid_features = head_features * num_heads |
|
out_features = default(out_features, features) |
|
|
|
self.to_out = nn.Linear( |
|
in_features=mid_features, out_features=out_features |
|
) |
|
|
|
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') |
|
|
|
if not self.use_flash: |
|
return |
|
|
|
device_properties = torch.cuda.get_device_properties(torch.device('cuda')) |
|
|
|
if device_properties.major == 8 and device_properties.minor == 0: |
|
|
|
self.sdp_kernel_config = (True, False, False) |
|
else: |
|
|
|
self.sdp_kernel_config = (False, True, True) |
|
|
|
def forward( |
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False |
|
) -> Tensor: |
|
|
|
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
|
if not self.use_flash: |
|
if is_causal and not mask: |
|
|
|
mask = causal_mask(q, k) |
|
|
|
|
|
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale |
|
sim = add_mask(sim, mask) if exists(mask) else sim |
|
|
|
|
|
attn = sim.softmax(dim=-1, dtype=torch.float32) |
|
|
|
|
|
out = einsum("... n m, ... m d -> ... n d", attn, v) |
|
else: |
|
with sdp_kernel(*self.sdp_kernel_config): |
|
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) |
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
*, |
|
head_features: int, |
|
num_heads: int, |
|
out_features: Optional[int] = None, |
|
context_features: Optional[int] = None, |
|
causal: bool = False, |
|
): |
|
super().__init__() |
|
self.context_features = context_features |
|
self.causal = causal |
|
mid_features = head_features * num_heads |
|
context_features = default(context_features, features) |
|
|
|
self.norm = nn.LayerNorm(features) |
|
self.norm_context = nn.LayerNorm(context_features) |
|
self.to_q = nn.Linear( |
|
in_features=features, out_features=mid_features, bias=False |
|
) |
|
self.to_kv = nn.Linear( |
|
in_features=context_features, out_features=mid_features * 2, bias=False |
|
) |
|
self.attention = AttentionBase( |
|
features, |
|
num_heads=num_heads, |
|
head_features=head_features, |
|
out_features=out_features, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
context: Optional[Tensor] = None, |
|
context_mask: Optional[Tensor] = None, |
|
causal: Optional[bool] = False, |
|
) -> Tensor: |
|
assert_message = "You must provide a context when using context_features" |
|
assert not self.context_features or exists(context), assert_message |
|
|
|
context = default(context, x) |
|
|
|
x, context = self.norm(x), self.norm_context(context) |
|
|
|
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) |
|
|
|
if exists(context_mask): |
|
|
|
mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) |
|
k, v = k * mask, v * mask |
|
|
|
|
|
return self.attention(q, k, v, is_causal=self.causal or causal) |
|
|
|
|
|
def FeedForward(features: int, multiplier: int) -> nn.Module: |
|
mid_features = features * multiplier |
|
return nn.Sequential( |
|
nn.Linear(in_features=features, out_features=mid_features), |
|
nn.GELU(), |
|
nn.Linear(in_features=mid_features, out_features=features), |
|
) |
|
|
|
""" |
|
Transformer Blocks |
|
""" |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
num_heads: int, |
|
head_features: int, |
|
multiplier: int, |
|
context_features: Optional[int] = None, |
|
): |
|
super().__init__() |
|
|
|
self.use_cross_attention = exists(context_features) and context_features > 0 |
|
|
|
self.attention = Attention( |
|
features=features, |
|
num_heads=num_heads, |
|
head_features=head_features |
|
) |
|
|
|
if self.use_cross_attention: |
|
self.cross_attention = Attention( |
|
features=features, |
|
num_heads=num_heads, |
|
head_features=head_features, |
|
context_features=context_features |
|
) |
|
|
|
self.feed_forward = FeedForward(features=features, multiplier=multiplier) |
|
|
|
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: |
|
x = self.attention(x, causal=causal) + x |
|
if self.use_cross_attention: |
|
x = self.cross_attention(x, context=context, context_mask=context_mask) + x |
|
x = self.feed_forward(x) + x |
|
return x |
|
|
|
|
|
""" |
|
Transformers |
|
""" |
|
|
|
|
|
class Transformer1d(nn.Module): |
|
def __init__( |
|
self, |
|
num_layers: int, |
|
channels: int, |
|
num_heads: int, |
|
head_features: int, |
|
multiplier: int, |
|
context_features: Optional[int] = None, |
|
): |
|
super().__init__() |
|
|
|
self.to_in = nn.Sequential( |
|
nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), |
|
Conv1d( |
|
in_channels=channels, |
|
out_channels=channels, |
|
kernel_size=1, |
|
), |
|
Rearrange("b c t -> b t c"), |
|
) |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
TransformerBlock( |
|
features=channels, |
|
head_features=head_features, |
|
num_heads=num_heads, |
|
multiplier=multiplier, |
|
context_features=context_features, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.to_out = nn.Sequential( |
|
Rearrange("b t c -> b c t"), |
|
Conv1d( |
|
in_channels=channels, |
|
out_channels=channels, |
|
kernel_size=1, |
|
), |
|
) |
|
|
|
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: |
|
x = self.to_in(x) |
|
for block in self.blocks: |
|
x = block(x, context=context, context_mask=context_mask, causal=causal) |
|
x = self.to_out(x) |
|
return x |
|
|
|
|
|
""" |
|
Time Embeddings |
|
""" |
|
|
|
|
|
class SinusoidalEmbedding(nn.Module): |
|
def __init__(self, dim: int): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
device, half_dim = x.device, self.dim // 2 |
|
emb = torch.tensor(log(10000) / (half_dim - 1), device=device) |
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") |
|
return torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Module): |
|
"""Used for continuous time""" |
|
|
|
def __init__(self, dim: int): |
|
super().__init__() |
|
assert (dim % 2) == 0 |
|
half_dim = dim // 2 |
|
self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
x = rearrange(x, "b -> b 1") |
|
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi |
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
|
fouriered = torch.cat((x, fouriered), dim=-1) |
|
return fouriered |
|
|
|
|
|
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: |
|
return nn.Sequential( |
|
LearnedPositionalEmbedding(dim), |
|
nn.Linear(in_features=dim + 1, out_features=out_features), |
|
) |
|
|
|
|
|
""" |
|
Encoder/Decoder Components |
|
""" |
|
|
|
|
|
class DownsampleBlock1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
*, |
|
factor: int, |
|
num_groups: int, |
|
num_layers: int, |
|
kernel_multiplier: int = 2, |
|
use_pre_downsample: bool = True, |
|
use_skip: bool = False, |
|
use_snake: bool = False, |
|
extract_channels: int = 0, |
|
context_channels: int = 0, |
|
num_transformer_blocks: int = 0, |
|
attention_heads: Optional[int] = None, |
|
attention_features: Optional[int] = None, |
|
attention_multiplier: Optional[int] = None, |
|
context_mapping_features: Optional[int] = None, |
|
context_embedding_features: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.use_pre_downsample = use_pre_downsample |
|
self.use_skip = use_skip |
|
self.use_transformer = num_transformer_blocks > 0 |
|
self.use_extract = extract_channels > 0 |
|
self.use_context = context_channels > 0 |
|
|
|
channels = out_channels if use_pre_downsample else in_channels |
|
|
|
self.downsample = Downsample1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
factor=factor, |
|
kernel_multiplier=kernel_multiplier, |
|
) |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
ResnetBlock1d( |
|
in_channels=channels + context_channels if i == 0 else channels, |
|
out_channels=channels, |
|
num_groups=num_groups, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
if self.use_transformer: |
|
assert ( |
|
(exists(attention_heads) or exists(attention_features)) |
|
and exists(attention_multiplier) |
|
) |
|
|
|
if attention_features is None and attention_heads is not None: |
|
attention_features = channels // attention_heads |
|
|
|
if attention_heads is None and attention_features is not None: |
|
attention_heads = channels // attention_features |
|
|
|
self.transformer = Transformer1d( |
|
num_layers=num_transformer_blocks, |
|
channels=channels, |
|
num_heads=attention_heads, |
|
head_features=attention_features, |
|
multiplier=attention_multiplier, |
|
context_features=context_embedding_features |
|
) |
|
|
|
if self.use_extract: |
|
num_extract_groups = min(num_groups, extract_channels) |
|
self.to_extracted = ResnetBlock1d( |
|
in_channels=out_channels, |
|
out_channels=extract_channels, |
|
num_groups=num_extract_groups, |
|
use_snake=use_snake |
|
) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
*, |
|
mapping: Optional[Tensor] = None, |
|
channels: Optional[Tensor] = None, |
|
embedding: Optional[Tensor] = None, |
|
embedding_mask: Optional[Tensor] = None, |
|
causal: Optional[bool] = False |
|
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: |
|
|
|
if self.use_pre_downsample: |
|
x = self.downsample(x) |
|
|
|
if self.use_context and exists(channels): |
|
x = torch.cat([x, channels], dim=1) |
|
|
|
skips = [] |
|
for block in self.blocks: |
|
x = block(x, mapping=mapping, causal=causal) |
|
skips += [x] if self.use_skip else [] |
|
|
|
if self.use_transformer: |
|
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) |
|
skips += [x] if self.use_skip else [] |
|
|
|
if not self.use_pre_downsample: |
|
x = self.downsample(x) |
|
|
|
if self.use_extract: |
|
extracted = self.to_extracted(x) |
|
return x, extracted |
|
|
|
return (x, skips) if self.use_skip else x |
|
|
|
|
|
class UpsampleBlock1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
*, |
|
factor: int, |
|
num_layers: int, |
|
num_groups: int, |
|
use_nearest: bool = False, |
|
use_pre_upsample: bool = False, |
|
use_skip: bool = False, |
|
use_snake: bool = False, |
|
skip_channels: int = 0, |
|
use_skip_scale: bool = False, |
|
extract_channels: int = 0, |
|
num_transformer_blocks: int = 0, |
|
attention_heads: Optional[int] = None, |
|
attention_features: Optional[int] = None, |
|
attention_multiplier: Optional[int] = None, |
|
context_mapping_features: Optional[int] = None, |
|
context_embedding_features: Optional[int] = None, |
|
): |
|
super().__init__() |
|
|
|
self.use_extract = extract_channels > 0 |
|
self.use_pre_upsample = use_pre_upsample |
|
self.use_transformer = num_transformer_blocks > 0 |
|
self.use_skip = use_skip |
|
self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 |
|
|
|
channels = out_channels if use_pre_upsample else in_channels |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
ResnetBlock1d( |
|
in_channels=channels + skip_channels, |
|
out_channels=channels, |
|
num_groups=num_groups, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
if self.use_transformer: |
|
assert ( |
|
(exists(attention_heads) or exists(attention_features)) |
|
and exists(attention_multiplier) |
|
) |
|
|
|
if attention_features is None and attention_heads is not None: |
|
attention_features = channels // attention_heads |
|
|
|
if attention_heads is None and attention_features is not None: |
|
attention_heads = channels // attention_features |
|
|
|
self.transformer = Transformer1d( |
|
num_layers=num_transformer_blocks, |
|
channels=channels, |
|
num_heads=attention_heads, |
|
head_features=attention_features, |
|
multiplier=attention_multiplier, |
|
context_features=context_embedding_features, |
|
) |
|
|
|
self.upsample = Upsample1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
factor=factor, |
|
use_nearest=use_nearest, |
|
) |
|
|
|
if self.use_extract: |
|
num_extract_groups = min(num_groups, extract_channels) |
|
self.to_extracted = ResnetBlock1d( |
|
in_channels=out_channels, |
|
out_channels=extract_channels, |
|
num_groups=num_extract_groups, |
|
use_snake=use_snake |
|
) |
|
|
|
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: |
|
return torch.cat([x, skip * self.skip_scale], dim=1) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
*, |
|
skips: Optional[List[Tensor]] = None, |
|
mapping: Optional[Tensor] = None, |
|
embedding: Optional[Tensor] = None, |
|
embedding_mask: Optional[Tensor] = None, |
|
causal: Optional[bool] = False |
|
) -> Union[Tuple[Tensor, Tensor], Tensor]: |
|
|
|
if self.use_pre_upsample: |
|
x = self.upsample(x) |
|
|
|
for block in self.blocks: |
|
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x |
|
x = block(x, mapping=mapping, causal=causal) |
|
|
|
if self.use_transformer: |
|
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) |
|
|
|
if not self.use_pre_upsample: |
|
x = self.upsample(x) |
|
|
|
if self.use_extract: |
|
extracted = self.to_extracted(x) |
|
return x, extracted |
|
|
|
return x |
|
|
|
|
|
class BottleneckBlock1d(nn.Module): |
|
def __init__( |
|
self, |
|
channels: int, |
|
*, |
|
num_groups: int, |
|
num_transformer_blocks: int = 0, |
|
attention_heads: Optional[int] = None, |
|
attention_features: Optional[int] = None, |
|
attention_multiplier: Optional[int] = None, |
|
context_mapping_features: Optional[int] = None, |
|
context_embedding_features: Optional[int] = None, |
|
use_snake: bool = False, |
|
): |
|
super().__init__() |
|
self.use_transformer = num_transformer_blocks > 0 |
|
|
|
self.pre_block = ResnetBlock1d( |
|
in_channels=channels, |
|
out_channels=channels, |
|
num_groups=num_groups, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
|
|
if self.use_transformer: |
|
assert ( |
|
(exists(attention_heads) or exists(attention_features)) |
|
and exists(attention_multiplier) |
|
) |
|
|
|
if attention_features is None and attention_heads is not None: |
|
attention_features = channels // attention_heads |
|
|
|
if attention_heads is None and attention_features is not None: |
|
attention_heads = channels // attention_features |
|
|
|
self.transformer = Transformer1d( |
|
num_layers=num_transformer_blocks, |
|
channels=channels, |
|
num_heads=attention_heads, |
|
head_features=attention_features, |
|
multiplier=attention_multiplier, |
|
context_features=context_embedding_features, |
|
) |
|
|
|
self.post_block = ResnetBlock1d( |
|
in_channels=channels, |
|
out_channels=channels, |
|
num_groups=num_groups, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
*, |
|
mapping: Optional[Tensor] = None, |
|
embedding: Optional[Tensor] = None, |
|
embedding_mask: Optional[Tensor] = None, |
|
causal: Optional[bool] = False |
|
) -> Tensor: |
|
x = self.pre_block(x, mapping=mapping, causal=causal) |
|
if self.use_transformer: |
|
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) |
|
x = self.post_block(x, mapping=mapping, causal=causal) |
|
return x |
|
|
|
|
|
""" |
|
UNet |
|
""" |
|
|
|
|
|
class UNet1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
channels: int, |
|
multipliers: Sequence[int], |
|
factors: Sequence[int], |
|
num_blocks: Sequence[int], |
|
attentions: Sequence[int], |
|
patch_size: int = 1, |
|
resnet_groups: int = 8, |
|
use_context_time: bool = True, |
|
kernel_multiplier_downsample: int = 2, |
|
use_nearest_upsample: bool = False, |
|
use_skip_scale: bool = True, |
|
use_snake: bool = False, |
|
use_stft: bool = False, |
|
use_stft_context: bool = False, |
|
out_channels: Optional[int] = None, |
|
context_features: Optional[int] = None, |
|
context_features_multiplier: int = 4, |
|
context_channels: Optional[Sequence[int]] = None, |
|
context_embedding_features: Optional[int] = None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
out_channels = default(out_channels, in_channels) |
|
context_channels = list(default(context_channels, [])) |
|
num_layers = len(multipliers) - 1 |
|
use_context_features = exists(context_features) |
|
use_context_channels = len(context_channels) > 0 |
|
context_mapping_features = None |
|
|
|
attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) |
|
|
|
self.num_layers = num_layers |
|
self.use_context_time = use_context_time |
|
self.use_context_features = use_context_features |
|
self.use_context_channels = use_context_channels |
|
self.use_stft = use_stft |
|
self.use_stft_context = use_stft_context |
|
|
|
self.context_features = context_features |
|
context_channels_pad_length = num_layers + 1 - len(context_channels) |
|
context_channels = context_channels + [0] * context_channels_pad_length |
|
self.context_channels = context_channels |
|
self.context_embedding_features = context_embedding_features |
|
|
|
if use_context_channels: |
|
has_context = [c > 0 for c in context_channels] |
|
self.has_context = has_context |
|
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] |
|
|
|
assert ( |
|
len(factors) == num_layers |
|
and len(attentions) >= num_layers |
|
and len(num_blocks) == num_layers |
|
) |
|
|
|
if use_context_time or use_context_features: |
|
context_mapping_features = channels * context_features_multiplier |
|
|
|
self.to_mapping = nn.Sequential( |
|
nn.Linear(context_mapping_features, context_mapping_features), |
|
nn.GELU(), |
|
nn.Linear(context_mapping_features, context_mapping_features), |
|
nn.GELU(), |
|
) |
|
|
|
if use_context_time: |
|
assert exists(context_mapping_features) |
|
self.to_time = nn.Sequential( |
|
TimePositionalEmbedding( |
|
dim=channels, out_features=context_mapping_features |
|
), |
|
nn.GELU(), |
|
) |
|
|
|
if use_context_features: |
|
assert exists(context_features) and exists(context_mapping_features) |
|
self.to_features = nn.Sequential( |
|
nn.Linear( |
|
in_features=context_features, out_features=context_mapping_features |
|
), |
|
nn.GELU(), |
|
) |
|
|
|
if use_stft: |
|
stft_kwargs, kwargs = groupby("stft_", kwargs) |
|
assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" |
|
stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 |
|
in_channels *= stft_channels |
|
out_channels *= stft_channels |
|
context_channels[0] *= stft_channels if use_stft_context else 1 |
|
assert exists(in_channels) and exists(out_channels) |
|
self.stft = STFT(**stft_kwargs) |
|
|
|
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" |
|
|
|
self.to_in = Patcher( |
|
in_channels=in_channels + context_channels[0], |
|
out_channels=channels * multipliers[0], |
|
patch_size=patch_size, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
|
|
self.downsamples = nn.ModuleList( |
|
[ |
|
DownsampleBlock1d( |
|
in_channels=channels * multipliers[i], |
|
out_channels=channels * multipliers[i + 1], |
|
context_mapping_features=context_mapping_features, |
|
context_channels=context_channels[i + 1], |
|
context_embedding_features=context_embedding_features, |
|
num_layers=num_blocks[i], |
|
factor=factors[i], |
|
kernel_multiplier=kernel_multiplier_downsample, |
|
num_groups=resnet_groups, |
|
use_pre_downsample=True, |
|
use_skip=True, |
|
use_snake=use_snake, |
|
num_transformer_blocks=attentions[i], |
|
**attention_kwargs, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.bottleneck = BottleneckBlock1d( |
|
channels=channels * multipliers[-1], |
|
context_mapping_features=context_mapping_features, |
|
context_embedding_features=context_embedding_features, |
|
num_groups=resnet_groups, |
|
num_transformer_blocks=attentions[-1], |
|
use_snake=use_snake, |
|
**attention_kwargs, |
|
) |
|
|
|
self.upsamples = nn.ModuleList( |
|
[ |
|
UpsampleBlock1d( |
|
in_channels=channels * multipliers[i + 1], |
|
out_channels=channels * multipliers[i], |
|
context_mapping_features=context_mapping_features, |
|
context_embedding_features=context_embedding_features, |
|
num_layers=num_blocks[i] + (1 if attentions[i] else 0), |
|
factor=factors[i], |
|
use_nearest=use_nearest_upsample, |
|
num_groups=resnet_groups, |
|
use_skip_scale=use_skip_scale, |
|
use_pre_upsample=False, |
|
use_skip=True, |
|
use_snake=use_snake, |
|
skip_channels=channels * multipliers[i + 1], |
|
num_transformer_blocks=attentions[i], |
|
**attention_kwargs, |
|
) |
|
for i in reversed(range(num_layers)) |
|
] |
|
) |
|
|
|
self.to_out = Unpatcher( |
|
in_channels=channels * multipliers[0], |
|
out_channels=out_channels, |
|
patch_size=patch_size, |
|
context_mapping_features=context_mapping_features, |
|
use_snake=use_snake |
|
) |
|
|
|
def get_channels( |
|
self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 |
|
) -> Optional[Tensor]: |
|
"""Gets context channels at `layer` and checks that shape is correct""" |
|
use_context_channels = self.use_context_channels and self.has_context[layer] |
|
if not use_context_channels: |
|
return None |
|
assert exists(channels_list), "Missing context" |
|
|
|
channels_id = self.channels_ids[layer] |
|
|
|
channels = channels_list[channels_id] |
|
message = f"Missing context for layer {layer} at index {channels_id}" |
|
assert exists(channels), message |
|
|
|
num_channels = self.context_channels[layer] |
|
message = f"Expected context with {num_channels} channels at idx {channels_id}" |
|
assert channels.shape[1] == num_channels, message |
|
|
|
channels = self.stft.encode1d(channels) if self.use_stft_context else channels |
|
return channels |
|
|
|
def get_mapping( |
|
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None |
|
) -> Optional[Tensor]: |
|
"""Combines context time features and features into mapping""" |
|
items, mapping = [], None |
|
|
|
if self.use_context_time: |
|
assert_message = "use_context_time=True but no time features provided" |
|
assert exists(time), assert_message |
|
items += [self.to_time(time)] |
|
|
|
if self.use_context_features: |
|
assert_message = "context_features exists but no features provided" |
|
assert exists(features), assert_message |
|
items += [self.to_features(features)] |
|
|
|
if self.use_context_time or self.use_context_features: |
|
mapping = reduce(torch.stack(items), "n b m -> b m", "sum") |
|
mapping = self.to_mapping(mapping) |
|
return mapping |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
time: Optional[Tensor] = None, |
|
*, |
|
features: Optional[Tensor] = None, |
|
channels_list: Optional[Sequence[Tensor]] = None, |
|
embedding: Optional[Tensor] = None, |
|
embedding_mask: Optional[Tensor] = None, |
|
causal: Optional[bool] = False, |
|
) -> Tensor: |
|
channels = self.get_channels(channels_list, layer=0) |
|
|
|
print(x.shape) |
|
x = self.stft.encode1d(x) if self.use_stft else x |
|
print(x.shape) |
|
|
|
x = torch.cat([x, channels], dim=1) if exists(channels) else x |
|
print(x.shape) |
|
|
|
mapping = self.get_mapping(time, features) |
|
x = self.to_in(x, mapping, causal=causal) |
|
print(x.shape) |
|
skips_list = [x] |
|
|
|
for i, downsample in enumerate(self.downsamples): |
|
channels = self.get_channels(channels_list, layer=i + 1) |
|
x, skips = downsample( |
|
x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal |
|
) |
|
skips_list += [skips] |
|
|
|
x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) |
|
for i, upsample in enumerate(self.upsamples): |
|
skips = skips_list.pop() |
|
x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) |
|
|
|
x += skips_list.pop() |
|
x = self.to_out(x, mapping, causal=causal) |
|
x = self.stft.decode1d(x) if self.use_stft else x |
|
|
|
return x |
|
|
|
|
|
""" Conditioning Modules """ |
|
|
|
|
|
class FixedEmbedding(nn.Module): |
|
def __init__(self, max_length: int, features: int): |
|
super().__init__() |
|
self.max_length = max_length |
|
self.embedding = nn.Embedding(max_length, features) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
batch_size, length, device = *x.shape[0:2], x.device |
|
assert_message = "Input sequence length must be <= max_length" |
|
assert length <= self.max_length, assert_message |
|
position = torch.arange(length, device=device) |
|
fixed_embedding = self.embedding(position) |
|
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) |
|
return fixed_embedding |
|
|
|
|
|
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: |
|
if proba == 1: |
|
return torch.ones(shape, device=device, dtype=torch.bool) |
|
elif proba == 0: |
|
return torch.zeros(shape, device=device, dtype=torch.bool) |
|
else: |
|
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) |
|
|
|
|
|
class UNetCFG1d(UNet1d): |
|
|
|
"""UNet1d with Classifier-Free Guidance""" |
|
|
|
def __init__( |
|
self, |
|
context_embedding_max_length: int, |
|
context_embedding_features: int, |
|
use_xattn_time: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
context_embedding_features=context_embedding_features, **kwargs |
|
) |
|
|
|
self.use_xattn_time = use_xattn_time |
|
|
|
if use_xattn_time: |
|
assert exists(context_embedding_features) |
|
self.to_time_embedding = nn.Sequential( |
|
TimePositionalEmbedding( |
|
dim=kwargs["channels"], out_features=context_embedding_features |
|
), |
|
nn.GELU(), |
|
) |
|
|
|
context_embedding_max_length += 1 |
|
|
|
self.fixed_embedding = FixedEmbedding( |
|
max_length=context_embedding_max_length, features=context_embedding_features |
|
) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
time: Tensor, |
|
*, |
|
embedding: Tensor, |
|
embedding_mask: Optional[Tensor] = None, |
|
embedding_scale: float = 1.0, |
|
embedding_mask_proba: float = 0.0, |
|
batch_cfg: bool = False, |
|
rescale_cfg: bool = False, |
|
scale_phi: float = 0.4, |
|
negative_embedding: Optional[Tensor] = None, |
|
negative_embedding_mask: Optional[Tensor] = None, |
|
**kwargs, |
|
) -> Tensor: |
|
b, device = embedding.shape[0], embedding.device |
|
|
|
if self.use_xattn_time: |
|
embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) |
|
|
|
if embedding_mask is not None: |
|
embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) |
|
|
|
fixed_embedding = self.fixed_embedding(embedding) |
|
|
|
if embedding_mask_proba > 0.0: |
|
|
|
batch_mask = rand_bool( |
|
shape=(b, 1, 1), proba=embedding_mask_proba, device=device |
|
) |
|
embedding = torch.where(batch_mask, fixed_embedding, embedding) |
|
|
|
if embedding_scale != 1.0: |
|
if batch_cfg: |
|
batch_x = torch.cat([x, x], dim=0) |
|
batch_time = torch.cat([time, time], dim=0) |
|
|
|
if negative_embedding is not None: |
|
if negative_embedding_mask is not None: |
|
negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) |
|
|
|
negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) |
|
|
|
batch_embed = torch.cat([embedding, negative_embedding], dim=0) |
|
|
|
else: |
|
batch_embed = torch.cat([embedding, fixed_embedding], dim=0) |
|
|
|
batch_mask = None |
|
if embedding_mask is not None: |
|
batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) |
|
|
|
batch_features = None |
|
features = kwargs.pop("features", None) |
|
if self.use_context_features: |
|
batch_features = torch.cat([features, features], dim=0) |
|
|
|
batch_channels = None |
|
channels_list = kwargs.pop("channels_list", None) |
|
if self.use_context_channels: |
|
batch_channels = [] |
|
for channels in channels_list: |
|
batch_channels += [torch.cat([channels, channels], dim=0)] |
|
|
|
|
|
batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) |
|
out, out_masked = batch_out.chunk(2, dim=0) |
|
|
|
else: |
|
|
|
out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) |
|
out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) |
|
|
|
out_cfg = out_masked + (out - out_masked) * embedding_scale |
|
|
|
if rescale_cfg: |
|
|
|
out_std = out.std(dim=1, keepdim=True) |
|
out_cfg_std = out_cfg.std(dim=1, keepdim=True) |
|
|
|
return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg |
|
|
|
else: |
|
|
|
return out_cfg |
|
|
|
else: |
|
return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) |
|
|
|
|
|
class UNetNCCA1d(UNet1d): |
|
|
|
"""UNet1d with Noise Channel Conditioning Augmentation""" |
|
|
|
def __init__(self, context_features: int, **kwargs): |
|
super().__init__(context_features=context_features, **kwargs) |
|
self.embedder = NumberEmbedder(features=context_features) |
|
|
|
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: |
|
x = x if torch.is_tensor(x) else torch.tensor(x) |
|
return x.expand(shape) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
time: Tensor, |
|
*, |
|
channels_list: Sequence[Tensor], |
|
channels_augmentation: Union[ |
|
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor |
|
] = False, |
|
channels_scale: Union[ |
|
float, Sequence[float], Sequence[Sequence[float]], Tensor |
|
] = 0, |
|
**kwargs, |
|
) -> Tensor: |
|
b, n = x.shape[0], len(channels_list) |
|
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) |
|
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) |
|
|
|
|
|
for i in range(n): |
|
scale = channels_scale[:, i] * channels_augmentation[:, i] |
|
scale = rearrange(scale, "b -> b 1 1") |
|
item = channels_list[i] |
|
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) |
|
|
|
|
|
channels_scale_emb = self.embedder(channels_scale) |
|
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") |
|
|
|
return super().forward( |
|
x=x, |
|
time=time, |
|
channels_list=channels_list, |
|
features=channels_scale_emb, |
|
**kwargs, |
|
) |
|
|
|
|
|
class UNetAll1d(UNetCFG1d, UNetNCCA1d): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, *args, **kwargs): |
|
return UNetCFG1d.forward(self, *args, **kwargs) |
|
|
|
|
|
def XUNet1d(type: str = "base", **kwargs) -> UNet1d: |
|
if type == "base": |
|
return UNet1d(**kwargs) |
|
elif type == "all": |
|
return UNetAll1d(**kwargs) |
|
elif type == "cfg": |
|
return UNetCFG1d(**kwargs) |
|
elif type == "ncca": |
|
return UNetNCCA1d(**kwargs) |
|
else: |
|
raise ValueError(f"Unknown XUNet1d type: {type}") |
|
|
|
class NumberEmbedder(nn.Module): |
|
def __init__( |
|
self, |
|
features: int, |
|
dim: int = 256, |
|
): |
|
super().__init__() |
|
self.features = features |
|
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) |
|
|
|
def forward(self, x: Union[List[float], Tensor]) -> Tensor: |
|
if not torch.is_tensor(x): |
|
device = next(self.embedding.parameters()).device |
|
x = torch.tensor(x, device=device) |
|
assert isinstance(x, Tensor) |
|
shape = x.shape |
|
x = rearrange(x, "... -> (...)") |
|
embedding = self.embedding(x) |
|
x = embedding.view(*shape, self.features) |
|
return x |
|
|
|
|
|
""" |
|
Audio Transforms |
|
""" |
|
|
|
|
|
class STFT(nn.Module): |
|
"""Helper for torch stft and istft""" |
|
|
|
def __init__( |
|
self, |
|
num_fft: int = 1023, |
|
hop_length: int = 256, |
|
window_length: Optional[int] = None, |
|
length: Optional[int] = None, |
|
use_complex: bool = False, |
|
): |
|
super().__init__() |
|
self.num_fft = num_fft |
|
self.hop_length = default(hop_length, floor(num_fft // 4)) |
|
self.window_length = default(window_length, num_fft) |
|
self.length = length |
|
self.register_buffer("window", torch.hann_window(self.window_length)) |
|
self.use_complex = use_complex |
|
|
|
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: |
|
b = wave.shape[0] |
|
wave = rearrange(wave, "b c t -> (b c) t") |
|
|
|
stft = torch.stft( |
|
wave, |
|
n_fft=self.num_fft, |
|
hop_length=self.hop_length, |
|
win_length=self.window_length, |
|
window=self.window, |
|
return_complex=True, |
|
normalized=True, |
|
) |
|
|
|
if self.use_complex: |
|
|
|
stft_a, stft_b = stft.real, stft.imag |
|
else: |
|
|
|
magnitude, phase = torch.abs(stft), torch.angle(stft) |
|
stft_a, stft_b = magnitude, phase |
|
|
|
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) |
|
|
|
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: |
|
b, l = stft_a.shape[0], stft_a.shape[-1] |
|
length = closest_power_2(l * self.hop_length) |
|
|
|
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") |
|
|
|
if self.use_complex: |
|
real, imag = stft_a, stft_b |
|
else: |
|
magnitude, phase = stft_a, stft_b |
|
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) |
|
|
|
stft = torch.stack([real, imag], dim=-1) |
|
|
|
wave = torch.istft( |
|
stft, |
|
n_fft=self.num_fft, |
|
hop_length=self.hop_length, |
|
win_length=self.window_length, |
|
window=self.window, |
|
length=default(self.length, length), |
|
normalized=True, |
|
) |
|
|
|
return rearrange(wave, "(b c) t -> b c t", b=b) |
|
|
|
def encode1d( |
|
self, wave: Tensor, stacked: bool = True |
|
) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
|
stft_a, stft_b = self.encode(wave) |
|
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") |
|
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) |
|
|
|
def decode1d(self, stft_pair: Tensor) -> Tensor: |
|
f = self.num_fft // 2 + 1 |
|
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) |
|
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) |
|
return self.decode(stft_a, stft_b) |
|
|