cats_dogs / app.py
agueroooooooooo's picture
Update app.py
27dc96e
raw
history blame
2.31 kB
import gradio as gr
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
import torch.nn.functional as F # All functions that don't have any parameters
from torch.utils.data import DataLoader, Dataset # Gives easier dataset managment and creates mini batches
from torchvision.datasets import ImageFolder
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use gpu or cpu
from tqdm import tqdm
from torchvision import models
# load pretrain model and modify...
model = models.resnet50(pretrained=True)
# If you want to do finetuning then set requires_grad = False
# Remove these two lines if you want to train entire model,
# and only want to load the pretrain weights.
for param in model.parameters():
param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
checkpoint = torch.load("checpoint_epoch_4.pt",
map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def image_classifier(inp):
model.eval()
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize([0.5] * 3, [0.5] * 3), ])
img = data_transforms(inp).unsqueeze(dim=0)
img = img.to(device)
pred = model(img)
_, preds = torch.max(pred, 1)
print(f"class : {preds}")
cur_name = ""
if preds[0] == 1:
print(f"predicted ----> Dog")
cur_name = "DOG"
else:
print(f"predicted ----> Cat")
cur_name = "CAT"
return cur_name
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="text")
demo.launch()