Jonas Becker
1st try
7f19394
"""
Module containing the main VAE class.
"""
import torch
from torch import nn, optim
from torch.nn import functional as F
from disvae.utils.initialization import weights_init
from .encoders import get_encoder
from .decoders import get_decoder
MODELS = ["Burgess"]
def init_specific_model(model_type, img_size, latent_dim):
"""Return an instance of a VAE with encoder and decoder from `model_type`."""
model_type = model_type.lower().capitalize()
if model_type not in MODELS:
err = "Unkown model_type={}. Possible values: {}"
raise ValueError(err.format(model_type, MODELS))
encoder = get_encoder(model_type)
decoder = get_decoder(model_type)
model = VAE(img_size, encoder, decoder, latent_dim)
model.model_type = model_type # store to help reloading
return model
class VAE(nn.Module):
def __init__(self, img_size, encoder, decoder, latent_dim):
"""
Class which defines model and forward pass.
Parameters
----------
img_size : tuple of ints
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
"""
super(VAE, self).__init__()
if list(img_size[1:]) not in [[32, 32], [64, 64]]:
raise RuntimeError("{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(img_size))
self.latent_dim = latent_dim
self.img_size = img_size
self.num_pixels = self.img_size[1] * self.img_size[2]
self.encoder = encoder(img_size, self.latent_dim)
self.decoder = decoder(img_size, self.latent_dim)
self.reset_parameters()
def reparameterize(self, mean, logvar):
"""
Samples from a normal distribution using the reparameterization trick.
Parameters
----------
mean : torch.Tensor
Mean of the normal distribution. Shape (batch_size, latent_dim)
logvar : torch.Tensor
Diagonal log variance of the normal distribution. Shape (batch_size,
latent_dim)
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + std * eps
else:
# Reconstruction mode
return mean
def forward(self, x):
"""
Forward pass of model.
Parameters
----------
x : torch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
"""
latent_dist = self.encoder(x)
latent_sample = self.reparameterize(*latent_dist)
reconstruct = self.decoder(latent_sample)
return reconstruct, latent_dist, latent_sample
def reset_parameters(self):
self.apply(weights_init)
def sample_latent(self, x):
"""
Returns a sample from the latent distribution.
Parameters
----------
x : torch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
"""
latent_dist = self.encoder(x)
latent_sample = self.reparameterize(*latent_dist)
return latent_sample