OmAlve's picture
Create app.py
d012df7 verified
raw
history blame
2.78 kB
import timm
from PIL import Image
from torchvision import transforms as T
import gradio as gr
import torch
model = timm.create_model("hf_hub:OmAlve/swin_s3_base_224-Foods-101", pretrained=True)
image_size = (224,224)
test_tf = T.Compose([
T.Resize(image_size),
T.ToTensor(),
T.Normalize(
mean = (0.5,0.5,0.5),
std = (0.5,0.5,0.5)
)
])
labels = [
"apple_pie",
"baby_back_ribs",
"baklava",
"beef_carpaccio",
"beef_tartare",
"beet_salad",
"beignets",
"bibimbap",
"bread_pudding",
"breakfast_burrito",
"bruschetta",
"caesar_salad",
"cannoli",
"caprese_salad",
"carrot_cake",
"ceviche",
"cheesecake",
"cheese_plate",
"chicken_curry",
"chicken_quesadilla",
"chicken_wings",
"chocolate_cake",
"chocolate_mousse",
"churros",
"clam_chowder",
"club_sandwich",
"crab_cakes",
"creme_brulee",
"croque_madame",
"cup_cakes",
"deviled_eggs",
"donuts",
"dumplings",
"edamame",
"eggs_benedict",
"escargots",
"falafel",
"filet_mignon",
"fish_and_chips",
"foie_gras",
"french_fries",
"french_onion_soup",
"french_toast",
"fried_calamari",
"fried_rice",
"frozen_yogurt",
"garlic_bread",
"gnocchi",
"greek_salad",
"grilled_cheese_sandwich",
"grilled_salmon",
"guacamole",
"gyoza",
"hamburger",
"hot_and_sour_soup",
"hot_dog",
"huevos_rancheros",
"hummus",
"ice_cream",
"lasagna",
"lobster_bisque",
"lobster_roll_sandwich",
"macaroni_and_cheese",
"macarons",
"miso_soup",
"mussels",
"nachos",
"omelette",
"onion_rings",
"oysters",
"pad_thai",
"paella",
"pancakes",
"panna_cotta",
"peking_duck",
"pho",
"pizza",
"pork_chop",
"poutine",
"prime_rib",
"pulled_pork_sandwich",
"ramen",
"ravioli",
"red_velvet_cake",
"risotto",
"samosa",
"sashimi",
"scallops",
"seaweed_salad",
"shrimp_and_grits",
"spaghetti_bolognese",
"spaghetti_carbonara",
"spring_rolls",
"steak",
"strawberry_shortcake",
"sushi",
"tacos",
"takoyaki",
"tiramisu",
"tuna_tartare",
"waffles"
]
def predict(img):
inp = test_tf(img).unsqueeze(0)
with torch.no_grad():
predictions = torch.nn.functional.softmax(model(inp)[0], dim=0)
toplabels = predictions.argsort(descending=True)[:5]
results = {labels[label] : float(predictions[label]) for label in toplabels}
return results
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="label",
examples=['./miso soup.jpg','./cupcake.jpg','./pasta.jpg'],
live=True).launch()