Spaces:
Runtime error
Runtime error
# Import useful model | |
import torch | |
import gradio as gr | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import torchvision.models as models | |
INPUT_SIZE = 256 # Size of input images | |
N_CLASSES = 8 | |
N_LABELS = 2 | |
transform = transforms.Compose( | |
[ | |
transforms.Grayscale(num_output_channels=3), | |
transforms.Resize((INPUT_SIZE, INPUT_SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) | |
] | |
) | |
# Create a dictionary that maps key to values | |
labels = { | |
'0': 'Non-Poisonous', | |
'1': 'Poisonous' | |
} | |
classes = { | |
'0': 'Angel', | |
'1': 'Death', | |
'2': 'Elder', | |
'3': 'Misletoe', | |
'4': 'Cherrys', | |
'5': 'CloudBerrys', | |
'6': 'Lion', | |
'7': 'Oyster' | |
} | |
# Define Model architecture | |
class MultiTaskModel(nn.Module): | |
# n_classes : number of species - 8 | |
# n_labels : poisonos/non-poisonous | |
def __init__(self, n_classes, n_labels): | |
super(MultiTaskModel, self).__init__() | |
# initialize base model (RestNet) | |
self.base_model = models.resnet50(pretrained=True) | |
num_features = self.base_model.fc.in_features | |
# Freeze the base model | |
for param in self.base_model.parameters(): | |
param.requires_grad = False | |
# Removes the last layer | |
self.base_model = nn.Sequential(*list(self.base_model.children())[:-1]) | |
# Define new layers | |
# Create a dense layer with input feat = num_features, and output features = 128 | |
self.fc = nn.Linear(num_features, 128) | |
self.fc_classes = nn.Linear(128, n_classes) | |
self.fc_labels = nn.Linear(128, n_labels) | |
def forward(self, x): | |
x = self.base_model(x) | |
x = x.view(x.size(0), -1) # Flatten the tensors | |
x = F.relu(self.fc(x)) | |
species_output = self.fc_classes(x) | |
poisonous_output = self.fc_labels(x) | |
return species_output, poisonous_output | |
# Function that takes image and returns the predictions | |
def predict(image): | |
# 5. Apply transformation to your image # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=Ajc6lqtPY8WG&line=26&uniqifier=1 | |
input = transform(image) | |
# 6. Add an extra dimension for batch | |
input = input.unsqueeze(0) | |
# 7. Send image to our model # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=thZBTz8gNHUv&line=13&uniqifier=1 | |
with torch.no_grad(): | |
classes_outputs, labels_outputs = model(input) | |
# 8. Get classes and labels # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=thZBTz8gNHUv&line=17&uniqifier=1 | |
class_probs = F.softmax(classes_outputs, dim=1)[0] | |
label_probs = F.softmax(labels_outputs, dim=1)[0] | |
class_confidence = {classes[str(i)]: float(class_probs[i]) for i in range(len(class_probs))} | |
label_confidence = {labels[str(i)]: float(label_probs[i]) for i in range(len(label_probs))} | |
return label_confidence, class_confidence | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Define our model | |
model = MultiTaskModel(N_CLASSES, N_LABELS) | |
# 3. Load weights | |
model.load_state_dict(torch.load('best_model_nikita.pt', map_location=device)) | |
# 4. Set model to evaluation mode | |
model.eval() | |
title = "BaMI Classifier" | |
description = "A plant based classifier to classify the particular species of Mushroom and Berries, along with it's edibality" | |
demo = gr.Interface(fn = predict, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=[gr.outputs.Label(num_top_classes=1), gr.outputs.Label(num_top_classes=3)], # selec | |
examples=['image3.jpg', 'image9.jpg', 'image12.jpg'], | |
title=title,description=description | |
) | |
demo.launch(debug=True) |