emage_evaltools / decoders.py
H-Liu1997's picture
Upload 8 files
a92b3a0 verified
# This script is modified from https://github.com/EricGuo5513/TM2T
# Licensed under: https://github.com/EricGuo5513/TM2T/blob/main/LICENSE
import torch.nn as nn
class VQDecoderV3(nn.Module):
def __init__(self, args):
super(VQDecoderV3, self).__init__()
n_up = args.vae_layer
channels = []
for i in range(n_up - 1):
channels.append(args.vae_length)
channels.append(args.vae_length)
channels.append(args.vae_test_dim)
input_size = args.vae_length
n_resblk = 2
assert len(channels) == n_up + 1
if input_size == channels[0]:
layers = []
else:
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
for i in range(n_resblk):
layers += [ResBlock(channels[0])]
# channels = channels
for i in range(n_up):
layers += [
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
]
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
self.main = nn.Sequential(*layers)
# self.main.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0, 2, 1)
outputs = self.main(inputs).permute(0, 2, 1)
return outputs
class ResBlock(nn.Module):
def __init__(self, channel):
super(ResBlock, self).__init__()
self.model = nn.Sequential(
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
)
def forward(self, x):
residual = x
out = self.model(x)
out += residual
return out