ImageGenerationVAE: Variational Autoencoder for MNIST Image Generation

Model Details

  • Model Architecture: Variational Autoencoder (VAE)
  • Framework: PyTorch
  • Input Shape: (1, 28, 28) (Grayscale MNIST Images)
  • Latent Dimension: 200
  • Dataset: MNIST Handwritten Digits

Model Description

The ImageGenerationVAE model is a variational autoencoder (VAE) trained to generate handwritten digit images from the MNIST dataset. Unlike a standard autoencoder, this model learns a probabilistic latent representation, allowing it to generate diverse samples from the learned distribution.

This model can be used for:

  • Image generation
  • Feature learning
  • Latent space interpolation
  • Anomaly detection

Training Details

  • Loss Function: Smooth L1 Loss + KL Divergence
  • Optimizer: Adam
  • Batch Size: 512
  • Number of Epochs: TBD
  • Regularization: Batch Normalization

Model Architecture

import torch

class ImageGenerationVAE(nn.Module, PyTorchModelHubMixin):
    def __init__(self, hidden_dim=3000):
        super(ImageGenerationVAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Flatten(),
        )

        self.mu_nn = nn.Linear(32 * 7 * 7, hidden_dim)
        self.sigma_nn = nn.Linear(32 * 7 * 7, hidden_dim)

        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 32 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2),
            nn.Sigmoid(),
        )

        self.kl = 0

    def forward(self, x):
        # encoding head
        x = self.encoder(x)
        mu = self.mu_nn(x)
        log_var = self.sigma_nn(x)
        
        std = torch.exp(log_var)  # More stable than full exp
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon

        # decoding head
        output = self.decoder(z)

        # KL divergence
        self.kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        self.kl = self.kl / x.size(0)  # normalize by batch size

        return output

This model has been pushed to the Hub using the PytorchModelHubMixin integration:

  • Library: [More Information Needed]
  • Docs: [More Information Needed]
Downloads last month
31
Safetensors
Model size
39.2k params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.