{ "cells": [ { "cell_type": "code", "execution_count": 32, "id": "1efa9df0-5f50-415c-b574-fae1236cb2b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "from torch.utils.data import Dataset, DataLoader\n", "from torchvision.datasets import MNIST, CIFAR10\n", "from torchvision import datasets, transforms\n", "\n", "\n", "dataset = CIFAR10(root='./data/', train=True, download=True, transform=\n", "transforms.Compose([\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32, padding=4),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", "]))\n", "batch_size = 192 # 96\n", "train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n" ] }, { "cell_type": "code", "execution_count": 36, "id": "e9254b0b-0b70-4d87-b9eb-62a37212ba5a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([3, 32, 32])\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img = next(iter(train_loader))[0]\n", "print(img[0].shape)\n", "\n", "\n", "from matplotlib import pyplot as plt\n", "plt.imshow(img[0].permute(1,2 , 0), interpolation='nearest')\n", "plt.show()\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }