{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.14","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from datasets import load_dataset\ndataset = load_dataset(\"mnist\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:22.773326Z","iopub.execute_input":"2024-11-26T03:35:22.773705Z","iopub.status.idle":"2024-11-26T03:35:33.281299Z","shell.execute_reply.started":"2024-11-26T03:35:22.773675Z","shell.execute_reply":"2024-11-26T03:35:33.280473Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/6.97k [00:00, 5)"},"metadata":{}}],"execution_count":5},{"cell_type":"code","source":"import matplotlib.pyplot as plt\nprint(sample_label)\nplt.imshow(sample)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:37.280112Z","iopub.execute_input":"2024-11-26T03:35:37.280422Z","iopub.status.idle":"2024-11-26T03:35:37.521924Z","shell.execute_reply.started":"2024-11-26T03:35:37.280388Z","shell.execute_reply":"2024-11-26T03:35:37.521067Z"}},"outputs":[{"name":"stdout","text":"5\n","output_type":"stream"},{"execution_count":6,"output_type":"execute_result","data":{"text/plain":""},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"
","image/png":"iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcTUlEQVR4nO3df3DU9b3v8dcCyQqaLI0hv0rAgD+wAvEWJWZAxJJLSOc4gIwHf3QGvF4cMXiKaPXGUZHWM2nxjrV6qd7TqURnxB+cEaiO5Y4GE441oQNKGW7blNBY4iEJFSe7IUgIyef+wXXrQgJ+1l3eSXg+Zr4zZPf75vvx69Znv9nNNwHnnBMAAOfYMOsFAADOTwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYGGG9gFP19vbq4MGDSktLUyAQsF4OAMCTc04dHR3Ky8vTsGH9X+cMuAAdPHhQ+fn51ssAAHxDzc3NGjt2bL/PD7gApaWlSZJm6vsaoRTj1QAAfJ1Qtz7QO9H/nvcnaQFat26dnnrqKbW2tqqwsFDPPfecpk+ffta5L7/tNkIpGhEgQAAw6Pz/O4ye7W2UpHwI4fXXX9eqVau0evVqffTRRyosLFRpaakOHTqUjMMBAAahpATo6aef1rJly3TnnXfqO9/5jl544QWNGjVKL774YjIOBwAYhBIeoOPHj2vXrl0qKSn5x0GGDVNJSYnq6upO27+rq0uRSCRmAwAMfQkP0Geffaaenh5lZ2fHPJ6dna3W1tbT9q+srFQoFIpufAIOAM4P5j+IWlFRoXA4HN2am5utlwQAOAcS/im4zMxMDR8+XG1tbTGPt7W1KScn57T9g8GggsFgopcBABjgEn4FlJqaqmnTpqm6ujr6WG9vr6qrq1VcXJzowwEABqmk/BzQqlWrtGTJEl1zzTWaPn26nnnmGXV2durOO+9MxuEAAINQUgK0ePFi/f3vf9fjjz+u1tZWXX311dq6detpH0wAAJy/As45Z72Ir4pEIgqFQpqt+dwJAQAGoROuWzXaonA4rPT09H73M/8UHADg/ESAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYGGG9AGAgCYzw/5/E8DGZSVhJYjQ8eElccz2jer1nxk885D0z6t6A90zr06neMx9d87r3jCR91tPpPVO08QHvmUtX1XvPDAVcAQEATBAgAICJhAfoiSeeUCAQiNkmTZqU6MMAAAa5pLwHdNVVV+m99977x0Hi+L46AGBoS0oZRowYoZycnGT81QCAISIp7wHt27dPeXl5mjBhgu644w4dOHCg3327uroUiURiNgDA0JfwABUVFamqqkpbt27V888/r6amJl1//fXq6Ojoc//KykqFQqHolp+fn+glAQAGoIQHqKysTLfccoumTp2q0tJSvfPOO2pvb9cbb7zR5/4VFRUKh8PRrbm5OdFLAgAMQEn/dMDo0aN1+eWXq7Gxsc/ng8GggsFgspcBABhgkv5zQEeOHNH+/fuVm5ub7EMBAAaRhAfowQcfVG1trT755BN9+OGHWrhwoYYPH67bbrst0YcCAAxiCf8W3KeffqrbbrtNhw8f1pgxYzRz5kzV19drzJgxiT4UAGAQS3iAXnvttUT/lRighl95mfeMC6Z4zxy8YbT3zBfX+d9EUpIyQv5z/1EY340uh5rfHk3znvnZ/5rnPbNjygbvmabuL7xnJOmnbf/VeybvP1xcxzofcS84AIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBE0n8hHQa+ntnfjWvu6ap13jOXp6TGdSycW92ux3vm8eeWes+M6PS/cWfxxhXeM2n/ecJ7RpKCn/nfxHTUzh1xHet8xBUQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHA3bCjYcDCuuV3H8r1nLk9pi+tYQ80DLdd5z/z1SKb3TNXEf/eekaRwr/9dqrOf/TCuYw1k/mcBPrgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNS6ERLa1xzz/3sFu+Zf53X6T0zfM9F3jN/uPc575l4PfnZVO+ZxpJR3jM97S3eM7cX3+s9I0mf/Iv/TIH+ENexcP7iCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHNSBG3jPV13jNj3rrYe6bn8OfeM1dN/m/eM5L0f2e96D3zm3+7wXsmq/1D75l4BOriu0Fogf+/WsAbV0AAABMECABgwjtA27dv10033aS8vDwFAgFt3rw55nnnnB5//HHl5uZq5MiRKikp0b59+xK1XgDAEOEdoM7OThUWFmrdunV9Pr927Vo9++yzeuGFF7Rjxw5deOGFKi0t1bFjx77xYgEAQ4f3hxDKyspUVlbW53POOT3zzDN69NFHNX/+fEnSyy+/rOzsbG3evFm33nrrN1stAGDISOh7QE1NTWptbVVJSUn0sVAopKKiItXV9f2xmq6uLkUikZgNADD0JTRAra2tkqTs7OyYx7Ozs6PPnaqyslKhUCi65efnJ3JJAIAByvxTcBUVFQqHw9GtubnZekkAgHMgoQHKycmRJLW1tcU83tbWFn3uVMFgUOnp6TEbAGDoS2iACgoKlJOTo+rq6uhjkUhEO3bsUHFxcSIPBQAY5Lw/BXfkyBE1NjZGv25qatLu3buVkZGhcePGaeXKlXryySd12WWXqaCgQI899pjy8vK0YMGCRK4bADDIeQdo586duvHGG6Nfr1q1SpK0ZMkSVVVV6aGHHlJnZ6fuvvtutbe3a+bMmdq6dasuuOCCxK0aADDoBZxzznoRXxWJRBQKhTRb8zUikGK9HAxSf/nf18Y3908veM/c+bc53jN/n9nhPaPeHv8ZwMAJ160abVE4HD7j+/rmn4IDAJyfCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYML71zEAg8GVD/8lrrk7p/jf2Xr9+Oqz73SKG24p955Je73eewYYyLgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNSDEk97eG45g4vv9J75sBvvvCe+R9Pvuw9U/HPC71n3Mch7xlJyv/XOv8h5+I6Fs5fXAEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GSnwFb1/+JP3zK1rfuQ988rq/+k9s/s6/xuY6jr/EUm66sIV3jOX/arFe+bEXz/xnsHQwRUQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGAi4Jxz1ov4qkgkolAopNmarxGBFOvlAEnhZlztPZP+00+9Z16d8H+8Z+I16f3/7j1zxZqw90zPvr96z+DcOuG6VaMtCofDSk9P73c/roAAACYIEADAhHeAtm/frptuukl5eXkKBALavHlzzPNLly5VIBCI2ebNm5eo9QIAhgjvAHV2dqqwsFDr1q3rd5958+appaUlur366qvfaJEAgKHH+zeilpWVqays7Iz7BINB5eTkxL0oAMDQl5T3gGpqapSVlaUrrrhCy5cv1+HDh/vdt6urS5FIJGYDAAx9CQ/QvHnz9PLLL6u6ulo/+9nPVFtbq7KyMvX09PS5f2VlpUKhUHTLz89P9JIAAAOQ97fgzubWW2+N/nnKlCmaOnWqJk6cqJqaGs2ZM+e0/SsqKrRq1aro15FIhAgBwHkg6R/DnjBhgjIzM9XY2Njn88FgUOnp6TEbAGDoS3qAPv30Ux0+fFi5ubnJPhQAYBDx/hbckSNHYq5mmpqatHv3bmVkZCgjI0Nr1qzRokWLlJOTo/379+uhhx7SpZdeqtLS0oQuHAAwuHkHaOfOnbrxxhujX3/5/s2SJUv0/PPPa8+ePXrppZfU3t6uvLw8zZ07Vz/5yU8UDAYTt2oAwKDHzUiBQWJ4dpb3zMHFl8Z1rB0P/8J7Zlgc39G/o2mu90x4Zv8/1oGBgZuRAgAGNAIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJhI+K/kBpAcPW2HvGeyn/WfkaRjD53wnhkVSPWe+dUlb3vP/NPCld4zozbt8J5B8nEFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GakgIHemVd7z+y/5QLvmclXf+I9I8V3Y9F4PPf5f/GeGbVlZxJWAgtcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJrgZKfAVgWsme8/85V/8b9z5qxkvec/MuuC498y51OW6vWfqPy/wP1Bvi/8MBiSugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFAPeiILx3jP778yL61hPLH7Ne2bRRZ/FdayB7JG2a7xnan9xnffMt16q857B0MEVEADABAECAJjwClBlZaWuvfZapaWlKSsrSwsWLFBDQ0PMPseOHVN5ebkuvvhiXXTRRVq0aJHa2toSumgAwODnFaDa2lqVl5ervr5e7777rrq7uzV37lx1dnZG97n//vv11ltvaePGjaqtrdXBgwd18803J3zhAIDBzetDCFu3bo35uqqqSllZWdq1a5dmzZqlcDisX//619qwYYO+973vSZLWr1+vK6+8UvX19bruOv83KQEAQ9M3eg8oHA5LkjIyMiRJu3btUnd3t0pKSqL7TJo0SePGjVNdXd+fdunq6lIkEonZAABDX9wB6u3t1cqVKzVjxgxNnjxZktTa2qrU1FSNHj06Zt/s7Gy1trb2+fdUVlYqFApFt/z8/HiXBAAYROIOUHl5ufbu3avXXvP/uYmvqqioUDgcjm7Nzc3f6O8DAAwOcf0g6ooVK/T2229r+/btGjt2bPTxnJwcHT9+XO3t7TFXQW1tbcrJyenz7woGgwoGg/EsAwAwiHldATnntGLFCm3atEnbtm1TQUFBzPPTpk1TSkqKqquro481NDTowIEDKi4uTsyKAQBDgtcVUHl5uTZs2KAtW7YoLS0t+r5OKBTSyJEjFQqFdNddd2nVqlXKyMhQenq67rvvPhUXF/MJOABADK8APf/885Kk2bNnxzy+fv16LV26VJL085//XMOGDdOiRYvU1dWl0tJS/fKXv0zIYgEAQ0fAOeesF/FVkUhEoVBIszVfIwIp1svBGYy4ZJz3THharvfM4h9vPftOp7hn9F+9Zwa6B1r8v4tQ90v/m4pKUkbV7/2HenviOhaGnhOuWzXaonA4rPT09H73415wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBHXb0TFwDUit+/fPHsmn794YVzHWl5Q6z1zW1pbXMcayFb850zvmY+ev9p7JvPf93rPZHTUec8A5wpXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5Geo4cL73Gf+b+z71nHrn0He+ZuSM7vWcGuraeL+Kam/WbB7xnJj36Z++ZjHb/m4T2ek8AAxtXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5Geo58ssC/9X+ZsjEJK0mcde0TvWd+UTvXeybQE/CemfRkk/eMJF3WtsN7pieuIwHgCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBFwzjnrRXxVJBJRKBTSbM3XiECK9XIAAJ5OuG7VaIvC4bDS09P73Y8rIACACQIEADDhFaDKykpde+21SktLU1ZWlhYsWKCGhoaYfWbPnq1AIBCz3XPPPQldNABg8PMKUG1trcrLy1VfX693331X3d3dmjt3rjo7O2P2W7ZsmVpaWqLb2rVrE7poAMDg5/UbUbdu3RrzdVVVlbKysrRr1y7NmjUr+vioUaOUk5OTmBUCAIakb/QeUDgcliRlZGTEPP7KK68oMzNTkydPVkVFhY4ePdrv39HV1aVIJBKzAQCGPq8roK/q7e3VypUrNWPGDE2ePDn6+O23367x48crLy9Pe/bs0cMPP6yGhga9+eabff49lZWVWrNmTbzLAAAMUnH/HNDy5cv129/+Vh988IHGjh3b737btm3TnDlz1NjYqIkTJ572fFdXl7q6uqJfRyIR5efn83NAADBIfd2fA4rrCmjFihV6++23tX379jPGR5KKiookqd8ABYNBBYPBeJYBABjEvALknNN9992nTZs2qaamRgUFBWed2b17tyQpNzc3rgUCAIYmrwCVl5drw4YN2rJli9LS0tTa2ipJCoVCGjlypPbv368NGzbo+9//vi6++GLt2bNH999/v2bNmqWpU6cm5R8AADA4eb0HFAgE+nx8/fr1Wrp0qZqbm/WDH/xAe/fuVWdnp/Lz87Vw4UI9+uijZ/w+4FdxLzgAGNyS8h7Q2VqVn5+v2tpan78SAHCe4l5wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATI6wXcCrnnCTphLolZ7wYAIC3E+qW9I//nvdnwAWoo6NDkvSB3jFeCQDgm+jo6FAoFOr3+YA7W6LOsd7eXh08eFBpaWkKBAIxz0UiEeXn56u5uVnp6elGK7THeTiJ83AS5+EkzsNJA+E8OOfU0dGhvLw8DRvW/zs9A+4KaNiwYRo7duwZ90lPTz+vX2Bf4jycxHk4ifNwEufhJOvzcKYrny/xIQQAgAkCBAAwMagCFAwGtXr1agWDQeulmOI8nMR5OInzcBLn4aTBdB4G3IcQAADnh0F1BQQAGDoIEADABAECAJggQAAAE4MmQOvWrdMll1yiCy64QEVFRfr9739vvaRz7oknnlAgEIjZJk2aZL2spNu+fbtuuukm5eXlKRAIaPPmzTHPO+f0+OOPKzc3VyNHjlRJSYn27dtns9gkOtt5WLp06Wmvj3nz5tksNkkqKyt17bXXKi0tTVlZWVqwYIEaGhpi9jl27JjKy8t18cUX66KLLtKiRYvU1tZmtOLk+DrnYfbs2ae9Hu655x6jFfdtUATo9ddf16pVq7R69Wp99NFHKiwsVGlpqQ4dOmS9tHPuqquuUktLS3T74IMPrJeUdJ2dnSosLNS6dev6fH7t2rV69tln9cILL2jHjh268MILVVpaqmPHjp3jlSbX2c6DJM2bNy/m9fHqq6+ewxUmX21trcrLy1VfX693331X3d3dmjt3rjo7O6P73H///Xrrrbe0ceNG1dbW6uDBg7r55psNV514X+c8SNKyZctiXg9r1641WnE/3CAwffp0V15eHv26p6fH5eXlucrKSsNVnXurV692hYWF1sswJclt2rQp+nVvb6/LyclxTz31VPSx9vZ2FwwG3auvvmqwwnPj1PPgnHNLlixx8+fPN1mPlUOHDjlJrra21jl38t99SkqK27hxY3SfP/3pT06Sq6urs1pm0p16Hpxz7oYbbnA//OEP7Rb1NQz4K6Djx49r165dKikpiT42bNgwlZSUqK6uznBlNvbt26e8vDxNmDBBd9xxhw4cOGC9JFNNTU1qbW2NeX2EQiEVFRWdl6+PmpoaZWVl6YorrtDy5ct1+PBh6yUlVTgcliRlZGRIknbt2qXu7u6Y18OkSZM0bty4If16OPU8fOmVV15RZmamJk+erIqKCh09etRief0acDcjPdVnn32mnp4eZWdnxzyenZ2tP//5z0arslFUVKSqqipdccUVamlp0Zo1a3T99ddr7969SktLs16eidbWVknq8/Xx5XPni3nz5unmm29WQUGB9u/fr0ceeURlZWWqq6vT8OHDrZeXcL29vVq5cqVmzJihyZMnSzr5ekhNTdXo0aNj9h3Kr4e+zoMk3X777Ro/frzy8vK0Z88ePfzww2poaNCbb75puNpYAz5A+IeysrLon6dOnaqioiKNHz9eb7zxhu666y7DlWEguPXWW6N/njJliqZOnaqJEyeqpqZGc+bMMVxZcpSXl2vv3r3nxfugZ9Lfebj77rujf54yZYpyc3M1Z84c7d+/XxMnTjzXy+zTgP8WXGZmpoYPH37ap1ja2tqUk5NjtKqBYfTo0br88svV2NhovRQzX74GeH2cbsKECcrMzBySr48VK1bo7bff1vvvvx/z61tycnJ0/Phxtbe3x+w/VF8P/Z2HvhQVFUnSgHo9DPgApaamatq0aaquro4+1tvbq+rqahUXFxuuzN6RI0e0f/9+5ebmWi/FTEFBgXJycmJeH5FIRDt27DjvXx+ffvqpDh8+PKReH845rVixQps2bdK2bdtUUFAQ8/y0adOUkpIS83poaGjQgQMHhtTr4WznoS+7d++WpIH1erD+FMTX8dprr7lgMOiqqqrcH//4R3f33Xe70aNHu9bWVuulnVMPPPCAq6mpcU1NTe53v/udKykpcZmZme7QoUPWS0uqjo4O9/HHH7uPP/7YSXJPP/20+/jjj93f/vY355xzP/3pT93o0aPdli1b3J49e9z8+fNdQUGB++KLL4xXnlhnOg8dHR3uwQcfdHV1da6pqcm999577rvf/a677LLL3LFjx6yXnjDLly93oVDI1dTUuJaWluh29OjR6D733HOPGzdunNu2bZvbuXOnKy4udsXFxYarTryznYfGxkb34x//2O3cudM1NTW5LVu2uAkTJrhZs2YZrzzWoAiQc84999xzbty4cS41NdVNnz7d1dfXWy/pnFu8eLHLzc11qamp7tvf/rZbvHixa2xstF5W0r3//vtO0mnbkiVLnHMnP4r92GOPuezsbBcMBt2cOXNcQ0OD7aKT4Ezn4ejRo27u3LluzJgxLiUlxY0fP94tW7ZsyP2ftL7++SW59evXR/f54osv3L333uu+9a1vuVGjRrmFCxe6lpYWu0UnwdnOw4EDB9ysWbNcRkaGCwaD7tJLL3U/+tGPXDgctl34Kfh1DAAAEwP+PSAAwNBEgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJj4f4W4/AnknuSPAAAAAElFTkSuQmCC"},"metadata":{}}],"execution_count":6},{"cell_type":"markdown","source":"## Transform Dataset for Training","metadata":{}},{"cell_type":"code","source":"from torchvision import transforms\n\npreprocess = transforms.Compose([\n transforms.ToTensor(),\n transforms.Pad(2), ## send the size becomes 32x32\n ## https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html\n transforms.Normalize([0.5],[0.5]) ## normalize the range into -1 to 1\n])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:37.522925Z","iopub.execute_input":"2024-11-26T03:35:37.523175Z","iopub.status.idle":"2024-11-26T03:35:41.609646Z","shell.execute_reply.started":"2024-11-26T03:35:37.523150Z","shell.execute_reply":"2024-11-26T03:35:41.608888Z"}},"outputs":[],"execution_count":7},{"cell_type":"markdown","source":"Check the shape of data after transformation","metadata":{}},{"cell_type":"code","source":"import torch\nbatch_size = 512\n\ndef transform(examples):\n ## https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.convert\n ## convert PIL Image to L mode (GrayScale)\n images = [preprocess(image.convert(\"L\")) for image in examples[\"image\"]]\n\n return {\"images\":images, \"labels\":examples[\"label\"]}\n\ntrain_dataset = dataset['train'].with_transform(transform)\n\ntrain_dataloader = torch.utils.data.DataLoader(\n train_dataset, batch_size, shuffle=True\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:41.610657Z","iopub.execute_input":"2024-11-26T03:35:41.611079Z","iopub.status.idle":"2024-11-26T03:35:41.619049Z","shell.execute_reply.started":"2024-11-26T03:35:41.611052Z","shell.execute_reply":"2024-11-26T03:35:41.618176Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"batch = next(iter(train_dataloader))\nprint('Shape:', batch['images'].shape,\n '\\nBounds:', batch['images'].min().item(), 'to', batch['images'].max().item())","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:41.620141Z","iopub.execute_input":"2024-11-26T03:35:41.620492Z","iopub.status.idle":"2024-11-26T03:35:41.867613Z","shell.execute_reply.started":"2024-11-26T03:35:41.620464Z","shell.execute_reply":"2024-11-26T03:35:41.866695Z"}},"outputs":[{"name":"stdout","text":"Shape: torch.Size([512, 1, 32, 32]) \nBounds: -1.0 to 1.0\n","output_type":"stream"}],"execution_count":9},{"cell_type":"markdown","source":"## Build the Model","metadata":{}},{"cell_type":"code","source":"from diffusers import UNet2DModel\n\nunet = UNet2DModel(\n in_channels=1,\n out_channels=1,\n sample_size=32,\n block_out_channels=(32,64,128,256),\n norm_num_groups=8,\n num_class_embeds=10\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:41.868891Z","iopub.execute_input":"2024-11-26T03:35:41.869235Z","iopub.status.idle":"2024-11-26T03:35:54.845071Z","shell.execute_reply.started":"2024-11-26T03:35:41.869194Z","shell.execute_reply":"2024-11-26T03:35:54.844383Z"}},"outputs":[{"name":"stderr","text":"The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"0it [00:00, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2c6b13746434472e86d24dfdc207adc2"}},"metadata":{}}],"execution_count":10},{"cell_type":"markdown","source":"Test the inference and the output shape","metadata":{}},{"cell_type":"code","source":"noised_x = torch.randn((1, 1, 32, 32))\nwith torch.no_grad():\n out = unet(noised_x, timestep=7, class_labels=torch.tensor([2])).sample\n\nout.shape","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:54.846131Z","iopub.execute_input":"2024-11-26T03:35:54.846830Z","iopub.status.idle":"2024-11-26T03:35:55.110234Z","shell.execute_reply.started":"2024-11-26T03:35:54.846788Z","shell.execute_reply":"2024-11-26T03:35:55.109407Z"}},"outputs":[{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"torch.Size([1, 1, 32, 32])"},"metadata":{}}],"execution_count":11},{"cell_type":"markdown","source":"## Training","metadata":{}},{"cell_type":"code","source":"import torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom diffusers import DDPMScheduler\n\ndef train(num_epochs=30, lr=1e-4, device=\"cuda\"):\n scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02)\n optimizer = torch.optim.AdamW(unet.parameters(), lr=lr) # The optimizer\n losses = [] # somewhere to store the loss values for later plotting\n unet.to(device)\n\n # Train the model (this takes a while!)\n for epoch in range(num_epochs):\n for step, batch in tqdm(enumerate(train_dataloader)):\n\n # Load the input images\n clean_images = batch[\"images\"].to(device)\n class_labels = batch[\"labels\"].to(device)\n\n # Sample noise to add to the images\n noise = torch.randn(clean_images.shape).to(clean_images.device)\n\n # Sample a random timestep for each image\n timesteps = torch.randint(\n 0,\n scheduler.config.num_train_timesteps,\n (clean_images.shape[0],),\n device=clean_images.device,\n ).long()\n\n # Add noise to the clean images according timestep\n noisy_images = scheduler.add_noise(clean_images, noise, timesteps)\n\n # Get the model prediction for the noise\n noise_pred = unet(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]\n\n # Compare the prediction with the actual noise:\n loss = F.mse_loss(noise_pred, noise)\n losses.append(loss)\n # Store the loss for later plotting\n # Update the model parameters with the optimizer based on this loss\n loss.backward(loss)\n optimizer.step()\n optimizer.zero_grad()\n print(f\"Epoch {epoch}: loss={losses[-1]}\")\n return losses","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:55.111240Z","iopub.execute_input":"2024-11-26T03:35:55.111602Z","iopub.status.idle":"2024-11-26T03:35:55.122148Z","shell.execute_reply.started":"2024-11-26T03:35:55.111575Z","shell.execute_reply":"2024-11-26T03:35:55.121344Z"}},"outputs":[],"execution_count":12},{"cell_type":"code","source":"losses = train()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:55.123319Z","iopub.execute_input":"2024-11-26T03:35:55.123671Z","iopub.status.idle":"2024-11-26T04:31:56.851922Z","shell.execute_reply.started":"2024-11-26T03:35:55.123634Z","shell.execute_reply":"2024-11-26T04:31:56.851052Z"}},"outputs":[{"name":"stderr","text":"118it [01:52, 1.04it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 0: loss=0.1325972080230713\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 1: loss=0.09378191083669662\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 2: loss=0.07209588587284088\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 3: loss=0.05439606308937073\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 4: loss=0.06066245958209038\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 5: loss=0.04885260760784149\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 6: loss=0.0416167750954628\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 7: loss=0.047721683979034424\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 8: loss=0.033292364329099655\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 9: loss=0.045422039926052094\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 10: loss=0.03524807095527649\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 11: loss=0.03403984382748604\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 12: loss=0.030451234430074692\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 13: loss=0.027445441111922264\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 14: loss=0.0382767878472805\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 15: loss=0.0306419488042593\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 16: loss=0.02459515444934368\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 17: loss=0.023863770067691803\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 18: loss=0.022374501451849937\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 19: loss=0.02972579002380371\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 20: loss=0.022356227040290833\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 21: loss=0.022434819489717484\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 22: loss=0.029154803603887558\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 23: loss=0.024483010172843933\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 24: loss=0.024230940267443657\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 25: loss=0.027546880766749382\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 26: loss=0.02587004564702511\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 27: loss=0.020630789920687675\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 28: loss=0.01809917762875557\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]","output_type":"stream"},{"name":"stdout","text":"Epoch 29: loss=0.015931110829114914\n","output_type":"stream"},{"name":"stderr","text":"\n","output_type":"stream"}],"execution_count":13},{"cell_type":"code","source":"from kaggle_secrets import UserSecretsClient\nuser_secrets = UserSecretsClient()\ntoken = user_secrets.get_secret(\"HF_TOKEN\")\n\nunet.push_to_hub(\"unet-mnist-32\", variant=\"fp16\", token=token)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T04:56:57.055248Z","iopub.execute_input":"2024-11-26T04:56:57.055643Z","iopub.status.idle":"2024-11-26T04:56:59.514266Z","shell.execute_reply.started":"2024-11-26T04:56:57.055607Z","shell.execute_reply":"2024-11-26T04:56:59.513419Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/2.70k [00:00