|
|
|
|
|
import os, sys |
|
import math |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torchvision |
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WideBasic(nn.Module): |
|
def __init__(self, in_channels, out_channels, stride=1): |
|
super().__init__() |
|
self.residual = nn.Sequential( |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=1 |
|
), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(), |
|
nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
) |
|
) |
|
|
|
self.shortcut = nn.Sequential() |
|
|
|
if in_channels != out_channels or stride != 1: |
|
self.shortcut = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 1, stride=stride) |
|
) |
|
|
|
def forward(self, x): |
|
residual = self.residual(x) |
|
shortcut = self.shortcut(x) |
|
|
|
return residual + shortcut |
|
|
|
class WideResNet(nn.Module): |
|
def __init__(self, num_classes, block, depth=50, widen_factor=1): |
|
super().__init__() |
|
|
|
self.depth = depth |
|
k = widen_factor |
|
l = int((depth - 4) / 6) |
|
self.in_channels = 16 |
|
self.init_conv = nn.Conv2d(3, self.in_channels, 3, 1, padding=1) |
|
self.conv2 = self._make_layer(block, 16 * k, l, 1) |
|
self.conv3 = self._make_layer(block, 32 * k, l, 2) |
|
self.conv4 = self._make_layer(block, 64 * k, l, 2) |
|
self.bn = nn.BatchNorm2d(64 * k) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
self.linear = nn.Linear(64 * k, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.init_conv(x) |
|
x = self.conv2(x) |
|
x = self.conv3(x) |
|
x = self.conv4(x) |
|
x = self.bn(x) |
|
x = self.relu(x) |
|
x = self.avg_pool(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.linear(x) |
|
|
|
return x |
|
|
|
def _make_layer(self, block, out_channels, num_blocks, stride): |
|
strides = [stride] + [1] * (num_blocks - 1) |
|
layers = [] |
|
for stride in strides: |
|
layers.append(block(self.in_channels, out_channels, stride)) |
|
self.in_channels = out_channels |
|
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
model = WideResNet(10, WideBasic, depth=40, widen_factor=10) |
|
model.load_state_dict( |
|
torch.load("weights/cifar10_wide_resnet_model.pt", |
|
map_location=torch.device('cpu')) |
|
) |
|
|
|
model.eval() |
|
|
|
import gradio as gr |
|
from torchvision import transforms |
|
|
|
import os |
|
import glob |
|
|
|
examples_dir = './examples' |
|
example_files = glob.glob(os.path.join(examples_dir, '*.png')) |
|
|
|
normalize = transforms.Normalize( |
|
mean=[0.4914, 0.4822, 0.4465], |
|
std=[0.2470, 0.2435, 0.2616], |
|
) |
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
classes = [ |
|
"airplane", |
|
"automobile", |
|
"bird", |
|
"cat", |
|
"deer", |
|
"dog", |
|
"frog", |
|
"horse", |
|
"ship", |
|
"truck", |
|
] |
|
|
|
def predict(image): |
|
tsr_image = transform(image).unsqueeze(dim=0) |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
pred = model(tsr_image) |
|
prob = torch.nn.functional.softmax(pred[0], dim=0) |
|
|
|
confidences = {classes[i]: float(prob[i]) for i in range(10)} |
|
return confidences |
|
|
|
|
|
with gr.Blocks(css=".gradio-container {background:honeydew;}", title="WideResNet - CIFAR10 Classification" |
|
) as demo: |
|
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">WideResNet - CIFAR10 Classification</div>""") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="pil", image_mode="RGB", shape=(32, 32)) |
|
|
|
output_label=gr.Label(label="Probabilities", num_top_classes=3) |
|
|
|
send_btn = gr.Button("Infer") |
|
|
|
with gr.Row(): |
|
gr.Examples(['./examples/cifar10_test00.png'], label='dog', inputs=input_image) |
|
gr.Examples(['./examples/cifar10_test01.png'], label='ship', inputs=input_image) |
|
|
|
|
|
|
|
send_btn.click(fn=predict, inputs=input_image, outputs=output_label) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|