File size: 3,718 Bytes
ffcac11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95817ca
 
ffcac11
 
 
 
95817ca
 
ffcac11
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Import useful model
import torch
import gradio as gr
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models

INPUT_SIZE = 256 # Size of input images
N_CLASSES = 8
N_LABELS = 2

transform = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ]
)


# Create a dictionary that maps key to values
labels = {
    '0': 'Non-Poisonous',
    '1': 'Poisonous'
}

classes = {
    '0': 'Angel',
    '1': 'Death',
    '2': 'Elder',
    '3': 'Misletoe',
    '4': 'Cherrys',
    '5': 'CloudBerrys',
    '6': 'Lion',
    '7': 'Oyster'
}


# Define Model architecture
class MultiTaskModel(nn.Module):
  # n_classes : number of species - 8
  # n_labels : poisonos/non-poisonous
  def __init__(self, n_classes, n_labels):
    super(MultiTaskModel, self).__init__()

    # initialize base model (RestNet)
    self.base_model = models.resnet50(pretrained=True)
    num_features = self.base_model.fc.in_features

    # Freeze the base model
    for param in self.base_model.parameters():
        param.requires_grad = False

    # Removes the last layer
    self.base_model = nn.Sequential(*list(self.base_model.children())[:-1])

    # Define new layers
    # Create a dense layer with input feat = num_features, and output features = 128
    self.fc = nn.Linear(num_features, 128)
    self.fc_classes = nn.Linear(128, n_classes)
    self.fc_labels = nn.Linear(128, n_labels)

  def forward(self, x):
    x = self.base_model(x)
    x = x.view(x.size(0), -1) # Flatten the tensors
    x = F.relu(self.fc(x))

    species_output = self.fc_classes(x)
    poisonous_output = self.fc_labels(x)

    return species_output, poisonous_output


# Function that takes image and returns the predictions
def predict(image):
  # 5. Apply transformation to your image # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=Ajc6lqtPY8WG&line=26&uniqifier=1
  input = transform(image)

  # 6. Add an extra dimension for batch
  input = input.unsqueeze(0)

  # 7. Send image to our model # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=thZBTz8gNHUv&line=13&uniqifier=1
  with torch.no_grad():
    classes_outputs, labels_outputs = model(input)

  # 8. Get classes and labels # https://colab.research.google.com/drive/1V4Y6w8WwpN7jwjLKItx3kf-RbQijdMqp#scrollTo=thZBTz8gNHUv&line=17&uniqifier=1
  class_probs = F.softmax(classes_outputs, dim=1)[0]
  label_probs = F.softmax(labels_outputs, dim=1)[0]
  
  class_confidence = {classes[str(i)]: float(class_probs[i]) for i in range(len(class_probs))}
  label_confidence = {labels[str(i)]: float(label_probs[i]) for i in range(len(label_probs))}

  return label_confidence, class_confidence


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define our model
model = MultiTaskModel(N_CLASSES, N_LABELS)

# 3. Load weights
model.load_state_dict(torch.load('best_model_nikita.pt', map_location=device))

# 4. Set model to evaluation mode
model.eval()

title = "BaMI Classifier"
description = "A plant based classifier to classify the particular species of Mushroom and Berries, along with it's edibality"

demo = gr.Interface(fn = predict,
             inputs=gr.inputs.Image(type="pil"),
             outputs=[gr.outputs.Label(num_top_classes=1), gr.outputs.Label(num_top_classes=3)], # selec
             examples=['image3.jpg', 'image9.jpg', 'image12.jpg'],
             title=title,description=description
             )

demo.launch(debug=True)