Spaces:
Sleeping
Sleeping
srikanthp07
commited on
Commit
·
9022436
1
Parent(s):
a306de4
Upload 27 files
Browse files'created app'
- README.md +0 -13
- app.py +156 -0
- examples/bird.png +0 -0
- examples/car.png +0 -0
- examples/cat.png +0 -0
- examples/deer.png +0 -0
- examples/dog.png +0 -0
- examples/frog.png +0 -0
- examples/horse.png +0 -0
- examples/plane.png +0 -0
- examples/ship.png +0 -0
- examples/truck.png +0 -0
- model.py +272 -0
- modelp.ckpt +3 -0
- requirements.txt +6 -0
- utils/__pycache__/dataloader.cpython-310.pyc +0 -0
- utils/__pycache__/dataset.cpython-310.pyc +0 -0
- utils/__pycache__/gradcam.cpython-310.pyc +0 -0
- utils/__pycache__/transforms.cpython-310.pyc +0 -0
- utils/__pycache__/utils.cpython-310.pyc +0 -0
- utils/dataloader.py +13 -0
- utils/dataset.py +26 -0
- utils/find_LR.py +13 -0
- utils/gradcam.py +175 -0
- utils/one_cycle_lr.py +14 -0
- utils/transforms.py +43 -0
- utils/utils.py +125 -0
README.md
CHANGED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: S12
|
3 |
-
emoji: 🐠
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: red
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.39.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
from pytorch_grad_cam import GradCAM
|
5 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
6 |
+
import torch
|
7 |
+
from torchvision import transforms
|
8 |
+
from model import CustomResNet
|
9 |
+
from utils.utils import wrong_predictions
|
10 |
+
from utils.dataloader import get_dataloader
|
11 |
+
import random
|
12 |
+
from collections import OrderedDict
|
13 |
+
import os
|
14 |
+
|
15 |
+
test_o = get_dataloader()
|
16 |
+
# test_o=next(iter(test_o))
|
17 |
+
|
18 |
+
|
19 |
+
examples_dir = os.path.join(os.getcwd(), 'examples')
|
20 |
+
examples = [[os.path.join(examples_dir, img), 0.5] for img in os.listdir(examples_dir)]
|
21 |
+
|
22 |
+
|
23 |
+
model = CustomResNet()
|
24 |
+
model.load_state_dict(torch.load('modelp.ckpt')['state_dict'])#, strict = False)
|
25 |
+
model = model.cpu()
|
26 |
+
|
27 |
+
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
28 |
+
norm_mean=(0.4914, 0.4822, 0.4465)
|
29 |
+
norm_std=(0.2023, 0.1994, 0.2010)
|
30 |
+
misclassified_images, all_predictions = wrong_predictions(model,test_o, norm_mean, norm_std, classes, 'cpu')
|
31 |
+
|
32 |
+
layers = ['layer_1', 'layer_3']
|
33 |
+
# layers = [model.layer_1, model.layer_2, model.layer_3]
|
34 |
+
def inference(input_img, transparency, layer_num, top_classes):
|
35 |
+
input_img_ori = input_img.copy()
|
36 |
+
transform = transforms.ToTensor()
|
37 |
+
# transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(
|
38 |
+
# mean=[0.485,0.456,0.406],
|
39 |
+
# std=[0.229, 0.224, 0.255]
|
40 |
+
# )])
|
41 |
+
inv_normalize = transforms.Normalize(
|
42 |
+
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
|
43 |
+
std=[1/0.229, 1/0.224, 1/0.255]
|
44 |
+
)
|
45 |
+
input_img = transform(input_img)
|
46 |
+
# input_img = input_img.to(device)
|
47 |
+
input_img = input_img.unsqueeze(0)
|
48 |
+
outputs = model(input_img)
|
49 |
+
_, prediction = torch.max(outputs, 1)
|
50 |
+
softmax = torch.nn.Softmax(dim=0)
|
51 |
+
outputs = softmax(outputs.flatten())
|
52 |
+
# print(outputs)
|
53 |
+
confidences = {classes[i]: float(outputs[i]) for i in range(10)}
|
54 |
+
confidences = OrderedDict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
|
55 |
+
# print(confidences)
|
56 |
+
filtered_confidences ={}# OrderedDict()
|
57 |
+
for i, (key, val) in enumerate(confidences.items()):
|
58 |
+
if i == top_classes:
|
59 |
+
break
|
60 |
+
filtered_confidences[key] = val
|
61 |
+
|
62 |
+
|
63 |
+
if layer_num == 1:
|
64 |
+
target_layers = [model.layer_1]
|
65 |
+
elif layer_num == 2:
|
66 |
+
target_layers = [model.layer_2]
|
67 |
+
else:
|
68 |
+
target_layers = [model.layer_3]
|
69 |
+
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
70 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
71 |
+
grayscale_cam = grayscale_cam[0, :]
|
72 |
+
img = input_img.squeeze(0)
|
73 |
+
img = inv_normalize(img)
|
74 |
+
rgb_img = np.transpose(img, (1, 2, 0))
|
75 |
+
rgb_img = np.array(np.clip(rgb_img,0,1), np.float32)
|
76 |
+
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
77 |
+
|
78 |
+
# visualization = input_img_ori
|
79 |
+
return filtered_confidences, visualization
|
80 |
+
# return filtered_confidences, superimposed_img
|
81 |
+
|
82 |
+
def get_misclassified_images(num):
|
83 |
+
outputimgs = []
|
84 |
+
# misclassified_images = wrong_predictions(model,test_o, norm_mean, norm_std, classes, 'cpu')
|
85 |
+
for i in range(int(num)):
|
86 |
+
# misclassified_images[0][0].cpu().numpy()
|
87 |
+
inv_normalize = transforms.Normalize(
|
88 |
+
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
|
89 |
+
std=[1/0.229, 1/0.224, 1/0.255]
|
90 |
+
)
|
91 |
+
inv_tensor = np.array(inv_normalize(misclassified_images[random.randint(2,98)][0]).cpu().permute(1,2,0)*255, dtype='uint8')
|
92 |
+
outputimgs.append(inv_tensor)
|
93 |
+
return outputimgs
|
94 |
+
|
95 |
+
|
96 |
+
def get_gradcam_images(num, transparency, layer_num):
|
97 |
+
outcoms=[]
|
98 |
+
for i in range(int(num)):
|
99 |
+
input_img = all_predictions[random.randint(2,98)][0]
|
100 |
+
inv_normalize = transforms.Normalize(
|
101 |
+
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
|
102 |
+
std=[1/0.229, 1/0.224, 1/0.255]
|
103 |
+
)
|
104 |
+
input_img = input_img.unsqueeze(0)
|
105 |
+
if layer_num == 1:
|
106 |
+
target_layers = [model.layer_1]
|
107 |
+
elif layer_num == 2:
|
108 |
+
target_layers = [model.layer_2]
|
109 |
+
else:
|
110 |
+
target_layers = [model.layer_3]
|
111 |
+
|
112 |
+
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
113 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
114 |
+
grayscale_cam = grayscale_cam[0, :]
|
115 |
+
img = input_img.squeeze(0)
|
116 |
+
img = inv_normalize(img)
|
117 |
+
rgb_img = np.transpose(img, (1, 2, 0))
|
118 |
+
rgb_img = np.array(np.clip(rgb_img,0,1), np.float32)
|
119 |
+
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
120 |
+
outcoms.append(visualization)
|
121 |
+
return outcoms
|
122 |
+
|
123 |
+
# demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1)], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
|
124 |
+
inference_new_image = gr.Interface(
|
125 |
+
inference,
|
126 |
+
inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value = 0.3, label="transparency?"), gr.Slider(1, 3, value = 1,step=1, label="layer?"),
|
127 |
+
gr.Slider(1, 10, value = 3, step=1, label="top classes?")],
|
128 |
+
|
129 |
+
outputs = [gr.Label(),gr.Image(shape=(32, 32), label="Model Prediction").style(width=300, height=300)],
|
130 |
+
title = 'gradio app',
|
131 |
+
description = 'for dl purposes',
|
132 |
+
examples = examples,
|
133 |
+
)
|
134 |
+
|
135 |
+
misclassified_interface = gr.Interface(
|
136 |
+
get_misclassified_images,
|
137 |
+
inputs = [gr.Number(value=10, label="images number")],
|
138 |
+
|
139 |
+
outputs = [gr.Gallery(label="misclassified images")],
|
140 |
+
title = 'gradio app',
|
141 |
+
description = 'for dl purposes'
|
142 |
+
)
|
143 |
+
|
144 |
+
gradcam_images = gr.Interface(
|
145 |
+
get_gradcam_images,
|
146 |
+
inputs = [gr.Number(value=10, label="images number"), gr.Slider(0, 1, value = 0.3, label="transparency?"), gr.Slider(1, 3, value = 1,step=1, label="layer?")],
|
147 |
+
|
148 |
+
outputs = [gr.Gallery(label="gradcam images")],
|
149 |
+
title = 'gradio app',
|
150 |
+
description = 'for dl purposes'
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
demo = gr.TabbedInterface([inference_new_image, misclassified_interface, gradcam_images], tab_names=["Input image", "Misclassified Images", "grad cam images"],
|
155 |
+
title="customresnet gradcam")
|
156 |
+
demo.launch()
|
examples/bird.png
ADDED
examples/car.png
ADDED
examples/cat.png
ADDED
examples/deer.png
ADDED
examples/dog.png
ADDED
examples/frog.png
ADDED
examples/horse.png
ADDED
examples/plane.png
ADDED
examples/ship.png
ADDED
examples/truck.png
ADDED
model.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from pytorch_lightning import LightningModule
|
8 |
+
from torch.utils.data import DataLoader, random_split
|
9 |
+
from torchmetrics import Accuracy
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.datasets import CIFAR10
|
12 |
+
|
13 |
+
# from utils.dataloader import get_dataloader
|
14 |
+
from utils.dataset import get_dataset
|
15 |
+
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
|
19 |
+
AVAIL_GPUS = min(1, torch.cuda.device_count())
|
20 |
+
BATCH_SIZE = 512 if AVAIL_GPUS else 64
|
21 |
+
|
22 |
+
|
23 |
+
# transforms with albumentations
|
24 |
+
# find_lr coupled with one_cycle lr
|
25 |
+
|
26 |
+
|
27 |
+
class BasicBlock(LightningModule):
|
28 |
+
|
29 |
+
def __init__(self, in_planes, planes, stride=1):
|
30 |
+
super(BasicBlock, self).__init__()
|
31 |
+
self.conv1 = nn.Conv2d(
|
32 |
+
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
33 |
+
)
|
34 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
35 |
+
self.conv2 = nn.Conv2d(
|
36 |
+
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
37 |
+
)
|
38 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
39 |
+
|
40 |
+
self.shortcut = nn.Sequential()
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
44 |
+
out = self.bn2(self.conv2(out))
|
45 |
+
out += self.shortcut(x)
|
46 |
+
out = F.relu(out)
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
class CustomBlock(LightningModule):
|
51 |
+
def __init__(self, in_channels, out_channels):
|
52 |
+
super(CustomBlock, self).__init__()
|
53 |
+
|
54 |
+
self.inner_layer = nn.Sequential(
|
55 |
+
nn.Conv2d(
|
56 |
+
in_channels=in_channels,
|
57 |
+
out_channels=out_channels,
|
58 |
+
kernel_size=3,
|
59 |
+
stride=1,
|
60 |
+
padding=1,
|
61 |
+
bias=False,
|
62 |
+
),
|
63 |
+
nn.MaxPool2d(kernel_size=2),
|
64 |
+
nn.BatchNorm2d(out_channels),
|
65 |
+
nn.ReLU(),
|
66 |
+
)
|
67 |
+
|
68 |
+
self.res_block = BasicBlock(out_channels, out_channels)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = self.inner_layer(x)
|
72 |
+
r = self.res_block(x)
|
73 |
+
|
74 |
+
out = x + r
|
75 |
+
|
76 |
+
return out
|
77 |
+
|
78 |
+
|
79 |
+
class CustomResNet(LightningModule):
|
80 |
+
def __init__(self, num_classes=10,data_dir=PATH_DATASETS, hidden_size=16, lr=2e-4):
|
81 |
+
super(CustomResNet, self).__init__()
|
82 |
+
|
83 |
+
self.data_dir = data_dir
|
84 |
+
self.hidden_size = hidden_size
|
85 |
+
self.learning_rate = lr
|
86 |
+
|
87 |
+
self.train_losses = []
|
88 |
+
self.test_losses = []
|
89 |
+
self.train_acc = []
|
90 |
+
self.test_acc = []
|
91 |
+
|
92 |
+
self.lr_change = []
|
93 |
+
# self.outputs=[]
|
94 |
+
self.train_step_losses = []
|
95 |
+
self.train_step_acc = []
|
96 |
+
|
97 |
+
self.val_step_losses = []
|
98 |
+
self.val_step_acc = []
|
99 |
+
|
100 |
+
test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
|
101 |
+
|
102 |
+
self.accuracy = Accuracy(task='multiclass',num_classes=num_classes)
|
103 |
+
|
104 |
+
self.transform = transforms.Compose(
|
105 |
+
[
|
106 |
+
transforms.ToTensor(),
|
107 |
+
transforms.Normalize((0.1307,), (0.3081,)),
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
self.prep_layer = nn.Sequential(
|
112 |
+
nn.Conv2d(
|
113 |
+
in_channels=3,
|
114 |
+
out_channels=64,
|
115 |
+
kernel_size=3,
|
116 |
+
stride=1,
|
117 |
+
padding=1,
|
118 |
+
bias=False,
|
119 |
+
),
|
120 |
+
nn.BatchNorm2d(64),
|
121 |
+
nn.ReLU(),
|
122 |
+
)
|
123 |
+
|
124 |
+
self.layer_1 = CustomBlock(in_channels=64, out_channels=128)
|
125 |
+
|
126 |
+
self.layer_2 = nn.Sequential(
|
127 |
+
nn.Conv2d(
|
128 |
+
in_channels=128,
|
129 |
+
out_channels=256,
|
130 |
+
kernel_size=3,
|
131 |
+
stride=1,
|
132 |
+
padding=1,
|
133 |
+
bias=False,
|
134 |
+
),
|
135 |
+
nn.MaxPool2d(kernel_size=2),
|
136 |
+
nn.BatchNorm2d(256),
|
137 |
+
nn.ReLU(),
|
138 |
+
)
|
139 |
+
|
140 |
+
self.layer_3 = CustomBlock(in_channels=256, out_channels=512)
|
141 |
+
|
142 |
+
self.max_pool = nn.Sequential(nn.MaxPool2d(kernel_size=4))
|
143 |
+
|
144 |
+
self.fc = nn.Linear(512, num_classes)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
x = self.prep_layer(x)
|
148 |
+
x = self.layer_1(x)
|
149 |
+
x = self.layer_2(x)
|
150 |
+
x = self.layer_3(x)
|
151 |
+
x = self.max_pool(x)
|
152 |
+
x = x.view(x.size(0), -1)
|
153 |
+
x = self.fc(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
def training_step(self, batch, batch_idx):
|
157 |
+
x, y = batch
|
158 |
+
logits = self(x)
|
159 |
+
loss = F.cross_entropy(logits, y)
|
160 |
+
preds = torch.argmax(logits, dim=1)
|
161 |
+
acc = (preds == y).cpu().float().mean()
|
162 |
+
|
163 |
+
# Calling self.log will surface up scalars for you in TensorBoard
|
164 |
+
self.log("train_loss", loss, prog_bar=True)
|
165 |
+
self.log("train_acc", acc, prog_bar=True)
|
166 |
+
self.train_step_acc.append(acc)
|
167 |
+
self.train_step_losses.append(loss.cpu().item())
|
168 |
+
|
169 |
+
return {'loss':loss, 'train_acc': acc}
|
170 |
+
|
171 |
+
def on_train_epoch_end(self):
|
172 |
+
# batch_losses = [x["train_loss"] for x in outputs] #This part
|
173 |
+
epoch_loss = sum(self.train_step_losses)/len(self.train_step_losses)
|
174 |
+
# batch_accs = [x["train_acc"] for x in outputs] #This part
|
175 |
+
epoch_acc = sum(self.train_step_acc)/len(self.train_step_acc)
|
176 |
+
self.log("train_loss_epoch", epoch_loss, prog_bar=True)
|
177 |
+
self.log("train_acc_epoch", epoch_acc, prog_bar=True)
|
178 |
+
self.train_acc.append(epoch_acc)
|
179 |
+
self.train_losses.append(epoch_loss)
|
180 |
+
self.lr_change.append(self.scheduler.get_last_lr()[0])
|
181 |
+
self.train_step_losses.clear()
|
182 |
+
self.train_step_acc.clear()
|
183 |
+
return epoch_acc
|
184 |
+
|
185 |
+
def validation_step(self, batch, batch_idx):
|
186 |
+
x, y = batch
|
187 |
+
logits = self(x)
|
188 |
+
loss = F.cross_entropy(logits, y)
|
189 |
+
preds = torch.argmax(logits, dim=1)
|
190 |
+
acc = (preds == y).cpu().float().mean()
|
191 |
+
|
192 |
+
# Calling self.log will surface up scalars for you in TensorBoard
|
193 |
+
self.log("val_loss", loss, prog_bar=True)
|
194 |
+
self.log("val_acc", acc, prog_bar=True)
|
195 |
+
self.val_step_acc.append(acc)
|
196 |
+
self.val_step_losses.append(loss.cpu().item())
|
197 |
+
return {'val_loss':loss, 'val_acc': acc}
|
198 |
+
|
199 |
+
def on_validation_epoch_end(self):
|
200 |
+
# batch_losses = [x["val_loss"] for x in outputs] #This part
|
201 |
+
epoch_loss = sum(self.val_step_losses)/len(self.val_step_losses)
|
202 |
+
# batch_accs = [x["val_acc"] for x in outputs] #This part
|
203 |
+
epoch_acc = sum(self.val_step_acc)/len(self.val_step_acc)
|
204 |
+
self.log("val_loss_epoch", epoch_loss, prog_bar=True)
|
205 |
+
self.log("val_acc_epoch", epoch_acc, prog_bar=True)
|
206 |
+
self.test_acc.append(epoch_acc)
|
207 |
+
self.test_losses.append(epoch_loss)
|
208 |
+
self.val_step_losses.clear()
|
209 |
+
self.val_step_acc.clear()
|
210 |
+
return epoch_acc
|
211 |
+
|
212 |
+
def test_step(self, batch, batch_idx):
|
213 |
+
# Here we just reuse the validation_step for testing
|
214 |
+
return self.validation_step(batch, batch_idx)
|
215 |
+
|
216 |
+
def configure_optimizers(self):
|
217 |
+
self.optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=5e-4)
|
218 |
+
|
219 |
+
self.scheduler=torch.optim.lr_scheduler.OneCycleLR(self.optimizer,max_lr=self.learning_rate,epochs=30,steps_per_epoch=len(self.cifar_full)//BATCH_SIZE)
|
220 |
+
lr_scheduler = {'scheduler': self.scheduler, 'interval': 'step'}
|
221 |
+
return {'optimizer': self.optimizer, 'lr_scheduler': lr_scheduler}
|
222 |
+
|
223 |
+
####################
|
224 |
+
# DATA RELATED HOOKS
|
225 |
+
####################
|
226 |
+
|
227 |
+
def prepare_data(self):
|
228 |
+
# download
|
229 |
+
CIFAR10(self.data_dir, train=True, download=True)
|
230 |
+
CIFAR10(self.data_dir, train=False, download=True)
|
231 |
+
|
232 |
+
def setup(self, stage=None):
|
233 |
+
|
234 |
+
# Assign train/val datasets for use in dataloaders
|
235 |
+
if stage == "fit" or stage is None:
|
236 |
+
# cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
|
237 |
+
self.cifar_full = get_dataset()[0]
|
238 |
+
self.cifar_train, self.cifar_val = random_split(self.cifar_full, [45000, 5000])
|
239 |
+
|
240 |
+
# Assign test dataset for use in dataloader(s)
|
241 |
+
if stage == "test" or stage is None:
|
242 |
+
# self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
|
243 |
+
self.cifar_test = get_dataset()[1]
|
244 |
+
|
245 |
+
def train_dataloader(self):
|
246 |
+
cifar_full = get_dataset()[0]
|
247 |
+
return DataLoader(cifar_full, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
|
248 |
+
# return get_dataloader()[0]
|
249 |
+
|
250 |
+
|
251 |
+
def val_dataloader(self):
|
252 |
+
return DataLoader(self.cifar_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
|
253 |
+
# return get_dataloader()[1]
|
254 |
+
|
255 |
+
def test_dataloader(self):
|
256 |
+
return DataLoader(self.cifar_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
|
257 |
+
# return get_dataloader()[1]
|
258 |
+
|
259 |
+
def draw_graphs(self):
|
260 |
+
fig, axs = plt.subplots(2,2,figsize=(15,10))
|
261 |
+
axs[0, 0].plot(self.train_losses)
|
262 |
+
axs[0, 0].set_title("Training Loss")
|
263 |
+
axs[1, 0].plot(self.train_acc)
|
264 |
+
axs[1, 0].set_title("Training Accuracy")
|
265 |
+
axs[0, 1].plot(self.test_losses)
|
266 |
+
axs[0, 1].set_title("Test Loss")
|
267 |
+
axs[1, 1].plot(self.test_acc)
|
268 |
+
axs[1, 1].set_title("Test Accuracy")
|
269 |
+
|
270 |
+
def draw_graphs_lr(self):
|
271 |
+
# fig, axs = plt.subplots(1,1,figsize=(15,10))
|
272 |
+
plt.plot(self.lr_change)
|
modelp.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:362e063f506824562fe59f8732ac5ae0db714cff4829400c05718fbac0f68b3e
|
3 |
+
size 52634750
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torch-lr-finder
|
4 |
+
grad-cam
|
5 |
+
pillow
|
6 |
+
numpy
|
utils/__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (499 Bytes). View file
|
|
utils/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (1.18 kB). View file
|
|
utils/__pycache__/gradcam.cpython-310.pyc
ADDED
Binary file (5.72 kB). View file
|
|
utils/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (1.02 kB). View file
|
|
utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.19 kB). View file
|
|
utils/dataloader.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils.dataset import get_dataset
|
3 |
+
|
4 |
+
batch_size = 10
|
5 |
+
|
6 |
+
kwargs = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 2, 'pin_memory': True}
|
7 |
+
|
8 |
+
def get_dataloader():
|
9 |
+
test_data = get_dataset()
|
10 |
+
test_loader = torch.utils.data.DataLoader(test_data, **kwargs)
|
11 |
+
# train_loader = torch.utils.data.DataLoader(train_data, **kwargs)
|
12 |
+
|
13 |
+
return test_loader
|
utils/dataset.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
|
4 |
+
from .transforms import test_transforms, train_transforms
|
5 |
+
|
6 |
+
|
7 |
+
class Cifar10SearchDataset(datasets.CIFAR10):
|
8 |
+
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
|
9 |
+
super().__init__(root=root, train=train, download=download, transform=transform)
|
10 |
+
|
11 |
+
def __getitem__(self, index):
|
12 |
+
image, label = self.data[index], self.targets[index]
|
13 |
+
|
14 |
+
if self.transform is not None:
|
15 |
+
transformed = self.transform(image=image)
|
16 |
+
image = transformed["image"]
|
17 |
+
|
18 |
+
return image, label
|
19 |
+
|
20 |
+
def get_dataset():
|
21 |
+
# train_data = Cifar10SearchDataset(
|
22 |
+
# root='./data/cifar10', train=True, download=True, transform=train_transforms)
|
23 |
+
test_data = Cifar10SearchDataset(
|
24 |
+
root='./data/cifar10', train=False, download=True, transform=test_transforms)
|
25 |
+
|
26 |
+
return test_data
|
utils/find_LR.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch_lr_finder import LRFinder
|
2 |
+
|
3 |
+
def find_lr(model,optimizer, criterion, device,train_loader):
|
4 |
+
lr_finder = LRFinder(model, optimizer, criterion, device=device)
|
5 |
+
lr_finder.range_test(
|
6 |
+
train_loader,
|
7 |
+
step_mode="exp",
|
8 |
+
end_lr=10,
|
9 |
+
num_iter=200,
|
10 |
+
)
|
11 |
+
mx_lr = lr_finder.plot(suggest_lr=True, skip_start=0, skip_end=0)
|
12 |
+
lr_finder.reset()
|
13 |
+
return mx_lr
|
utils/gradcam.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import functional as F
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def denormalize(img):
|
8 |
+
mean = (0.49139968, 0.48215841, 0.44653091)
|
9 |
+
std = (0.24703223, 0.24348513, 0.26158784)
|
10 |
+
img = img.cpu().numpy().astype(dtype=np.float32)
|
11 |
+
|
12 |
+
for i in range(img.shape[0]):
|
13 |
+
img[i] = (img[i]*std[i])+mean[i]
|
14 |
+
|
15 |
+
return np.transpose(img, (1,2,0))
|
16 |
+
|
17 |
+
class GradCAM:
|
18 |
+
""" Class for extracting activations and
|
19 |
+
registering gradients from targetted intermediate layers
|
20 |
+
target_layers = list of convolution layer index as shown in summary
|
21 |
+
"""
|
22 |
+
def __init__(self, model, candidate_layers=None):
|
23 |
+
def save_fmaps(key):
|
24 |
+
def forward_hook(module, input, output):
|
25 |
+
self.fmap_pool[key] = output.detach()
|
26 |
+
|
27 |
+
return forward_hook
|
28 |
+
|
29 |
+
def save_grads(key):
|
30 |
+
def backward_hook(module, grad_in, grad_out):
|
31 |
+
self.grad_pool[key] = grad_out[0].detach()
|
32 |
+
|
33 |
+
return backward_hook
|
34 |
+
|
35 |
+
self.device = next(model.parameters()).device
|
36 |
+
self.model = model
|
37 |
+
self.handlers = [] # a set of hook function handlers
|
38 |
+
self.fmap_pool = {}
|
39 |
+
self.grad_pool = {}
|
40 |
+
self.candidate_layers = candidate_layers # list
|
41 |
+
|
42 |
+
for name, module in self.model.named_modules():
|
43 |
+
if self.candidate_layers is None or name in self.candidate_layers:
|
44 |
+
self.handlers.append(module.register_forward_hook(save_fmaps(name)))
|
45 |
+
self.handlers.append(module.register_backward_hook(save_grads(name)))
|
46 |
+
|
47 |
+
def _encode_one_hot(self, ids):
|
48 |
+
one_hot = torch.zeros_like(self.nll).to(self.device)
|
49 |
+
print(one_hot.shape)
|
50 |
+
one_hot.scatter_(1, ids, 1.0)
|
51 |
+
return one_hot
|
52 |
+
|
53 |
+
def forward(self, image):
|
54 |
+
self.image_shape = image.shape[2:] # HxW
|
55 |
+
self.nll = self.model(image)
|
56 |
+
#self.probs = F.softmax(self.logits, dim=1)
|
57 |
+
return self.nll.sort(dim=1, descending=True) # ordered results
|
58 |
+
|
59 |
+
def backward(self, ids):
|
60 |
+
"""
|
61 |
+
Class-specific backpropagation
|
62 |
+
"""
|
63 |
+
one_hot = self._encode_one_hot(ids)
|
64 |
+
self.model.zero_grad()
|
65 |
+
self.nll.backward(gradient=one_hot, retain_graph=True)
|
66 |
+
|
67 |
+
def remove_hook(self):
|
68 |
+
"""
|
69 |
+
Remove all the forward/backward hook functions
|
70 |
+
"""
|
71 |
+
for handle in self.handlers:
|
72 |
+
handle.remove()
|
73 |
+
|
74 |
+
def _find(self, pool, target_layer):
|
75 |
+
if target_layer in pool.keys():
|
76 |
+
return pool[target_layer]
|
77 |
+
else:
|
78 |
+
raise ValueError("Invalid layer name: {}".format(target_layer))
|
79 |
+
|
80 |
+
def generate(self, target_layer):
|
81 |
+
fmaps = self._find(self.fmap_pool, target_layer)
|
82 |
+
grads = self._find(self.grad_pool, target_layer)
|
83 |
+
weights = F.adaptive_avg_pool2d(grads, 1)
|
84 |
+
|
85 |
+
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
|
86 |
+
gcam = F.relu(gcam)
|
87 |
+
# need to capture image size duign forward pass
|
88 |
+
gcam = F.interpolate(
|
89 |
+
gcam, self.image_shape, mode="bilinear", align_corners=False
|
90 |
+
)
|
91 |
+
|
92 |
+
# scale output between 0,1
|
93 |
+
B, C, H, W = gcam.shape
|
94 |
+
gcam = gcam.view(B, -1)
|
95 |
+
gcam -= gcam.min(dim=1, keepdim=True)[0]
|
96 |
+
gcam /= gcam.max(dim=1, keepdim=True)[0]
|
97 |
+
gcam = gcam.view(B, C, H, W)
|
98 |
+
|
99 |
+
return gcam
|
100 |
+
|
101 |
+
def generate_gradcam(misclassified_images, model, target_layers,device):
|
102 |
+
images=[]
|
103 |
+
labels=[]
|
104 |
+
for i, (img, pred, correct) in enumerate(misclassified_images):
|
105 |
+
images.append(img)
|
106 |
+
labels.append(correct)
|
107 |
+
|
108 |
+
model.eval()
|
109 |
+
|
110 |
+
# map input to device
|
111 |
+
images = torch.stack(images).to(device)
|
112 |
+
|
113 |
+
# set up grad cam
|
114 |
+
gcam = GradCAM(model, target_layers)
|
115 |
+
|
116 |
+
# forward pass
|
117 |
+
probs, ids = gcam.forward(images)
|
118 |
+
|
119 |
+
# outputs agaist which to compute gradients
|
120 |
+
ids_ = torch.LongTensor(labels).view(len(images),-1).to(device)
|
121 |
+
|
122 |
+
# backward pass
|
123 |
+
gcam.backward(ids=ids_)
|
124 |
+
layers = []
|
125 |
+
for i in range(len(target_layers)):
|
126 |
+
target_layer = target_layers[i]
|
127 |
+
print("Generating Grad-CAM @{}".format(target_layer))
|
128 |
+
# Grad-CAM
|
129 |
+
layers.append(gcam.generate(target_layer=target_layer))
|
130 |
+
|
131 |
+
# remove hooks when done
|
132 |
+
gcam.remove_hook()
|
133 |
+
return layers, probs, ids
|
134 |
+
|
135 |
+
def plot_gradcam_images(gcam_layers, target_layers, classes, image_size,predicted, misclassified_images):
|
136 |
+
|
137 |
+
images=[]
|
138 |
+
labels=[]
|
139 |
+
for i, (img, pred, correct) in enumerate(misclassified_images):
|
140 |
+
images.append(img)
|
141 |
+
labels.append(correct)
|
142 |
+
|
143 |
+
c = len(images)+1
|
144 |
+
r = len(target_layers)+2
|
145 |
+
fig = plt.figure(figsize=(60,30))
|
146 |
+
fig.subplots_adjust(hspace=0.01, wspace=0.01)
|
147 |
+
ax = plt.subplot(r, c, 1)
|
148 |
+
ax.text(0.3,-0.5, "INPUT", fontsize=28)
|
149 |
+
plt.axis('off')
|
150 |
+
for i in range(len(target_layers)):
|
151 |
+
target_layer = target_layers[i]
|
152 |
+
ax = plt.subplot(r, c, c*(i+1)+1)
|
153 |
+
ax.text(0.3,-0.5, target_layer, fontsize=28)
|
154 |
+
plt.axis('off')
|
155 |
+
|
156 |
+
for j in range(len(images)):
|
157 |
+
img = np.uint8(255 * denormalize(images[j].view(image_size)))
|
158 |
+
if i==0:
|
159 |
+
ax = plt.subplot(r, c, j+2)
|
160 |
+
ax.text(0, 0.2, f"actual: {classes[labels[j]]} \npred: {classes[predicted[j][0]]}", fontsize=18)
|
161 |
+
plt.axis('off')
|
162 |
+
plt.subplot(r, c, c+j+2)
|
163 |
+
plt.imshow(img)
|
164 |
+
plt.axis('off')
|
165 |
+
|
166 |
+
|
167 |
+
heatmap = 1-gcam_layers[i][j].cpu().numpy()[0] # reverse the color map
|
168 |
+
heatmap = np.uint8(255 * heatmap)
|
169 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
170 |
+
superimposed_img = cv2.resize(cv2.addWeighted(img, 0.5, heatmap, 0.5, 0), (128,128))
|
171 |
+
plt.subplot(r, c, (i+2)*c+j+2)
|
172 |
+
plt.imshow(superimposed_img, interpolation='bilinear')
|
173 |
+
|
174 |
+
plt.axis('off')
|
175 |
+
plt.show()
|
utils/one_cycle_lr.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.optim as optim
|
2 |
+
|
3 |
+
|
4 |
+
def get_onecycle_scheduler(optimizer,mx_lr,train_loader,num_epochs):
|
5 |
+
return optim.lr_scheduler.OneCycleLR(
|
6 |
+
optimizer,
|
7 |
+
max_lr=mx_lr,
|
8 |
+
epochs=num_epochs,
|
9 |
+
steps_per_epoch=len(train_loader),
|
10 |
+
pct_start=5/num_epochs,
|
11 |
+
div_factor=100,
|
12 |
+
three_phase=False,
|
13 |
+
final_div_factor=100,
|
14 |
+
anneal_strategy='linear')
|
utils/transforms.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import albumentations as A
|
3 |
+
from albumentations.pytorch import ToTensorV2
|
4 |
+
# import torchvision.transforms as transforms
|
5 |
+
|
6 |
+
norm_mean=(0.4914, 0.4822, 0.4465)
|
7 |
+
norm_std=(0.2023, 0.1994, 0.2010)
|
8 |
+
|
9 |
+
train_transforms = A.Compose(
|
10 |
+
[
|
11 |
+
A.Sequential([
|
12 |
+
A.PadIfNeeded(
|
13 |
+
min_height=40,
|
14 |
+
min_width=40,
|
15 |
+
border_mode=cv2.BORDER_CONSTANT,
|
16 |
+
value=(norm_mean[0]*255, norm_mean[1]*255, norm_mean[2]*255)
|
17 |
+
),
|
18 |
+
A.RandomCrop(
|
19 |
+
height=32,
|
20 |
+
width=32
|
21 |
+
)
|
22 |
+
], p=1),
|
23 |
+
A.CoarseDropout(
|
24 |
+
max_holes=2,
|
25 |
+
max_height=16,
|
26 |
+
max_width=16,
|
27 |
+
min_holes=1,
|
28 |
+
min_height=8,
|
29 |
+
min_width=8,
|
30 |
+
fill_value=tuple((x * 255.0 for x in norm_mean)),
|
31 |
+
p=0.8,
|
32 |
+
),
|
33 |
+
A.Normalize(norm_mean, norm_std),
|
34 |
+
ToTensorV2()
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
test_transforms = A.Compose(
|
39 |
+
[
|
40 |
+
A.Normalize(norm_mean, norm_std, always_apply=True),
|
41 |
+
ToTensorV2()
|
42 |
+
]
|
43 |
+
)
|
utils/utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
# import torch.optim as optim
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# from tqdm import tqdm
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
|
12 |
+
def return_dataset_images(train_loader, total_images):
|
13 |
+
batch_data, batch_label = next(iter(train_loader))
|
14 |
+
|
15 |
+
fig = plt.figure()
|
16 |
+
|
17 |
+
for i in range(total_images):
|
18 |
+
plt.subplot(3,4,i+1)
|
19 |
+
plt.tight_layout()
|
20 |
+
# plt.imshow(batch_data[i].squeeze(0), cmap='gray')
|
21 |
+
plt.imshow(batch_data[i].permute(1,2,0), cmap='gray')
|
22 |
+
plt.title(batch_label[i].item())
|
23 |
+
plt.xticks([])
|
24 |
+
plt.yticks([])
|
25 |
+
|
26 |
+
def GetCorrectPredCount(pPrediction, pLabels):
|
27 |
+
return pPrediction.argmax(dim=1).eq(pLabels).sum().item()
|
28 |
+
|
29 |
+
|
30 |
+
def get_incorrrect_predictions(model, loader, device):
|
31 |
+
"""Get all incorrect predictions
|
32 |
+
Args:
|
33 |
+
model (Net): Trained model
|
34 |
+
loader (DataLoader): instance of data loader
|
35 |
+
device (str): Which device to use cuda/cpu
|
36 |
+
Returns:
|
37 |
+
list: list of all incorrect predictions and their corresponding details
|
38 |
+
"""
|
39 |
+
model.eval()
|
40 |
+
incorrect = []
|
41 |
+
with torch.no_grad():
|
42 |
+
for data, target in loader:
|
43 |
+
data, target = data.to(device), target.to(device)
|
44 |
+
output = model(data)
|
45 |
+
loss = F.nll_loss(output, target)
|
46 |
+
pred = output.argmax(dim=1)
|
47 |
+
for d, t, p, o in zip(data, target, pred, output):
|
48 |
+
if p.eq(t.view_as(p)).item() == False:
|
49 |
+
incorrect.append(
|
50 |
+
[d.cpu(), t.cpu(), p.cpu(), o[p.item()].cpu()])
|
51 |
+
|
52 |
+
return incorrect
|
53 |
+
|
54 |
+
def plot_incorrect_predictions(predictions, class_map, count=10):
|
55 |
+
"""Plot Incorrect predictions
|
56 |
+
Args:
|
57 |
+
predictions (list): List of all incorrect predictions
|
58 |
+
class_map (dict): Lable mapping
|
59 |
+
count (int, optional): Number of samples to print, multiple of 5. Defaults to 10.
|
60 |
+
"""
|
61 |
+
print(f'Total Incorrect Predictions {len(predictions)}')
|
62 |
+
|
63 |
+
if not count % 5 == 0:
|
64 |
+
print("Count should be multiple of 10")
|
65 |
+
return
|
66 |
+
|
67 |
+
classes = list(class_map.values())
|
68 |
+
|
69 |
+
fig = plt.figure(figsize=(10, 5))
|
70 |
+
for i, (d, t, p, o) in enumerate(predictions):
|
71 |
+
ax = fig.add_subplot(int(count/5), 5, i + 1, xticks=[], yticks=[])
|
72 |
+
ax.set_title(f'{classes[t.item()]}/{classes[p.item()]}')
|
73 |
+
plt.imshow(d.cpu().numpy().transpose(1, 2, 0))
|
74 |
+
if i+1 == 5*(count/5):
|
75 |
+
break
|
76 |
+
|
77 |
+
def wrong_predictions(model,test_loader, norm_mean, norm_std, classes, device):
|
78 |
+
wrong_images=[]
|
79 |
+
wrong_label=[]
|
80 |
+
correct_label=[]
|
81 |
+
|
82 |
+
correct_images=[]
|
83 |
+
correct_images_labels=[]
|
84 |
+
|
85 |
+
model.eval()
|
86 |
+
with torch.no_grad():
|
87 |
+
for data, target in test_loader:
|
88 |
+
data, target = data.to(device), target.to(device)
|
89 |
+
output = model(data)
|
90 |
+
pred = output.argmax(dim=1, keepdim=True).squeeze() # get the index of the max log-probability
|
91 |
+
|
92 |
+
wrong_pred = (pred.eq(target.view_as(pred)) == False)
|
93 |
+
wrong_images.append(data[wrong_pred])
|
94 |
+
wrong_label.append(pred[wrong_pred])
|
95 |
+
correct_label.append(target.view_as(pred)[wrong_pred])
|
96 |
+
|
97 |
+
# wrong_pred = (pred.eq(target.view_as(pred)) == False)
|
98 |
+
correct_images.append(data)
|
99 |
+
correct_images_labels.append(pred)
|
100 |
+
|
101 |
+
wrong_predictions = list(zip(torch.cat(wrong_images),torch.cat(wrong_label),torch.cat(correct_label)))
|
102 |
+
all_predictions = list(zip(torch.cat(correct_images),torch.cat(correct_images_labels),torch.cat(correct_images_labels)))
|
103 |
+
if len(wrong_predictions)>100:
|
104 |
+
break
|
105 |
+
print(f'Total wrong predictions are {len(wrong_predictions)}')
|
106 |
+
|
107 |
+
# plot_misclassified(wrong_predictions, norm_mean, norm_std, classes)
|
108 |
+
|
109 |
+
return wrong_predictions, all_predictions
|
110 |
+
|
111 |
+
def plot_misclassified(wrong_predictions, norm_mean, norm_std, classes):
|
112 |
+
fig = plt.figure(figsize=(10,12))
|
113 |
+
fig.tight_layout()
|
114 |
+
for i, (img, pred, correct) in enumerate(wrong_predictions[:20]):
|
115 |
+
img, pred, target = img.cpu().numpy().astype(dtype=np.float32), pred.cpu(), correct.cpu()
|
116 |
+
for j in range(img.shape[0]):
|
117 |
+
img[j] = (img[j]*norm_std[j])+norm_mean[j]
|
118 |
+
|
119 |
+
img = np.transpose(img, (1, 2, 0)) #/ 2 + 0.5
|
120 |
+
ax = fig.add_subplot(5, 5, i+1)
|
121 |
+
ax.axis('off')
|
122 |
+
ax.set_title(f'\nactual : {classes[target.item()]}\npredicted : {classes[pred.item()]}',fontsize=10)
|
123 |
+
ax.imshow(img)
|
124 |
+
|
125 |
+
plt.show()
|