cydxg's picture
Upload 73 files
0a948c1 verified
raw
history blame
11.5 kB
from functools import reduce
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.backends.cuda import sdp_kernel
from packaging import version
from dac.nn.layers import Snake1d
class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
class ResConvBlock(ResidualBlock):
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
super().__init__([
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_mid),
Snake1d(c_mid) if use_snake else nn.GELU(),
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
], skip)
class SelfAttention1d(nn.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm = nn.GroupNorm(1, c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(self, input):
n, c, s = input.shape
qkv = self.qkv_proj(self.norm(input))
qkv = qkv.view(
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3]**-0.25
if self.use_flash:
with sdp_kernel(*self.sdp_kernel_config):
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
else:
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
return input + self.dropout(self.out_proj(y))
class SkipBlock(nn.Module):
def __init__(self, *main):
super().__init__()
self.main = nn.Sequential(*main)
def forward(self, input):
return torch.cat([self.main(input), input], dim=1)
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.weight = nn.Parameter(torch.randn(
[out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
def expand_to_planes(input, shape):
return input[..., None].repeat([1, 1, shape[2]])
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
class Downsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv1d(x, weight, stride=2)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
class Upsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
def Downsample1d_2(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor,
padding=factor * (kernel_multiplier // 2),
)
def Upsample1d_2(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
)
else:
return nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor,
padding=factor // 2 + factor % 2,
output_padding=factor % 2,
)
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
#rms_norm = torch.compile(rms_norm)
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
def normalize(x, eps=1e-4):
dim = list(range(1, x.ndim))
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
alpha = np.sqrt(n.numel() / x.numel())
return x / torch.add(eps, n, alpha=alpha)
class ForcedWNConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1):
super().__init__()
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
def forward(self, x):
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(self.weight))
fan_in = self.weight[0].numel()
w = normalize(self.weight) / math.sqrt(fan_in)
return F.conv1d(x, w, padding='same')
# Kernels
use_compile = True
def compile(function, *args, **kwargs):
if not use_compile:
return function
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
@compile
def linear_geglu(x, weight, bias=None):
x = x @ weight.mT
if bias is not None:
x = x + bias
x, gate = x.chunk(2, dim=-1)
return x * F.gelu(gate)
@compile
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
# Layers
class LinearGEGLU(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features * 2, bias=bias)
self.out_features = out_features
def forward(self, x):
return linear_geglu(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, shape, fix_scale = False, eps=1e-6):
super().__init__()
self.eps = eps
if fix_scale:
self.register_buffer("scale", torch.ones(shape))
else:
self.scale = nn.Parameter(torch.ones(shape))
def extra_repr(self):
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
def forward(self, x):
return rms_norm(x, self.scale, self.eps)
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# try:
# snake_beta = torch.compile(snake_beta)
# except RuntimeError:
# pass
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
# License available in LICENSES/LICENSE_NVIDIA.txt
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x