ANYANTUDRE
initial set up
78d1101
raw
history blame
2.62 kB
import torch
from torch import nn
from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder
from TTS.tts.utils.helpers import sequence_mask
class Decoder(nn.Module):
"""Uses glow decoder with some modifications.
::
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
Args:
in_channels (int): channels of input tensor.
hidden_channels (int): hidden decoder channels.
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.)
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_flow_blocks (int): number of decoder blocks.
num_coupling_layers (int): number coupling layers. (number of wavenet layers.)
dropout_p (float): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
"""
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p=0.0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
c_in_channels=0,
):
super().__init__()
self.glow_decoder = GlowDecoder(
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p,
num_splits,
num_squeeze,
sigmoid_scale,
c_in_channels,
)
self.n_sqz = num_squeeze
def forward(self, x, x_len, g=None, reverse=False):
"""
Input shapes:
- x: :math:`[B, C, T]`
- x_len :math:`[B]`
- g: :math:`[B, C]`
Output shapes:
- x: :math:`[B, C, T]`
- x_len :math:`[B]`
- logget_tot :math:`[B]`
"""
x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max())
x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype)
x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse)
return x, x_len, logdet_tot
def preprocess(self, y, y_lengths, y_max_length):
if y_max_length is not None:
y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz
y = y[:, :, :y_max_length]
y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz
return y, y_lengths, y_max_length
def store_inverse(self):
self.glow_decoder.store_inverse()