File size: 7,778 Bytes
d72b2c3 54adc39 d72b2c3 0a8807e d72b2c3 e70ad00 d72b2c3 d8e2a3d d72b2c3 d8e2a3d d72b2c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import numpy as np
import torch.nn as nn
from .conv import StreamableConv1d, StreamableConvTranspose1d
class StreamableLSTM(nn.Module):
"""LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
def forward(self, x):
print('LSTM called 1c')
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output.
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection.
"""
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
super().__init__()
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetDecoder(nn.Module):
def __init__(self, channels: int = 1,
dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU',
activation_params: dict = {'alpha': 1.0},
final_activation: tp.Optional[str] = None,
final_activation_params: tp.Optional[dict] = None,
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {},
kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3,
dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = True,
compress: int = 2, lstm: int = 0,
disable_norm_outer_blocks: int = 0,
trim_right_ratio: float = 1.0):
super().__init__()
self.dimension = dimension
self.channels = channels
self.n_filters = n_filters
self.ratios = ratios
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
"Number of blocks for which to disable norm is invalid." \
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
act = getattr(nn, activation)
mult = int(2 ** len(self.ratios))
model: tp.List[nn.Module] = [
StreamableConv1d(dimension, mult * n_filters, kernel_size,
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
if lstm:
print('\n\n\n\nLSTM IN SEANET\n\n\n\n')
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
# Upsample to raw audio scale
for i, ratio in enumerate(self.ratios):
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
# Add upsampling layers
model += [
act(**activation_params),
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
kernel_size=ratio * 2, stride=ratio,
norm=block_norm, norm_kwargs=norm_params,
causal=causal, trim_right_ratio=trim_right_ratio),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
activation=activation, activation_params=activation_params,
norm=block_norm, norm_params=norm_params, causal=causal,
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
mult //= 2
# Add final layers
model += [
act(**activation_params),
StreamableConv1d(n_filters, channels, last_kernel_size,
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
# Add optional final activation to decoder (eg. tanh)
if final_activation is not None:
final_act = getattr(nn, final_activation)
final_activation_params = final_activation_params or {}
model += [
final_act(**final_activation_params)
]
self.model = nn.Sequential(*model)
def forward(self, z):
print(f'\n Enter seanet with shape {z.shape}\n') # arrives here with (1,128,35)
# how can this convnet care for the value that is in z so it crashes?
y = self.model(z)
print(f'\n Exit seanet with shape {y.shape}\n') # arrives here with (1,128,35)
return y
|