Nvd commited on
Commit
1e98bba
·
1 Parent(s): ccd736a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -4
app.py CHANGED
@@ -1,7 +1,42 @@
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
  import gradio as gr
4
+ from PIL import Image
5
+ from model import SimpleCNN
6
 
7
+ def preprocess_image(image):
8
+ transform = transforms.Compose([
9
+ transforms.Resize((32, 32)),
10
+ transforms.ToTensor(),
11
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
12
+ ])
13
+ image = Image.fromarray(image)
14
+ image = transform(image)
15
+ image = image.unsqueeze(0)
16
+ return image
17
 
18
+ def predict_image(model, image):
19
+ model.eval()
20
+ with torch.no_grad():
21
+ output = model(image)
22
+
23
+ _, predicted = torch.max(output.data, 1)
24
+ return predicted.item()
25
+
26
+ def main():
27
+
28
+ model = SimpleCNN()
29
+ model.load_state_dict(torch.load('cifar10_model.pth'))
30
+ model.eval()
31
+
32
+ iface = gr.Interface(
33
+ fn=lambda img: predict_image(model, preprocess_image(img)),
34
+ inputs=gr.Image(),
35
+ outputs="label",
36
+ live=True,
37
+ )
38
+
39
+ iface.launch(share=True)
40
+
41
+ if __name__ == "__main__":
42
+ main()