Spaces:
Sleeping
Sleeping
""" | |
Module containing the decoders. | |
""" | |
import numpy as np | |
import torch | |
from torch import nn | |
# ALL decoders should be called Decoder<Model> | |
def get_decoder(model_type): | |
model_type = model_type.lower().capitalize() | |
return eval("Decoder{}".format(model_type)) | |
class DecoderBurgess(nn.Module): | |
def __init__(self, img_size, | |
latent_dim=10): | |
r"""Decoder of the model proposed in [1]. | |
Parameters | |
---------- | |
img_size : tuple of ints | |
Size of images. E.g. (1, 32, 32) or (3, 64, 64). | |
latent_dim : int | |
Dimensionality of latent output. | |
Model Architecture (transposed for decoder) | |
------------ | |
- 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2) | |
- 2 fully connected layers (each of 256 units) | |
- Latent distribution: | |
- 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians) | |
References: | |
[1] Burgess, Christopher P., et al. "Understanding disentangling in | |
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018). | |
""" | |
super(DecoderBurgess, self).__init__() | |
# Layer parameters | |
hid_channels = 32 | |
kernel_size = 4 | |
hidden_dim = 256 | |
self.img_size = img_size | |
# Shape required to start transpose convs | |
self.reshape = (hid_channels, kernel_size, kernel_size) | |
n_chan = self.img_size[0] | |
self.img_size = img_size | |
# Fully connected layers | |
self.lin1 = nn.Linear(latent_dim, hidden_dim) | |
self.lin2 = nn.Linear(hidden_dim, hidden_dim) | |
self.lin3 = nn.Linear(hidden_dim, np.product(self.reshape)) | |
# Convolutional layers | |
cnn_kwargs = dict(stride=2, padding=1) | |
# If input image is 64x64 do fourth convolution | |
if self.img_size[1] == self.img_size[2] == 64: | |
self.convT_64 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs) | |
self.convT1 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs) | |
self.convT2 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs) | |
self.convT3 = nn.ConvTranspose2d(hid_channels, n_chan, kernel_size, **cnn_kwargs) | |
def forward(self, z): | |
batch_size = z.size(0) | |
# Fully connected layers with ReLu activations | |
x = torch.relu(self.lin1(z)) | |
x = torch.relu(self.lin2(x)) | |
x = torch.relu(self.lin3(x)) | |
x = x.view(batch_size, *self.reshape) | |
# Convolutional layers with ReLu activations | |
if self.img_size[1] == self.img_size[2] == 64: | |
x = torch.relu(self.convT_64(x)) | |
x = torch.relu(self.convT1(x)) | |
x = torch.relu(self.convT2(x)) | |
# Sigmoid activation for final conv layer | |
x = torch.sigmoid(self.convT3(x)) | |
return x | |