Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np # linear algebra | |
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
import os | |
import torch | |
import torchvision | |
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions | |
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way | |
import torchvision.transforms as transforms # Transformations we can perform on our dataset | |
import torch.nn.functional as F # All functions that don't have any parameters | |
from torch.utils.data import DataLoader, Dataset # Gives easier dataset managment and creates mini batches | |
from torchvision.datasets import ImageFolder | |
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc. | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use gpu or cpu | |
from tqdm import tqdm | |
from torchvision import models | |
# load pretrain model and modify... | |
model = models.resnet50(pretrained=True) | |
# If you want to do finetuning then set requires_grad = False | |
# Remove these two lines if you want to train entire model, | |
# and only want to load the pretrain weights. | |
for param in model.parameters(): | |
param.requires_grad = False | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, 2) | |
model.to(device) | |
# Loss and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.01) | |
checkpoint = torch.load("checpoint_epoch_4.pt", | |
map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
def image_classifier(inp): | |
model.eval() | |
data_transforms = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224)), | |
transforms.Normalize([0.5] * 3, [0.5] * 3), ]) | |
img = data_transforms(inp).unsqueeze(dim=0) | |
img = img.to(device) | |
pred = model(img) | |
_, preds = torch.max(pred, 1) | |
print(f"class : {preds}") | |
cur_name = "" | |
if preds[0] == 1: | |
print(f"predicted ----> Dog") | |
cur_name = "DOG" | |
else: | |
print(f"predicted ----> Cat") | |
cur_name = "CAT" | |
return cur_name | |
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="text") | |
demo.launch() |