{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.14","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from datasets import load_dataset\ndataset = load_dataset(\"mnist\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:22.773326Z","iopub.execute_input":"2024-11-26T03:35:22.773705Z","iopub.status.idle":"2024-11-26T03:35:33.281299Z","shell.execute_reply.started":"2024-11-26T03:35:22.773675Z","shell.execute_reply":"2024-11-26T03:35:33.280473Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/6.97k [00:00, 5)"},"metadata":{}}],"execution_count":5},{"cell_type":"code","source":"import matplotlib.pyplot as plt\nprint(sample_label)\nplt.imshow(sample)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:37.280112Z","iopub.execute_input":"2024-11-26T03:35:37.280422Z","iopub.status.idle":"2024-11-26T03:35:37.521924Z","shell.execute_reply.started":"2024-11-26T03:35:37.280388Z","shell.execute_reply":"2024-11-26T03:35:37.521067Z"}},"outputs":[{"name":"stdout","text":"5\n","output_type":"stream"},{"execution_count":6,"output_type":"execute_result","data":{"text/plain":""},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"
","image/png":""},"metadata":{}}],"execution_count":6},{"cell_type":"markdown","source":"## Transform Dataset for Training","metadata":{}},{"cell_type":"code","source":"from torchvision import transforms\n\npreprocess = transforms.Compose([\n transforms.ToTensor(),\n transforms.Pad(2), ## send the size becomes 32x32\n ## https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html\n transforms.Normalize([0.5],[0.5]) ## normalize the range into -1 to 1\n])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:37.522925Z","iopub.execute_input":"2024-11-26T03:35:37.523175Z","iopub.status.idle":"2024-11-26T03:35:41.609646Z","shell.execute_reply.started":"2024-11-26T03:35:37.523150Z","shell.execute_reply":"2024-11-26T03:35:41.608888Z"}},"outputs":[],"execution_count":7},{"cell_type":"markdown","source":"Check the shape of data after transformation","metadata":{}},{"cell_type":"code","source":"import torch\nbatch_size = 512\n\ndef transform(examples):\n ## https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.convert\n ## convert PIL Image to L mode (GrayScale)\n images = [preprocess(image.convert(\"L\")) for image in examples[\"image\"]]\n\n return {\"images\":images, \"labels\":examples[\"label\"]}\n\ntrain_dataset = dataset['train'].with_transform(transform)\n\ntrain_dataloader = torch.utils.data.DataLoader(\n train_dataset, batch_size, shuffle=True\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:41.610657Z","iopub.execute_input":"2024-11-26T03:35:41.611079Z","iopub.status.idle":"2024-11-26T03:35:41.619049Z","shell.execute_reply.started":"2024-11-26T03:35:41.611052Z","shell.execute_reply":"2024-11-26T03:35:41.618176Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"batch = next(iter(train_dataloader))\nprint('Shape:', batch['images'].shape,\n '\\nBounds:', batch['images'].min().item(), 'to', batch['images'].max().item())","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:41.620141Z","iopub.execute_input":"2024-11-26T03:35:41.620492Z","iopub.status.idle":"2024-11-26T03:35:41.867613Z","shell.execute_reply.started":"2024-11-26T03:35:41.620464Z","shell.execute_reply":"2024-11-26T03:35:41.866695Z"}},"outputs":[{"name":"stdout","text":"Shape: torch.Size([512, 1, 32, 32]) \nBounds: -1.0 to 1.0\n","output_type":"stream"}],"execution_count":9},{"cell_type":"markdown","source":"## Build the Model","metadata":{}},{"cell_type":"code","source":"from diffusers import UNet2DModel\n\nunet = UNet2DModel(\n in_channels=1,\n out_channels=1,\n sample_size=32,\n block_out_channels=(32,64,128,256),\n norm_num_groups=8,\n num_class_embeds=10\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:41.868891Z","iopub.execute_input":"2024-11-26T03:35:41.869235Z","iopub.status.idle":"2024-11-26T03:35:54.845071Z","shell.execute_reply.started":"2024-11-26T03:35:41.869194Z","shell.execute_reply":"2024-11-26T03:35:54.844383Z"}},"outputs":[{"name":"stderr","text":"The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"0it [00:00, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2c6b13746434472e86d24dfdc207adc2"}},"metadata":{}}],"execution_count":10},{"cell_type":"markdown","source":"Test the inference and the output shape","metadata":{}},{"cell_type":"code","source":"noised_x = torch.randn((1, 1, 32, 32))\nwith torch.no_grad():\n out = unet(noised_x, timestep=7, class_labels=torch.tensor([2])).sample\n\nout.shape","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:54.846131Z","iopub.execute_input":"2024-11-26T03:35:54.846830Z","iopub.status.idle":"2024-11-26T03:35:55.110234Z","shell.execute_reply.started":"2024-11-26T03:35:54.846788Z","shell.execute_reply":"2024-11-26T03:35:55.109407Z"}},"outputs":[{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"torch.Size([1, 1, 32, 32])"},"metadata":{}}],"execution_count":11},{"cell_type":"markdown","source":"## Training","metadata":{}},{"cell_type":"code","source":"import torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom diffusers import DDPMScheduler\n\ndef train(num_epochs=30, lr=1e-4, device=\"cuda\"):\n scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02)\n optimizer = torch.optim.AdamW(unet.parameters(), lr=lr) # The optimizer\n losses = [] # somewhere to store the loss values for later plotting\n unet.to(device)\n\n # Train the model (this takes a while!)\n for epoch in range(num_epochs):\n for step, batch in tqdm(enumerate(train_dataloader)):\n\n # Load the input images\n clean_images = batch[\"images\"].to(device)\n class_labels = batch[\"labels\"].to(device)\n\n # Sample noise to add to the images\n noise = torch.randn(clean_images.shape).to(clean_images.device)\n\n # Sample a random timestep for each image\n timesteps = torch.randint(\n 0,\n scheduler.config.num_train_timesteps,\n (clean_images.shape[0],),\n device=clean_images.device,\n ).long()\n\n # Add noise to the clean images according timestep\n noisy_images = scheduler.add_noise(clean_images, noise, timesteps)\n\n # Get the model prediction for the noise\n noise_pred = unet(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]\n\n # Compare the prediction with the actual noise:\n loss = F.mse_loss(noise_pred, noise)\n losses.append(loss)\n # Store the loss for later plotting\n # Update the model parameters with the optimizer based on this loss\n loss.backward(loss)\n optimizer.step()\n optimizer.zero_grad()\n print(f\"Epoch {epoch}: loss={losses[-1]}\")\n return losses","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:55.111240Z","iopub.execute_input":"2024-11-26T03:35:55.111602Z","iopub.status.idle":"2024-11-26T03:35:55.122148Z","shell.execute_reply.started":"2024-11-26T03:35:55.111575Z","shell.execute_reply":"2024-11-26T03:35:55.121344Z"}},"outputs":[],"execution_count":12},{"cell_type":"code","source":"losses = train()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T03:35:55.123319Z","iopub.execute_input":"2024-11-26T03:35:55.123671Z","iopub.status.idle":"2024-11-26T04:31:56.851922Z","shell.execute_reply.started":"2024-11-26T03:35:55.123634Z","shell.execute_reply":"2024-11-26T04:31:56.851052Z"}},"outputs":[{"name":"stderr","text":"118it [01:52, 1.04it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 0: loss=0.1325972080230713\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 1: loss=0.09378191083669662\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 2: loss=0.07209588587284088\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 3: loss=0.05439606308937073\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 4: loss=0.06066245958209038\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 5: loss=0.04885260760784149\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 6: loss=0.0416167750954628\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 7: loss=0.047721683979034424\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 8: loss=0.033292364329099655\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 9: loss=0.045422039926052094\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 10: loss=0.03524807095527649\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 11: loss=0.03403984382748604\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 12: loss=0.030451234430074692\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 13: loss=0.027445441111922264\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 14: loss=0.0382767878472805\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 15: loss=0.0306419488042593\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 16: loss=0.02459515444934368\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 17: loss=0.023863770067691803\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 18: loss=0.022374501451849937\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 19: loss=0.02972579002380371\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 20: loss=0.022356227040290833\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 21: loss=0.022434819489717484\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 22: loss=0.029154803603887558\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 23: loss=0.024483010172843933\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 24: loss=0.024230940267443657\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 25: loss=0.027546880766749382\n","output_type":"stream"},{"name":"stderr","text":"118it [01:52, 1.05it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 26: loss=0.02587004564702511\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 27: loss=0.020630789920687675\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.06it/s]\n","output_type":"stream"},{"name":"stdout","text":"Epoch 28: loss=0.01809917762875557\n","output_type":"stream"},{"name":"stderr","text":"118it [01:51, 1.05it/s]","output_type":"stream"},{"name":"stdout","text":"Epoch 29: loss=0.015931110829114914\n","output_type":"stream"},{"name":"stderr","text":"\n","output_type":"stream"}],"execution_count":13},{"cell_type":"code","source":"from kaggle_secrets import UserSecretsClient\nuser_secrets = UserSecretsClient()\ntoken = user_secrets.get_secret(\"HF_TOKEN\")\n\nunet.push_to_hub(\"unet-mnist-32\", variant=\"fp16\", token=token)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2024-11-26T04:56:57.055248Z","iopub.execute_input":"2024-11-26T04:56:57.055643Z","iopub.status.idle":"2024-11-26T04:56:59.514266Z","shell.execute_reply.started":"2024-11-26T04:56:57.055607Z","shell.execute_reply":"2024-11-26T04:56:59.513419Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/2.70k [00:00