Model Card: Time-Conditioned U-Net for MNIST

Model Details

  • Architecture: Time-Conditioned U-Net
  • Dataset: Comic Faces Paired Synthetic
  • Batch Size: 256
  • Image Size: 28x28
  • Loss Function: Mean Squared Error (MSE)
  • Optimizer: Adam (learning rate = 1e-4)

Model Architecture

This model is a U-Net-based neural network that incorporates time conditioning using sinusoidal embeddings and an MLP. The architecture is designed for small grayscale images (e.g., MNIST) and consists of:

Encoder (Contracting Path):

  • Downsampling using three DoubleConv layers with 32, 64, and 128 channels, respectively.
  • Time embedding added at each convolution block.
  • Max pooling used to reduce spatial dimensions.

Decoder (Expanding Path):

  • Upsampling via bilinear interpolation.
  • Skip connections from encoder layers to corresponding decoder layers.
  • Two DoubleConv layers with 128+64 and 64+32 channels, respectively.
  • Final 1x1 convolution to map to the output.

Time Embedding:

  • Uses a sinusoidal positional encoding to represent timestep information.
  • An MLP refines the embedding before passing it to convolutional layers.

Implementation

Generator (U-Net)

class UNet(nn.Module, PyTorchModelHubMixin):
    def __init__(self, in_channels=1, out_channels=1, time_embedding_dim=32):
        super(UNet, self).__init__()

        # Time embedding layer
        self.time_embedding = TimeEmbedding(time_embedding_dim)

        # Encoder
        self.down_conv1 = DoubleConv(in_channels, 32, time_embedding_dim)
        self.down_conv2 = DoubleConv(32, 64, time_embedding_dim)
        self.down_conv3 = DoubleConv(64, 128, time_embedding_dim)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        # Decoder
        self.up_conv2 = DoubleConv(128 + 64, 64, time_embedding_dim)
        self.up_conv1 = DoubleConv(64 + 32, 32, time_embedding_dim)
        self.final_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x, timesteps):
        t = self.time_embedding(timesteps)

        x1 = self.down_conv1(x, t)
        x2 = self.down_conv2(self.maxpool(x1), t)
        x3 = self.down_conv3(self.maxpool(x2), t)

        x = self.upsample(x3)
        x = torch.cat([x2, x], dim=1)
        x = self.up_conv2(x, t)

        x = self.upsample(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_conv1(x, t)

        return self.final_conv(x)

Time Embedding

class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embedding_dim, embedding_dim),
        )

    def forward(self, t):
        half_dim = self.embedding_dim // 2
        embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -(torch.log(torch.tensor(10000.0)) / (half_dim - 1)))
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return self.mlp(embeddings)

Training Configuration

  • Batch Size: 256
  • Image Size: 28x28
  • Loss Function: Mean Squared Error (MSE)
  • Optimizer: Adam (learning rate = 1e-4)

This model has been pushed to the Hub using the PytorchModelHubMixin integration:

  • Library: [More Information Needed]
  • Docs: [More Information Needed]
Downloads last month
12
Safetensors
Model size
485k params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.