# Michael Peres ~ 09/01/2024 # Bert Based Transformer Model for Image Classification # ---------------------------------------------------------------------------------------------------------------------- # Import Modules # pip install transformers torchvision from transformers import BertModel, BertTokenizer, BertConfig from transformers import get_linear_schedule_with_warmup from transformers import BertForSequenceClassification from torchvision.utils import make_grid, save_image from torch.utils.data import Dataset, DataLoader from torchvision.datasets import MNIST, CIFAR10 from torchvision import datasets, transforms from tqdm.notebook import tqdm, trange from torch.optim import AdamW, Adam import matplotlib.pyplot as plt import torch.nn.functional as F import math, os, torch import torch.nn as nn # ---------------------------------------------------------------------------------------------------------------------- # This is a simple implementation, where the first hidden state, # which is the encoded class token is used as the input to a MLP Head for classification. # The model is trained on CIFAR-10 dataset, which is a dataset of 60,000 32x32 color images in 10 classes, # with 6,000 images per class. # This model will only contain the encoder part of the BERT model, and the classification head. # ---------------------------------------------------------------------------------------------------------------------- # Some understanding of the BERT model is required to understand this code, here are the dimensions and documentation. # From documentation, https://huggingface.co/transformers/v3.0.2/model_doc/bert.html # BERT Parameters include: # - hidden size: 256 # - intermediate size: 1024 # - number of hidden_layers: 12 # - num of attention heads: 8 # - max position embeddings: 256 # - vocab size: 100 # - bos_token_id: 101 # - eod_token_id: 102 # - cls_token_id: 103 # But what do all of these mean in terms of the question. # Hidden size, this represents the dimensionality of the input embeddings D. # Intermediate size is the number of neurons in the hidden layer of the feedforward, # the feed forward would have dims, Hidden Size D -> Intermediate Size -> Hidden Size D # Num of hidden layers, means the number of hidden layers in the transformer encoder, # layers refer to transformer blocks, so more transformer blocks in the model. # Num of attention heads, refers to the number multihead attention modules within one hidden layer.abs # Max position embeddings refers to the max size of an input the model can handle, this should be larger for models that handle larger inputs etc.abs # vocab size refers to the set of tokens the model is trained on, which has a specific length, # in our case it is 100, which is confusing, because we have pixel intensities between 0-255. # bos token is the beginning of a sentence token, which is token id, good for understanding sentence boundaries for text generation tasks.abs # eos token id is end of sentence token, which I dont see in the documentation for bert config. # cls token id is token is inputted at the beginning of each input instances. # output_hidden_states = True, means to output all the hidden states for us to view. # ---------------------------------------------------------------------------------------------------------------------- # Preparing CIFAR10 Image Dataset, and DataLoaders for Training and Testing dataset = CIFAR10(root='./data/', train=True, download=True, transform= transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])) # augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] )) # Model Configuration and Hyperparameters config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=256, vocab_size=100, bos_token_id=101, eos_token_id=102, cls_token_id=103, output_hidden_states=False) model = BertModel(config).cuda() patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda() CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size)) readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, 10) ).cuda() for module in [patch_embed, readout, model, CLS_token]: module.cuda() optimizer = AdamW([*model.parameters(), *patch_embed.parameters(), *readout.parameters(), CLS_token], lr=5e-4) # DataLoaders batch_size = 192 # 96 train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # ---------------------------------------------------------------------------------------------------------------------- # Understanding ClS Token: # print("CLASS TOKEN shape:") # print(CLS_token.shape) # # reshaped_cls = CLS_token.expand(192, 1, -1) # print("CLS Reshaped shape", reshaped_cls.shape) # 192, 1, 256 # # We are telling the CLS to have the same shape as patch embeddings. # # imgs, labels = next(iter(train_loader)) # patch_embs = patch_embed(imgs.cuda()).flatten(2).permute(0, 2, 1) # # input_embs = torch.cat([reshaped_cls, patch_embs], dim=1) # print("Patch Embeddings Shape", patch_embs.shape) # # print("Input Embedding Shape", input_embs.shape) # ---------------------------------------------------------------------------------------------------------------------- # Understanding Output of Model Transformer: # Hidden State state dimension: 192, 12, 65, 256 # Last Hidden state dimension: 192, 65 256 # Pooler Output: 192, 256 # in essence pool all the tokens outputs, so we have a one value per complete sample, # completely removing the information for each token. # # # We should understand output of a model, # representations = output.last_hidden_state[:, 0, :] # print(output.last_hidden_state.shape) # Out of memory. # print(representations.shape) # ---------------------------------------------------------------------------------------------------------------------- # Training Loop EPOCHS = 30 model.train() loss_list = [] acc_list = [] correct_cnt = 0 total_loss = 0 for epoch in trange(EPOCHS, leave=False): pbar = tqdm(train_loader, leave=False) for i, (imgs, labels) in enumerate(pbar): patch_embs = patch_embed(imgs.cuda()) # patch embeddings, # print("patch embs shape ", patch_embs.shape) # (192, 256, 8, 8) # 192 per batch, patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden=256) # print(patch_embs.shape) input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1) # print(input_embs.shape) output = model(inputs_embeds=input_embs) # print(dir(output)) # print("output, hidden state shape", output.hidden_states) # out of memory error. # print("output hidden state shape", output.last_hidden_state.shape) # 192, 65, 256 # print("output pooler output shape", output.pooler_output.shape) logit = readout(output.last_hidden_state[:, 0, :]) loss = F.cross_entropy(logit, labels.cuda()) # print(loss) loss.backward() optimizer.step() optimizer.zero_grad() pbar.set_description(f"loss: {loss.item():.4f}") total_loss += loss.item() * imgs.shape[0] correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item() loss_list.append(round(total_loss / len(dataset), 4)) acc_list.append(round(correct_cnt / len(dataset), 4)) # test on validation set model.eval() correct_cnt = 0 total_loss = 0 for i, (imgs, labels) in enumerate(val_loader): patch_embs = patch_embed(imgs.cuda()) patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden) input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1) output = model(inputs_embeds=input_embs) logit = readout(output.last_hidden_state[:, 0, :]) loss = F.cross_entropy(logit, labels.cuda()) total_loss += loss.item() * imgs.shape[0] correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item() print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}") # Plotting Loss and Accuracy plt.figure() plt.plot(loss_list, label="loss") plt.plot(acc_list, label="accuracy") plt.legend() plt.show() # ---------------------------------------------------------------------------------------------------------------------- # Saving Model Parameters torch.save(model.state_dict(), "bert.pth") # ---------------------------------------------------------------------------------------------------------------------- # Reference: Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch # ----------------------------------------------------------------------------------------------------------------------