OmAlve commited on
Commit
d012df7
·
verified ·
1 Parent(s): 42b6482

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ from PIL import Image
3
+ from torchvision import transforms as T
4
+ import gradio as gr
5
+ import torch
6
+
7
+ model = timm.create_model("hf_hub:OmAlve/swin_s3_base_224-Foods-101", pretrained=True)
8
+ image_size = (224,224)
9
+
10
+ test_tf = T.Compose([
11
+ T.Resize(image_size),
12
+ T.ToTensor(),
13
+ T.Normalize(
14
+ mean = (0.5,0.5,0.5),
15
+ std = (0.5,0.5,0.5)
16
+ )
17
+ ])
18
+
19
+ labels = [
20
+ "apple_pie",
21
+ "baby_back_ribs",
22
+ "baklava",
23
+ "beef_carpaccio",
24
+ "beef_tartare",
25
+ "beet_salad",
26
+ "beignets",
27
+ "bibimbap",
28
+ "bread_pudding",
29
+ "breakfast_burrito",
30
+ "bruschetta",
31
+ "caesar_salad",
32
+ "cannoli",
33
+ "caprese_salad",
34
+ "carrot_cake",
35
+ "ceviche",
36
+ "cheesecake",
37
+ "cheese_plate",
38
+ "chicken_curry",
39
+ "chicken_quesadilla",
40
+ "chicken_wings",
41
+ "chocolate_cake",
42
+ "chocolate_mousse",
43
+ "churros",
44
+ "clam_chowder",
45
+ "club_sandwich",
46
+ "crab_cakes",
47
+ "creme_brulee",
48
+ "croque_madame",
49
+ "cup_cakes",
50
+ "deviled_eggs",
51
+ "donuts",
52
+ "dumplings",
53
+ "edamame",
54
+ "eggs_benedict",
55
+ "escargots",
56
+ "falafel",
57
+ "filet_mignon",
58
+ "fish_and_chips",
59
+ "foie_gras",
60
+ "french_fries",
61
+ "french_onion_soup",
62
+ "french_toast",
63
+ "fried_calamari",
64
+ "fried_rice",
65
+ "frozen_yogurt",
66
+ "garlic_bread",
67
+ "gnocchi",
68
+ "greek_salad",
69
+ "grilled_cheese_sandwich",
70
+ "grilled_salmon",
71
+ "guacamole",
72
+ "gyoza",
73
+ "hamburger",
74
+ "hot_and_sour_soup",
75
+ "hot_dog",
76
+ "huevos_rancheros",
77
+ "hummus",
78
+ "ice_cream",
79
+ "lasagna",
80
+ "lobster_bisque",
81
+ "lobster_roll_sandwich",
82
+ "macaroni_and_cheese",
83
+ "macarons",
84
+ "miso_soup",
85
+ "mussels",
86
+ "nachos",
87
+ "omelette",
88
+ "onion_rings",
89
+ "oysters",
90
+ "pad_thai",
91
+ "paella",
92
+ "pancakes",
93
+ "panna_cotta",
94
+ "peking_duck",
95
+ "pho",
96
+ "pizza",
97
+ "pork_chop",
98
+ "poutine",
99
+ "prime_rib",
100
+ "pulled_pork_sandwich",
101
+ "ramen",
102
+ "ravioli",
103
+ "red_velvet_cake",
104
+ "risotto",
105
+ "samosa",
106
+ "sashimi",
107
+ "scallops",
108
+ "seaweed_salad",
109
+ "shrimp_and_grits",
110
+ "spaghetti_bolognese",
111
+ "spaghetti_carbonara",
112
+ "spring_rolls",
113
+ "steak",
114
+ "strawberry_shortcake",
115
+ "sushi",
116
+ "tacos",
117
+ "takoyaki",
118
+ "tiramisu",
119
+ "tuna_tartare",
120
+ "waffles"
121
+ ]
122
+
123
+ def predict(img):
124
+ inp = test_tf(img).unsqueeze(0)
125
+ with torch.no_grad():
126
+ predictions = torch.nn.functional.softmax(model(inp)[0], dim=0)
127
+ toplabels = predictions.argsort(descending=True)[:5]
128
+ results = {labels[label] : float(predictions[label]) for label in toplabels}
129
+ return results
130
+
131
+ gr.Interface(fn=predict,
132
+ inputs=gr.Image(type="pil"),
133
+ outputs="label",
134
+ examples=['./miso soup.jpg','./cupcake.jpg','./pasta.jpg'],
135
+ live=True).launch()
136
+