ImageGenerationVAE: Variational Autoencoder for MNIST Image Generation
Model Details
- Model Architecture: Variational Autoencoder (VAE)
- Framework: PyTorch
- Input Shape: (1, 28, 28) (Grayscale MNIST Images)
- Latent Dimension: 200
- Dataset: MNIST Handwritten Digits
Model Description
The ImageGenerationVAE model is a variational autoencoder (VAE) trained to generate handwritten digit images from the MNIST dataset. Unlike a standard autoencoder, this model learns a probabilistic latent representation, allowing it to generate diverse samples from the learned distribution.
This model can be used for:
- Image generation
- Feature learning
- Latent space interpolation
- Anomaly detection
Training Details
- Loss Function: Smooth L1 Loss + KL Divergence
- Optimizer: Adam
- Batch Size: 512
- Number of Epochs: TBD
- Regularization: Batch Normalization
Model Architecture
import torch
class ImageGenerationVAE(nn.Module, PyTorchModelHubMixin):
def __init__(self, hidden_dim=3000):
super(ImageGenerationVAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Flatten(),
)
self.mu_nn = nn.Linear(32 * 7 * 7, hidden_dim)
self.sigma_nn = nn.Linear(32 * 7 * 7, hidden_dim)
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, 32 * 7 * 7),
nn.ReLU(),
nn.Unflatten(1, (32, 7, 7)),
nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2),
nn.Sigmoid(),
)
self.kl = 0
def forward(self, x):
# encoding head
x = self.encoder(x)
mu = self.mu_nn(x)
log_var = self.sigma_nn(x)
std = torch.exp(log_var) # More stable than full exp
epsilon = torch.randn_like(std)
z = mu + std * epsilon
# decoding head
output = self.decoder(z)
# KL divergence
self.kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
self.kl = self.kl / x.size(0) # normalize by batch size
return output
This model has been pushed to the Hub using the PytorchModelHubMixin integration:
- Library: [More Information Needed]
- Docs: [More Information Needed]
- Downloads last month
- 31
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
HF Inference API was unable to determine this model's library.