File size: 7,397 Bytes
91969e1 27ca97b 91969e1 1afe6fa 91969e1 27ca97b 91969e1 d2b34a7 91969e1 aa40902 91969e1 bb2af32 9139403 aa24a58 91969e1 dceff07 91969e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import warnings
warnings.filterwarnings('ignore')
import torch, numpy as np, os
from torch import nn
from transformers import AutoModelForImageClassification, AutoConfig, AutoImageProcessor
import matplotlib.pyplot as plt
from PIL import Image
import saliency.core as saliency
import io
import gradio as gr
import PIL
model_choice = 3
model_names = ["nvidia/mit-b0",'facebook/convnext-base-224', 'microsoft/resnet-18', 'microsoft/swin-tiny-patch4-window7-224']
model_name = model_names[model_choice]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Model(nn.Module):
def __init__(self, MODEL_NAME=model_name):
super().__init__()
self.config = AutoConfig.from_pretrained(MODEL_NAME, finetuning_task="image-classification")
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
self.class_len = self.config.num_labels
self.id2label = self.config.id2label
self.label2id = self.config.label2id
def forward(self, x):
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if len(x.shape) == 3:
x = x.unsqueeze(0)
if x.shape[-1] == 3:
x = x.permute(0, 3, 1, 2)
x = x.to(device)
x = self.model(x)
return x.logits
def conv_layer_forward_hook(module, input, output):
"""Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
global last_conv_layer_outputs
last_conv_layer_outputs[saliency.base.CONVOLUTION_LAYER_VALUES] = torch.movedim(output, 3, 1).detach().cpu().numpy()
def conv_layer_backward_hook(module, grad_input, grad_output):
"""Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
global last_conv_layer_outputs
last_conv_layer_outputs[saliency.base.CONVOLUTION_OUTPUT_GRADIENTS] = torch.movedim(grad_output[0], 3, 1).detach().cpu().numpy()
auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs = None, None, None, None, None
def swap_models(name):
global model, auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs
auto_transformer = AutoImageProcessor.from_pretrained(name)
model = Model(MODEL_NAME=name)
model = model.to(device).eval()
# register the hooks for the last convolution layer for Grad-Cam
named_modules = dict(model.model.named_modules())
last_conv_layer_name = None
for name, module in named_modules.items():
if isinstance(module, torch.nn.Conv2d):
last_conv_layer_name = name
last_conv_layer = named_modules[last_conv_layer_name]
last_conv_layer_outputs = {}
last_conv_layer.register_forward_hook(conv_layer_forward_hook)
last_conv_layer.register_backward_hook(conv_layer_backward_hook)
class_to_id = {v:k for k,v in model.model.config.id2label.items()}
id_to_class = {k:v for k,v in model.model.config.id2label.items()}
swap_models(model_name)
def saliency_graph(img1, steps=25):
img1 = auto_transformer(img1)
img1 = np.squeeze(np.array(img1.pixel_values))
if img1.shape[0] < img1.shape[1]:
img1 = np.moveaxis(img1, 0, -1)
img1 = (img1 - np.min(img1)) / (np.max(img1) - np.min(img1))
class_idx_str = 'class_idx_str'
def gradcam_call(images, call_model_args=None, expected_keys=None):
if not isinstance(images, np.ndarray) and not isinstance(images, torch.Tensor) and not isinstance(images, PIL.Image.Image):
# return two blank images
im1 = np.zeros((224, 224, 3))
im2 = np.zeros((224, 224, 3))
return im1, im2
if len(images.shape) == 3:
images = np.expand_dims(images, 0)
images = torch.tensor(images, dtype=torch.float32)
images = images.requires_grad_(True)
target_class_idx = call_model_args[class_idx_str]
y_pred = model(images)
if saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys:
out = y_pred[:, target_class_idx]
# move actual color channel to the 1st dimension
#images = torch.movedim(images, 3, 1)
grads = torch.autograd.grad(out, images, grad_outputs=torch.ones_like(out))
grads = grads[0].detach().cpu().numpy()
return {saliency.base.INPUT_OUTPUT_GRADIENTS: grads}
else:
hot = torch.zeroes_like(y_pred)
hot[:, target_class_idx] = 1
model.zero_grad()
y_pred.backward(gradient=hot, retain_graph=True)
return last_conv_layer_outputs
im = img1.astype(np.float32)
base = np.zeros(img1.shape)
pred = model(torch.from_numpy(im))
class_pred = pred.argmax(dim=1).item()
call_model_args = {class_idx_str: class_pred}
gradients = saliency.IntegratedGradients()
s = gradients.GetSmoothedMask(im, gradcam_call, call_model_args, x_steps=steps, x_baseline=base, batch_size=25)
smoothgrad_mask_grayscale = saliency.VisualizeImageGrayscale(s)
with torch.no_grad():
output = model.forward(img1)
output = torch.nn.functional.softmax(output, dim=1)
output = output.cpu().numpy()
top_5 = [(id_to_class[int(i)], output[0][i]) for i in np.argsort(output)[0][-5:][::-1]]
# Render the saliency masks.
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.barh([x[0] for x in top_5], [x[1] for x in top_5])
ax.set_title('Top 5 Predictions')
buf = io.BytesIO()
fig.savefig(buf, format='jpg')
buf.seek(0)
fig_img = Image.open(buf)
plt.close(fig)
return smoothgrad_mask_grayscale, fig_img
# gradio Interface
def gradio_interface(img):
smoothgrad_mask_grayscale, fig_img = saliency_graph(img, steps=20)
return smoothgrad_mask_grayscale, fig_img
with gr.Blocks() as iface:
#examples = gr.Examples(examples=["ex1.jpg", "ex2.jpg", "ex3.jpg", "ex4.jpg"], label="Examples", inputs="image", examples_per_page=4)
gr.Markdown("This function finds the most critical pixels in an image for predicting a class by looking at the pixels models attend to. The best models will ideally make predictions by highlighting the expected object. Poorly generalizable models will often rely on environmental cues instead and forego looking at the most important pixels. Highlighting the most important pixels helps explain/build trust about whether a given model uses the correct features to make its prediction.")
with gr.Row():
with gr.Column():
test_image = gr.Image(label="Input Image")
input_btn = gr.Button("Classify image")
model_select_dropdown = gr.Radio(model_names, label="Model to test", interactive=True)
with gr.Column():
output = gr.Image(label="Pixels used for classification")
output2 = gr.Image(label="Top 5 Predictions")
input_btn.click(gradio_interface, test_image, outputs=[output, output2])
model_select_dropdown.change(swap_models, inputs=[model_select_dropdown])
examples = gr.Examples(
examples = [os.path.join('./', x) for x in os.listdir('./') if x.endswith('.jpg')],
inputs=gr.Image(),
label="Examples",
fn=gradio_interface,
cache_examples=True,
run_on_click=True,
postprocess=True,
preprocess=True,
outputs=[output, output2])
iface.launch()
|