--- tags: - model_hub_mixin - pytorch_model_hub_mixin datasets: - zalando-datasets/fashion_mnist metrics: - accuracy library_name: pytorch pipeline_tag: image-classification --- # mlp-fashion-mnist A multi-layer perceptron (MLP) trained on the Fashion-MNIST dataset. It is a PyTorch adaptation of the TensorFlow model in Chapter 10 of Aurelien Geron's book 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow'. Code: https://github.com/sambitmukherjee/handson-ml3-pytorch/blob/main/chapter10/mlp_fashion_mnist.ipynb Experiment tracking: https://wandb.ai/sadhaklal/mlp-fashion-mnist ## Usage ``` !pip install -q datasets from datasets import load_dataset fashion_mnist = load_dataset("zalando-datasets/fashion_mnist") features = fashion_mnist['train'].features id2label = {id: label for id, label in enumerate(features['label'].names)} import torch import torchvision.transforms.v2 as v2 tfms = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True) ]) device = torch.device("cpu") import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin class MLP(nn.Module, PyTorchModelHubMixin): def __init__(self): super().__init__() self.fc1 = nn.Linear(28 * 28, 300) self.fc2 = nn.Linear(300, 100) self.fc3 = nn.Linear(100, 10) def forward(self, x): x = x.view(-1, 28 * 28) act = torch.relu(self.fc1(x)) act = torch.relu(self.fc2(act)) return self.fc3(act) model = MLP.from_pretrained("sadhaklal/mlp-fashion-mnist") model.to(device) example = fashion_mnist['test'][0] import matplotlib.pyplot as plt plt.imshow(example['image'], cmap='gray') print(f"Ground truth: {id2label[example['label']]}") img = tfms(example['image']) x_batch = img.unsqueeze(0) model.eval() x_batch = x_batch.to(device) with torch.no_grad(): logits = model(x_batch) proba = torch.softmax(logits, dim=-1) confidence, pred = proba.max(dim=-1) print(f"Predicted class: {id2label[pred[0].item()]}") print(f"Predicted confidence: {round(confidence[0].item(), 4)}") ``` ## Metric Accuracy on the test set: 0.8829 --- This model has been pushed to the Hub using the [PyTorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration.