hhguo's picture
update
37ced70
raw
history blame
13.2 kB
import typing as tp
import torch
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
from fireredtts.modules.bigvgan.alias_free_torch import (
Activation1d as TorchActivation1d,
)
from fireredtts.modules.bigvgan.activations import Snake, SnakeBeta
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(
self,
channels,
kernel_size=3,
dilation=(1, 3, 5),
activation=None,
snake_logscale=True,
use_cuda_kernel=False,
):
super(AMPBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(
self.convs2
) # total number of conv layers
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
if use_cuda_kernel:
from modules.bigvgan.alias_free_cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
if (
activation == "snake"
): # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList(
[
Activation1d(
activation=Snake(channels, alpha_logscale=snake_logscale)
)
for _ in range(self.num_layers)
]
)
elif (
activation == "snakebeta"
): # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList(
[
Activation1d(
activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
)
for _ in range(self.num_layers)
]
)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class AMPBlock2(torch.nn.Module):
def __init__(
self,
channels,
kernel_size=3,
dilation=(1, 3),
activation=None,
snake_logscale=True,
use_cuda_kernel=False,
):
super(AMPBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
if use_cuda_kernel:
from modules.bigvgan.alias_free_cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
if (
activation == "snake"
): # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList(
[
Activation1d(
activation=Snake(channels, alpha_logscale=snake_logscale)
)
for _ in range(self.num_layers)
]
)
elif (
activation == "snakebeta"
): # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList(
[
Activation1d(
activation=SnakeBeta(channels, alpha_logscale=snake_logscale)
)
for _ in range(self.num_layers)
]
)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class BigVGAN(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(
self,
num_mels: int,
upsample_initial_channel: int,
resblock_kernel_sizes: tp.List[int],
resblock_dilation_sizes: tp.List[tp.List[int]],
upsample_rates: tp.List[int],
upsample_kernel_sizes: tp.List[int],
resblock_type: str = "1",
snake_logscale: bool = True,
activation: str = "snakebeta",
use_tanh_at_final: bool = False,
use_bias_at_final: bool = False,
use_cuda_kernel: bool = False,
):
super(BigVGAN, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
# pre conv
self.conv_pre = weight_norm(
Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3)
)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if resblock_type == "1" else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList(
[
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
]
)
)
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(
resblock(
ch,
k,
d,
activation=activation,
snake_logscale=snake_logscale,
use_cuda_kernel=use_cuda_kernel,
)
)
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
if use_cuda_kernel:
from modules.bigvgan.alias_free_cuda.activation1d import (
Activation1d as CudaActivation1d,
)
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
# post conv
if (
activation == "snake"
): # periodic nonlinearity with snake function and anti-aliasing
activation_post = Snake(ch, alpha_logscale=snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif (
activation == "snakebeta"
): # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = SnakeBeta(ch, alpha_logscale=snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
# whether to use bias for the final conv_post. Defaults to True for backward compatibility
self.use_bias_at_final = use_bias_at_final
self.conv_post = weight_norm(
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
)
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
# final tanh activation. Defaults to True for backward compatibility
self.use_tanh_at_final = use_tanh_at_final
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
# final tanh activation
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, min=-1.0, max=1.0) # bound the output to [-1, 1]
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)