niki-stha's picture
Update app.py
95817ca
# Import useful model
import torch
import gradio as gr
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
INPUT_SIZE = 256 # Size of input images
N_CLASSES = 8
N_LABELS = 2
transform = transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
]
)
# Create a dictionary that maps key to values
labels = {
'0': 'Non-Poisonous',
'1': 'Poisonous'
}
classes = {
'0': 'Angel',
'1': 'Death',
'2': 'Elder',
'3': 'Misletoe',
'4': 'Cherrys',
'5': 'CloudBerrys',
'6': 'Lion',
'7': 'Oyster'
}
# Define Model architecture
class MultiTaskModel(nn.Module):
# n_classes : number of species - 8
# n_labels : poisonos/non-poisonous
def __init__(self, n_classes, n_labels):
super(MultiTaskModel, self).__init__()
# initialize base model (RestNet)
self.base_model = models.resnet50(pretrained=True)
num_features = self.base_model.fc.in_features
# Freeze the base model
for param in self.base_model.parameters():
param.requires_grad = False
# Removes the last layer
self.base_model = nn.Sequential(*list(self.base_model.children())[:-1])
# Define new layers
# Create a dense layer with input feat = num_features, and output features = 128
self.fc = nn.Linear(num_features, 128)
self.fc_classes = nn.Linear(128, n_classes)
self.fc_labels = nn.Linear(128, n_labels)
def forward(self, x):
x = self.base_model(x)
x = x.view(x.size(0), -1) # Flatten the tensors
x = F.relu(self.fc(x))
species_output = self.fc_classes(x)
poisonous_output = self.fc_labels(x)
return species_output, poisonous_output
# Function that takes image and returns the predictions
def predict(image):
# 5. Apply transformation to your image # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=Ajc6lqtPY8WG&line=26&uniqifier=1
input = transform(image)
# 6. Add an extra dimension for batch
input = input.unsqueeze(0)
# 7. Send image to our model # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=thZBTz8gNHUv&line=13&uniqifier=1
with torch.no_grad():
classes_outputs, labels_outputs = model(input)
# 8. Get classes and labels # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=thZBTz8gNHUv&line=17&uniqifier=1
class_probs = F.softmax(classes_outputs, dim=1)[0]
label_probs = F.softmax(labels_outputs, dim=1)[0]
class_confidence = {classes[str(i)]: float(class_probs[i]) for i in range(len(class_probs))}
label_confidence = {labels[str(i)]: float(label_probs[i]) for i in range(len(label_probs))}
return label_confidence, class_confidence
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define our model
model = MultiTaskModel(N_CLASSES, N_LABELS)
# 3. Load weights
model.load_state_dict(torch.load('best_model_nikita.pt', map_location=device))
# 4. Set model to evaluation mode
model.eval()
title = "BaMI Classifier"
description = "A plant based classifier to classify the particular species of Mushroom and Berries, along with it's edibality"
demo = gr.Interface(fn = predict,
inputs=gr.inputs.Image(type="pil"),
outputs=[gr.outputs.Label(num_top_classes=1), gr.outputs.Label(num_top_classes=3)], # selec
examples=['image3.jpg', 'image9.jpg', 'image12.jpg'],
title=title,description=description
)
demo.launch(debug=True)