= commited on
Commit
2426537
·
1 Parent(s): 13a4f3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -19
app.py CHANGED
@@ -3,32 +3,94 @@ import torch
3
  from PIL import Image
4
  import os
5
 
6
- from read import classify
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  st.title("Pizza & Not Pizza")
9
 
10
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
- checkpoint = torch.load(os.path.join(os.getcwd(), "best.pth.tar"))
12
  model = checkpoint["model"]
13
  classes = checkpoint["classes"]
14
  tran = checkpoint["transform"]
15
 
16
  # upload image
17
- while True:
18
- uploaded_file = st.file_uploader("Choose an image...", type="jpg")
19
- taking_picture = st.camera_input("Take a picture...")
20
-
21
- if uploaded_file is not None:
22
- img = Image.open(uploaded_file)
23
- st.image(img, caption="Uploaded Image.", use_column_width=True)
24
- label = classify(model, img, tran, classes, device)
25
- st.write(label)
26
-
27
- elif taking_picture is not None:
28
- img = Image.open(taking_picture)
29
- st.image(img, caption="Uploaded Image.", use_column_width=True)
30
- label = classify(model, img, tran, classes, device)
31
- st.write(label)
32
 
33
- else:
34
- pass
 
 
 
 
 
 
 
3
  from PIL import Image
4
  import os
5
 
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class Net(nn.Module):
11
+
12
+ def __init__(self):
13
+ super(Net, self).__init__()
14
+ self.conv1 = nn.Conv2d(3, 32, 5)
15
+ self.conv2 = nn.Conv2d(32, 64, 5)
16
+ self.conv3 = nn.Conv2d(64, 128, 5)
17
+ self.conv4 = nn.Conv2d(128, 256, 5)
18
+ self.conv5 = nn.Conv2d(256, 512, 5)
19
+
20
+ self.fc1 = None
21
+ self.fc2 = nn.Linear(512, 128)
22
+ self.fc3 = nn.Linear(128, 64)
23
+ self.fc4 = nn.Linear(64, 2)
24
+
25
+ def forward(self, x):
26
+ x = x.float()
27
+ """ x = F.relu(self.conv1(x))
28
+ x = F.relu(self.conv2(x))
29
+ x = F.max_pool2d(x, 2)
30
+ x = F.relu(self.conv3(x))
31
+ x = F.relu(self.conv4(x))
32
+ x = F.max_pool2d(x, 2)
33
+ x = F.relu(self.conv5(x))
34
+ x = F.max_pool2d(x, 2) """
35
+
36
+ x = F.max_pool2d(F.relu(self.conv1(x)), 2)
37
+ x = F.max_pool2d(F.relu(self.conv2(x)), 2)
38
+ x = F.max_pool2d(F.relu(self.conv3(x)), 2)
39
+ x = F.max_pool2d(F.relu(self.conv4(x)), 2)
40
+ x = F.max_pool2d(F.relu(self.conv5(x)), 2)
41
+
42
+ #x = x.view(x.size(0), -1)
43
+ x = torch.flatten(x, 1)
44
+
45
+ if self.fc1 is None:
46
+ self.fc1 = nn.Linear(x.shape[1], 512).to(x.device)
47
+
48
+ x = F.relu(self.fc1(x))
49
+ x = F.relu(self.fc2(x))
50
+ x = F.relu(self.fc3(x))
51
+ x = self.fc4(x)
52
+ return x
53
+
54
+
55
+ def classify(model, img, trans=None, classes=[], device=torch.device("cpu")):
56
+ try:
57
+ model = model.eval()
58
+ img = img.convert("RGB")
59
+ img = trans(img)
60
+ img = img.unsqueeze(0)
61
+ img = img.to(device)
62
+
63
+ output = model(img)
64
+ _, pred = torch.max(output, 1)
65
+ procent = torch.sigmoid(output)
66
+
67
+ return f"It {classes[pred.item()].replace('_', ' ')}, I'm {procent[0][pred[0]]*100:.2f}% sure"
68
+ except Exception:
69
+ return "Something went wrong😕, please notify the developer with the following message: " + str(Exception)
70
 
71
  st.title("Pizza & Not Pizza")
72
 
73
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
74
+ checkpoint = torch.load("best.pth.tar", map_location=device)
75
  model = checkpoint["model"]
76
  classes = checkpoint["classes"]
77
  tran = checkpoint["transform"]
78
 
79
  # upload image
80
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
81
+ taking_picture = st.camera_input("Take a picture...")
82
+
83
+ if uploaded_file is not None:
84
+ img = Image.open(uploaded_file)
85
+ st.image(img, caption="Uploaded Image.", use_column_width=True)
86
+ label = classify(model, img, tran, classes, device)
87
+ st.write(label)
 
 
 
 
 
 
 
88
 
89
+ elif taking_picture is not None:
90
+ img = Image.open(taking_picture)
91
+ st.image(img, caption="Uploaded Image.", use_column_width=True)
92
+ label = classify(model, img, tran, classes, device)
93
+ st.write(label)
94
+
95
+ else:
96
+ pass