import matplotlib.pyplot as plt import numpy as np import pandas as pd import pytorch_lightning as pl import seaborn as sns import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchmetrics from torch.optim.lr_scheduler import OneCycleLR from torch_lr_finder import LRFinder from . import config from .visualize import plot_incorrect_preds class Net(pl.LightningModule): def __init__( self, num_classes=10, dropout_percentage=0, norm="bn", num_groups=2, criterion=F.cross_entropy, learning_rate=0.001, weight_decay=0.0, ): super(Net, self).__init__() if norm == "bn": self.norm = nn.BatchNorm2d elif norm == "gn": self.norm = lambda in_dim: nn.GroupNorm( num_groups=num_groups, num_channels=in_dim ) elif norm == "ln": self.norm = lambda in_dim: nn.GroupNorm(num_groups=1, num_channels=in_dim) # Define the loss criterion self.criterion = criterion # Define the Metrics self.accuracy = torchmetrics.Accuracy( task="multiclass", num_classes=num_classes ) self.confusion_matrix = torchmetrics.ConfusionMatrix( task="multiclass", num_classes=config.NUM_CLASSES ) # Define the Optimizer Hyperparameters self.learning_rate = learning_rate self.weight_decay = weight_decay # Prediction Storage self.pred_store = { "test_preds": torch.tensor([]), "test_labels": torch.tensor([]), "test_incorrect": [], } self.log_store = { "train_loss_epoch": [], "train_acc_epoch": [], "val_loss_epoch": [], "val_acc_epoch": [], "test_loss_epoch": [], "test_acc_epoch": [], } # This defines the structure of the NN. # Prep Layer self.prep_layer = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), # 32x32x3 | 1 -> 32x32x64 | 3 self.norm(64), nn.ReLU(), nn.Dropout(dropout_percentage), ) self.l1 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), # 32x32x128 | 5 nn.MaxPool2d(2, 2), # 16x16x128 | 6 self.norm(128), nn.ReLU(), nn.Dropout(dropout_percentage), ) self.l1res = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 10 self.norm(128), nn.ReLU(), nn.Dropout(dropout_percentage), nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 14 self.norm(128), nn.ReLU(), nn.Dropout(dropout_percentage), ) self.l2 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, padding=1), # 16x16x256 | 18 nn.MaxPool2d(2, 2), # 8x8x256 | 19 self.norm(256), nn.ReLU(), nn.Dropout(dropout_percentage), ) self.l3 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, padding=1), # 8x8x512 | 27 nn.MaxPool2d(2, 2), # 4x4x512 | 28 self.norm(512), nn.ReLU(), nn.Dropout(dropout_percentage), ) self.l3res = nn.Sequential( nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 36 self.norm(512), nn.ReLU(), nn.Dropout(dropout_percentage), nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 44 self.norm(512), nn.ReLU(), nn.Dropout(dropout_percentage), ) self.maxpool = nn.MaxPool2d(4, 4) # Classifier self.linear = nn.Linear(512, 10) def forward(self, x): x = self.prep_layer(x) x = self.l1(x) x = x + self.l1res(x) x = self.l2(x) x = self.l3(x) x = x + self.l3res(x) x = self.maxpool(x) x = x.view(-1, 512) x = self.linear(x) return F.log_softmax(x, dim=1) def training_step(self, batch, batch_idx): data, target = batch # print("curr lr: ", self.optimizers().param_groups[0]["lr"]) # forward pass pred = self(data) # Calculate loss loss = self.criterion(pred, target) # Calculate the metrics accuracy = self.accuracy(pred, target) self.log_dict( {"train_loss": loss, "train_acc": accuracy}, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def validation_step(self, batch, batch_idx): data, target = batch # forward pass pred = self(data) # Calculate loss loss = self.criterion(pred, target) # Calculate the metrics accuracy = self.accuracy(pred, target) self.log_dict( {"val_loss": loss, "val_acc": accuracy}, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def test_step(self, batch, batch_idx): data, target = batch # forward pass pred = self(data) argmax_pred = pred.argmax(dim=1).cpu() # Calculate loss loss = self.criterion(pred, target) # Calculate the metrics accuracy = self.accuracy(pred, target) self.log_dict( {"test_loss": loss, "test_acc": accuracy}, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) # Update the confusion matrix self.confusion_matrix.update(pred, target) # Store the predictions, labels and incorrect predictions data, target, pred, argmax_pred = ( data.cpu(), target.cpu(), pred.cpu(), argmax_pred.cpu(), ) self.pred_store["test_preds"] = torch.cat( (self.pred_store["test_preds"], argmax_pred), dim=0 ) self.pred_store["test_labels"] = torch.cat( (self.pred_store["test_labels"], target), dim=0 ) for d, t, p, o in zip(data, target, argmax_pred, pred): if p.eq(t.view_as(p)).item() == False: self.pred_store["test_incorrect"].append( (d.cpu(), t, p, o[p.item()].cpu()) ) return loss def find_bestLR_LRFinder(self, optimizer): lr_finder = LRFinder(self, optimizer, criterion=self.criterion) lr_finder.range_test( self.trainer.datamodule.train_dataloader(), end_lr=config.LRFINDER_END_LR, num_iter=config.LRFINDER_NUM_ITERATIONS, step_mode=config.LRFINDER_STEP_MODE, ) best_lr = None try: _, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph except Exception as e: pass lr_finder.reset() # to reset the model and optimizer to their initial state return best_lr def configure_optimizers(self): optimizer = self.get_only_optimizer() best_lr = self.find_bestLR_LRFinder(optimizer) scheduler = OneCycleLR( optimizer, max_lr=1.47e-03, # total_steps=self.trainer.estimated_stepping_batches, steps_per_epoch=len(self.trainer.datamodule.train_dataloader()), epochs=config.NUM_EPOCHS, pct_start=5 / config.NUM_EPOCHS, div_factor=config.OCLR_DIV_FACTOR, three_phase=config.OCLR_THREE_PHASE, final_div_factor=config.OCLR_FINAL_DIV_FACTOR, anneal_strategy=config.OCLR_ANNEAL_STRATEGY, ) return [optimizer], [ {"scheduler": scheduler, "interval": "step", "frequency": 1} ] def get_only_optimizer(self): optimizer = optim.Adam( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) return optimizer def on_test_end(self) -> None: super().on_test_end() ## Confusion Matrix confmat = self.confusion_matrix.cpu().compute().numpy() if config.NORM_CONF_MAT: df_confmat = pd.DataFrame( confmat / np.sum(confmat, axis=1)[:, None], index=[i for i in config.CLASSES], columns=[i for i in config.CLASSES], ) else: df_confmat = pd.DataFrame( confmat, index=[i for i in config.CLASSES], columns=[i for i in config.CLASSES], ) plt.figure(figsize=(7, 5)) sns.heatmap(df_confmat, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5) plt.tight_layout() plt.ylabel("True label") plt.xlabel("Predicted label") plt.show() def plot_incorrect_predictions_helper(self, num_imgs=10): return plot_incorrect_preds( self.pred_store["test_incorrect"], config.CLASSES, num_imgs )