Spaces:
Sleeping
Sleeping
""" | |
Module containing the main VAE class. | |
""" | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from disvae.utils.initialization import weights_init | |
from .encoders import get_encoder | |
from .decoders import get_decoder | |
MODELS = ["Burgess"] | |
def init_specific_model(model_type, img_size, latent_dim): | |
"""Return an instance of a VAE with encoder and decoder from `model_type`.""" | |
model_type = model_type.lower().capitalize() | |
if model_type not in MODELS: | |
err = "Unkown model_type={}. Possible values: {}" | |
raise ValueError(err.format(model_type, MODELS)) | |
encoder = get_encoder(model_type) | |
decoder = get_decoder(model_type) | |
model = VAE(img_size, encoder, decoder, latent_dim) | |
model.model_type = model_type # store to help reloading | |
return model | |
class VAE(nn.Module): | |
def __init__(self, img_size, encoder, decoder, latent_dim): | |
""" | |
Class which defines model and forward pass. | |
Parameters | |
---------- | |
img_size : tuple of ints | |
Size of images. E.g. (1, 32, 32) or (3, 64, 64). | |
""" | |
super(VAE, self).__init__() | |
if list(img_size[1:]) not in [[32, 32], [64, 64]]: | |
raise RuntimeError("{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(img_size)) | |
self.latent_dim = latent_dim | |
self.img_size = img_size | |
self.num_pixels = self.img_size[1] * self.img_size[2] | |
self.encoder = encoder(img_size, self.latent_dim) | |
self.decoder = decoder(img_size, self.latent_dim) | |
self.reset_parameters() | |
def reparameterize(self, mean, logvar): | |
""" | |
Samples from a normal distribution using the reparameterization trick. | |
Parameters | |
---------- | |
mean : torch.Tensor | |
Mean of the normal distribution. Shape (batch_size, latent_dim) | |
logvar : torch.Tensor | |
Diagonal log variance of the normal distribution. Shape (batch_size, | |
latent_dim) | |
""" | |
if self.training: | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return mean + std * eps | |
else: | |
# Reconstruction mode | |
return mean | |
def forward(self, x): | |
""" | |
Forward pass of model. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Batch of data. Shape (batch_size, n_chan, height, width) | |
""" | |
latent_dist = self.encoder(x) | |
latent_sample = self.reparameterize(*latent_dist) | |
reconstruct = self.decoder(latent_sample) | |
return reconstruct, latent_dist, latent_sample | |
def reset_parameters(self): | |
self.apply(weights_init) | |
def sample_latent(self, x): | |
""" | |
Returns a sample from the latent distribution. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Batch of data. Shape (batch_size, n_chan, height, width) | |
""" | |
latent_dist = self.encoder(x) | |
latent_sample = self.reparameterize(*latent_dist) | |
return latent_sample | |