madhurjindal's picture
Upload 19 files
06a7cdc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics
from torch.optim.lr_scheduler import OneCycleLR
from torch_lr_finder import LRFinder
from . import config
from .visualize import plot_incorrect_preds
class Net(pl.LightningModule):
def __init__(
self,
num_classes=10,
dropout_percentage=0,
norm="bn",
num_groups=2,
criterion=F.cross_entropy,
learning_rate=0.001,
weight_decay=0.0,
):
super(Net, self).__init__()
if norm == "bn":
self.norm = nn.BatchNorm2d
elif norm == "gn":
self.norm = lambda in_dim: nn.GroupNorm(
num_groups=num_groups, num_channels=in_dim
)
elif norm == "ln":
self.norm = lambda in_dim: nn.GroupNorm(num_groups=1, num_channels=in_dim)
# Define the loss criterion
self.criterion = criterion
# Define the Metrics
self.accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=num_classes
)
self.confusion_matrix = torchmetrics.ConfusionMatrix(
task="multiclass", num_classes=config.NUM_CLASSES
)
# Define the Optimizer Hyperparameters
self.learning_rate = learning_rate
self.weight_decay = weight_decay
# Prediction Storage
self.pred_store = {
"test_preds": torch.tensor([]),
"test_labels": torch.tensor([]),
"test_incorrect": [],
}
self.log_store = {
"train_loss_epoch": [],
"train_acc_epoch": [],
"val_loss_epoch": [],
"val_acc_epoch": [],
"test_loss_epoch": [],
"test_acc_epoch": [],
}
# This defines the structure of the NN.
# Prep Layer
self.prep_layer = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 32x32x3 | 1 -> 32x32x64 | 3
self.norm(64),
nn.ReLU(),
nn.Dropout(dropout_percentage),
)
self.l1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1), # 32x32x128 | 5
nn.MaxPool2d(2, 2), # 16x16x128 | 6
self.norm(128),
nn.ReLU(),
nn.Dropout(dropout_percentage),
)
self.l1res = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 10
self.norm(128),
nn.ReLU(),
nn.Dropout(dropout_percentage),
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 14
self.norm(128),
nn.ReLU(),
nn.Dropout(dropout_percentage),
)
self.l2 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1), # 16x16x256 | 18
nn.MaxPool2d(2, 2), # 8x8x256 | 19
self.norm(256),
nn.ReLU(),
nn.Dropout(dropout_percentage),
)
self.l3 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=1), # 8x8x512 | 27
nn.MaxPool2d(2, 2), # 4x4x512 | 28
self.norm(512),
nn.ReLU(),
nn.Dropout(dropout_percentage),
)
self.l3res = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 36
self.norm(512),
nn.ReLU(),
nn.Dropout(dropout_percentage),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 44
self.norm(512),
nn.ReLU(),
nn.Dropout(dropout_percentage),
)
self.maxpool = nn.MaxPool2d(4, 4)
# Classifier
self.linear = nn.Linear(512, 10)
def forward(self, x):
x = self.prep_layer(x)
x = self.l1(x)
x = x + self.l1res(x)
x = self.l2(x)
x = self.l3(x)
x = x + self.l3res(x)
x = self.maxpool(x)
x = x.view(-1, 512)
x = self.linear(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
data, target = batch
# print("curr lr: ", self.optimizers().param_groups[0]["lr"])
# forward pass
pred = self(data)
# Calculate loss
loss = self.criterion(pred, target)
# Calculate the metrics
accuracy = self.accuracy(pred, target)
self.log_dict(
{"train_loss": loss, "train_acc": accuracy},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
# forward pass
pred = self(data)
# Calculate loss
loss = self.criterion(pred, target)
# Calculate the metrics
accuracy = self.accuracy(pred, target)
self.log_dict(
{"val_loss": loss, "val_acc": accuracy},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def test_step(self, batch, batch_idx):
data, target = batch
# forward pass
pred = self(data)
argmax_pred = pred.argmax(dim=1).cpu()
# Calculate loss
loss = self.criterion(pred, target)
# Calculate the metrics
accuracy = self.accuracy(pred, target)
self.log_dict(
{"test_loss": loss, "test_acc": accuracy},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
# Update the confusion matrix
self.confusion_matrix.update(pred, target)
# Store the predictions, labels and incorrect predictions
data, target, pred, argmax_pred = (
data.cpu(),
target.cpu(),
pred.cpu(),
argmax_pred.cpu(),
)
self.pred_store["test_preds"] = torch.cat(
(self.pred_store["test_preds"], argmax_pred), dim=0
)
self.pred_store["test_labels"] = torch.cat(
(self.pred_store["test_labels"], target), dim=0
)
for d, t, p, o in zip(data, target, argmax_pred, pred):
if p.eq(t.view_as(p)).item() == False:
self.pred_store["test_incorrect"].append(
(d.cpu(), t, p, o[p.item()].cpu())
)
return loss
def find_bestLR_LRFinder(self, optimizer):
lr_finder = LRFinder(self, optimizer, criterion=self.criterion)
lr_finder.range_test(
self.trainer.datamodule.train_dataloader(),
end_lr=config.LRFINDER_END_LR,
num_iter=config.LRFINDER_NUM_ITERATIONS,
step_mode=config.LRFINDER_STEP_MODE,
)
best_lr = None
try:
_, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph
except Exception as e:
pass
lr_finder.reset() # to reset the model and optimizer to their initial state
return best_lr
def configure_optimizers(self):
optimizer = self.get_only_optimizer()
best_lr = self.find_bestLR_LRFinder(optimizer)
scheduler = OneCycleLR(
optimizer,
max_lr=1.47e-03,
# total_steps=self.trainer.estimated_stepping_batches,
steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
epochs=config.NUM_EPOCHS,
pct_start=5 / config.NUM_EPOCHS,
div_factor=config.OCLR_DIV_FACTOR,
three_phase=config.OCLR_THREE_PHASE,
final_div_factor=config.OCLR_FINAL_DIV_FACTOR,
anneal_strategy=config.OCLR_ANNEAL_STRATEGY,
)
return [optimizer], [
{"scheduler": scheduler, "interval": "step", "frequency": 1}
]
def get_only_optimizer(self):
optimizer = optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
return optimizer
def on_test_end(self) -> None:
super().on_test_end()
## Confusion Matrix
confmat = self.confusion_matrix.cpu().compute().numpy()
if config.NORM_CONF_MAT:
df_confmat = pd.DataFrame(
confmat / np.sum(confmat, axis=1)[:, None],
index=[i for i in config.CLASSES],
columns=[i for i in config.CLASSES],
)
else:
df_confmat = pd.DataFrame(
confmat,
index=[i for i in config.CLASSES],
columns=[i for i in config.CLASSES],
)
plt.figure(figsize=(7, 5))
sns.heatmap(df_confmat, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.show()
def plot_incorrect_predictions_helper(self, num_imgs=10):
return plot_incorrect_preds(
self.pred_store["test_incorrect"], config.CLASSES, num_imgs
)