Siyun He commited on
Commit
5144969
·
1 Parent(s): 5b77f0f

upload code and model

Browse files
Files changed (2) hide show
  1. app.py +54 -0
  2. grass_wood_classification_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from torchvision.models import ResNet18_Weights
6
+
7
+ # Load the saved model
8
+ model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
9
+ model.fc = nn.Linear(model.fc.in_features, 1000) # Adjust to match the original model's output units
10
+ model.load_state_dict(torch.load('grass_wood_classification_model.pth'))
11
+ model.eval()
12
+
13
+ # Create a new model with the correct final layer
14
+ new_model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
15
+ new_model.fc = nn.Linear(new_model.fc.in_features, 2) # Adjust to match the desired output units
16
+
17
+ # Copy the weights and biases from the loaded model to the new model
18
+ new_model.fc.weight.data = model.fc.weight.data[0:2] # Copy only the first 2 output units
19
+ new_model.fc.bias.data = model.fc.bias.data[0:2]
20
+
21
+ # Define the preprocessing function
22
+ def preprocess_image(image):
23
+ preprocess = transforms.Compose([
24
+ transforms.Resize(256),
25
+ transforms.CenterCrop(224),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28
+ ])
29
+ input_tensor = preprocess(image)
30
+ input_batch = input_tensor.unsqueeze(0) # Add a batch dimension
31
+ return input_batch
32
+
33
+ # Define the prediction function
34
+ def predict(image):
35
+ input_batch = preprocess_image(image)
36
+ new_model.eval()
37
+ with torch.no_grad():
38
+ output = new_model(input_batch)
39
+ _, predicted_class = output.max(1)
40
+ class_names = ['grass', 'wood']
41
+ predicted_class_name = class_names[predicted_class.item()]
42
+ return predicted_class_name
43
+
44
+ # Create the Gradio interface
45
+ demo = gr.Interface(
46
+ fn=predict,
47
+ inputs=gr.Image(type='pil', label="Upload an Image"), # Use 'pil' to match the input type
48
+ outputs="text",
49
+ title="Grass or Wood Classifier Using ResNet18",
50
+ description="Upload an image to classify it as either grass or wood."
51
+ )
52
+
53
+ # Launch the interface
54
+ demo.launch(share=True)
grass_wood_classification_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52c0507644bb63b668eb8ef00c0fa55a6a5e20811088274be049551fac2aa78f
3
+ size 46838286