Spaces:
Sleeping
Sleeping
import timm | |
from PIL import Image | |
from torchvision import transforms as T | |
import gradio as gr | |
import torch | |
model = timm.create_model("hf_hub:OmAlve/swin_s3_base_224-Foods-101", pretrained=True) | |
image_size = (224,224) | |
test_tf = T.Compose([ | |
T.Resize(image_size), | |
T.ToTensor(), | |
T.Normalize( | |
mean = (0.5,0.5,0.5), | |
std = (0.5,0.5,0.5) | |
) | |
]) | |
labels = [ | |
"apple_pie", | |
"baby_back_ribs", | |
"baklava", | |
"beef_carpaccio", | |
"beef_tartare", | |
"beet_salad", | |
"beignets", | |
"bibimbap", | |
"bread_pudding", | |
"breakfast_burrito", | |
"bruschetta", | |
"caesar_salad", | |
"cannoli", | |
"caprese_salad", | |
"carrot_cake", | |
"ceviche", | |
"cheesecake", | |
"cheese_plate", | |
"chicken_curry", | |
"chicken_quesadilla", | |
"chicken_wings", | |
"chocolate_cake", | |
"chocolate_mousse", | |
"churros", | |
"clam_chowder", | |
"club_sandwich", | |
"crab_cakes", | |
"creme_brulee", | |
"croque_madame", | |
"cup_cakes", | |
"deviled_eggs", | |
"donuts", | |
"dumplings", | |
"edamame", | |
"eggs_benedict", | |
"escargots", | |
"falafel", | |
"filet_mignon", | |
"fish_and_chips", | |
"foie_gras", | |
"french_fries", | |
"french_onion_soup", | |
"french_toast", | |
"fried_calamari", | |
"fried_rice", | |
"frozen_yogurt", | |
"garlic_bread", | |
"gnocchi", | |
"greek_salad", | |
"grilled_cheese_sandwich", | |
"grilled_salmon", | |
"guacamole", | |
"gyoza", | |
"hamburger", | |
"hot_and_sour_soup", | |
"hot_dog", | |
"huevos_rancheros", | |
"hummus", | |
"ice_cream", | |
"lasagna", | |
"lobster_bisque", | |
"lobster_roll_sandwich", | |
"macaroni_and_cheese", | |
"macarons", | |
"miso_soup", | |
"mussels", | |
"nachos", | |
"omelette", | |
"onion_rings", | |
"oysters", | |
"pad_thai", | |
"paella", | |
"pancakes", | |
"panna_cotta", | |
"peking_duck", | |
"pho", | |
"pizza", | |
"pork_chop", | |
"poutine", | |
"prime_rib", | |
"pulled_pork_sandwich", | |
"ramen", | |
"ravioli", | |
"red_velvet_cake", | |
"risotto", | |
"samosa", | |
"sashimi", | |
"scallops", | |
"seaweed_salad", | |
"shrimp_and_grits", | |
"spaghetti_bolognese", | |
"spaghetti_carbonara", | |
"spring_rolls", | |
"steak", | |
"strawberry_shortcake", | |
"sushi", | |
"tacos", | |
"takoyaki", | |
"tiramisu", | |
"tuna_tartare", | |
"waffles" | |
] | |
def predict(img): | |
inp = test_tf(img).unsqueeze(0) | |
with torch.no_grad(): | |
predictions = torch.nn.functional.softmax(model(inp)[0], dim=0) | |
toplabels = predictions.argsort(descending=True)[:5] | |
results = {labels[label] : float(predictions[label]) for label in toplabels} | |
return results | |
gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="label", | |
examples=['./miso soup.jpg','./cupcake.jpg','./pasta.jpg'], | |
live=True).launch() | |