# common import os, sys import math #import numpy as np #from random import randrange # torch import torch from torch import nn #from torch import einsum import torch.nn.functional as F #from torch import optim #from torch.optim import lr_scheduler #from torch.utils.data import DataLoader #from torch.utils.data.sampler import SubsetRandomSampler # torchVision import torchvision from torchvision import transforms #from torchvision import models #from torchvision.datasets import CIFAR10, CIFAR100 # torchinfo #from torchinfo import summary # Define model class WideBasic(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual = nn.Sequential( nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1 ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout(), nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) ) self.shortcut = nn.Sequential() if in_channels != out_channels or stride != 1: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride=stride) ) def forward(self, x): residual = self.residual(x) shortcut = self.shortcut(x) return residual + shortcut class WideResNet(nn.Module): def __init__(self, num_classes, block, depth=50, widen_factor=1): super().__init__() self.depth = depth k = widen_factor l = int((depth - 4) / 6) self.in_channels = 16 self.init_conv = nn.Conv2d(3, self.in_channels, 3, 1, padding=1) self.conv2 = self._make_layer(block, 16 * k, l, 1) self.conv3 = self._make_layer(block, 32 * k, l, 2) self.conv4 = self._make_layer(block, 64 * k, l, 2) self.bn = nn.BatchNorm2d(64 * k) self.relu = nn.ReLU(inplace=True) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.linear = nn.Linear(64 * k, num_classes) def forward(self, x): x = self.init_conv(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.bn(x) x = self.relu(x) x = self.avg_pool(x) x = x.view(x.size(0), -1) x = self.linear(x) return x def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels return nn.Sequential(*layers) model = WideResNet(10, WideBasic, depth=40, widen_factor=10) model.load_state_dict( torch.load("weights/cifar10_wide_resnet_model.pt", map_location=torch.device('cpu')) ) model.eval() import gradio as gr from torchvision import transforms import os import glob examples_dir = './examples' example_files = glob.glob(os.path.join(examples_dir, '*.png')) normalize = transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616], ) transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) classes = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ] def predict(image): tsr_image = transform(image).unsqueeze(dim=0) model.eval() with torch.no_grad(): pred = model(tsr_image) prob = torch.nn.functional.softmax(pred[0], dim=0) confidences = {classes[i]: float(prob[i]) for i in range(10)} return confidences with gr.Blocks(css=".gradio-container {background:honeydew;}", title="WideResNet - CIFAR10 Classification" ) as demo: gr.HTML("""
WideResNet - CIFAR10 Classification
""") with gr.Row(): input_image = gr.Image(type="pil", image_mode="RGB", shape=(32, 32)) output_label=gr.Label(label="Probabilities", num_top_classes=3) send_btn = gr.Button("Infer") with gr.Row(): gr.Examples(['./examples/cifar10_test00.png'], label='dog', inputs=input_image) gr.Examples(['./examples/cifar10_test01.png'], label='ship', inputs=input_image) #gr.Examples(example_files, inputs=input_image) #gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2) send_btn.click(fn=predict, inputs=input_image, outputs=output_label) # demo.queue(concurrency_count=3) demo.launch() ### EOF ###