ClassCat's picture
update app.py
f3f78e5
raw
history blame
5.1 kB
# common
import os, sys
import math
#import numpy as np
#from random import randrange
# torch
import torch
from torch import nn
#from torch import einsum
import torch.nn.functional as F
#from torch import optim
#from torch.optim import lr_scheduler
#from torch.utils.data import DataLoader
#from torch.utils.data.sampler import SubsetRandomSampler
# torchVision
import torchvision
from torchvision import transforms
#from torchvision import models
#from torchvision.datasets import CIFAR10, CIFAR100
# torchinfo
#from torchinfo import summary
# Define model
class WideBasic(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1
)
)
self.shortcut = nn.Sequential()
if in_channels != out_channels or stride != 1:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=stride)
)
def forward(self, x):
residual = self.residual(x)
shortcut = self.shortcut(x)
return residual + shortcut
class WideResNet(nn.Module):
def __init__(self, num_classes, block, depth=50, widen_factor=1):
super().__init__()
self.depth = depth
k = widen_factor
l = int((depth - 4) / 6)
self.in_channels = 16
self.init_conv = nn.Conv2d(3, self.in_channels, 3, 1, padding=1)
self.conv2 = self._make_layer(block, 16 * k, l, 1)
self.conv3 = self._make_layer(block, 32 * k, l, 2)
self.conv4 = self._make_layer(block, 64 * k, l, 2)
self.bn = nn.BatchNorm2d(64 * k)
self.relu = nn.ReLU(inplace=True)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.linear = nn.Linear(64 * k, num_classes)
def forward(self, x):
x = self.init_conv(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.bn(x)
x = self.relu(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
return nn.Sequential(*layers)
model = WideResNet(10, WideBasic, depth=40, widen_factor=10)
model.load_state_dict(
torch.load("weights/cifar10_wide_resnet_model.pt",
map_location=torch.device('cpu'))
)
model.eval()
import gradio as gr
from torchvision import transforms
import os
import glob
examples_dir = './examples'
example_files = glob.glob(os.path.join(examples_dir, '*.png'))
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616],
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
classes = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
def predict(image):
tsr_image = transform(image).unsqueeze(dim=0)
model.eval()
with torch.no_grad():
pred = model(tsr_image)
prob = torch.nn.functional.softmax(pred[0], dim=0)
confidences = {classes[i]: float(prob[i]) for i in range(10)}
return confidences
with gr.Blocks(css=".gradio-container {background:honeydew;}", title="WideResNet - CIFAR10 Classification"
) as demo:
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">WideResNet - CIFAR10 Classification</div>""")
with gr.Row():
input_image = gr.Image(type="pil", image_mode="RGB", shape=(32, 32))
output_label=gr.Label(label="Probabilities", num_top_classes=3)
send_btn = gr.Button("Infer")
with gr.Row():
gr.Examples(['./examples/cifar10_test00.png'], label='dog', inputs=input_image)
gr.Examples(['./examples/cifar10_test01.png'], label='ship', inputs=input_image)
#gr.Examples(example_files, inputs=input_image)
#gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2)
send_btn.click(fn=predict, inputs=input_image, outputs=output_label)
# demo.queue(concurrency_count=3)
demo.launch()
### EOF ###