metadata
tags:
- model_hub_mixin
- pytorch_model_hub_mixin
datasets:
- zalando-datasets/fashion_mnist
metrics:
- accuracy
library_name: pytorch
pipeline_tag: image-classification
mlp-fashion-mnist
A multi-layer perceptron (MLP) trained on the Fashion-MNIST dataset.
It is a PyTorch adaptation of the TensorFlow model in Chapter 10 of Aurelien Geron's book 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow'.
Code: https://github.com/sambitmukherjee/handson-ml3-pytorch/blob/main/chapter10/mlp_fashion_mnist.ipynb
Experiment tracking: https://wandb.ai/sadhaklal/mlp-fashion-mnist
Usage
!pip install -q datasets
from datasets import load_dataset
fashion_mnist = load_dataset("zalando-datasets/fashion_mnist")
features = fashion_mnist['train'].features
id2label = {id: label for id, label in enumerate(features['label'].names)}
import torch
import torchvision.transforms.v2 as v2
tfms = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True)
])
device = torch.device("cpu")
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class MLP(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
act = torch.relu(self.fc1(x))
act = torch.relu(self.fc2(act))
return self.fc3(act)
model = MLP.from_pretrained("sadhaklal/mlp-fashion-mnist")
model.to(device)
example = fashion_mnist['test'][0]
import matplotlib.pyplot as plt
plt.imshow(example['image'], cmap='gray')
print(f"Ground truth: {id2label[example['label']]}")
img = tfms(example['image'])
x_batch = img.unsqueeze(0)
model.eval()
x_batch = x_batch.to(device)
with torch.no_grad():
logits = model(x_batch)
proba = torch.softmax(logits, dim=-1)
confidence, pred = proba.max(dim=-1)
print(f"Predicted class: {id2label[pred[0].item()]}")
print(f"Predicted confidence: {round(confidence[0].item(), 4)}")
Metric
Accuracy on the test set: 0.8829
This model has been pushed to the Hub using the PyTorchModelHubMixin integration.