Spaces:
Runtime error
Runtime error
import pdb | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import Nets | |
from huggingface_hub import hf_hub_download | |
def load_model(pretrained_dict, new): | |
model_dict = new.state_dict() | |
# 1. filter out unnecessary keys | |
pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict} | |
# 2. overwrite entries in the existing state dict | |
model_dict.update(pretrained_dict) | |
new.load_state_dict(model_dict) | |
device = torch.device("cpu") | |
print("use cpu") | |
net = Nets.ResNet(block=Nets.BasicBlock, layers=[2, 2, 2, 2], num_classes=1).to(device) | |
model_ckpt_path = hf_hub_download(repo_id="M4869/beauty_prediction_fpb5k", filename="resnet18.pth") | |
# /home/my/.cache/huggingface/hub/models--google--pegasus-xsum/snapshots/8d8ffc158a3bee9fbb03afacdfc347c823c5ec8b/config.json | |
load_model(torch.load(model_ckpt_path, encoding='latin1', map_location=torch.device('cpu')), net) | |
net.eval() | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) | |
def fun(img_path): | |
img = Image.open(img_path).convert('RGB') | |
img = transform(img) | |
with torch.no_grad(): | |
img = img.unsqueeze(0).to(device) | |
output = net(img).squeeze(1).cpu().numpy()[0] | |
v = float("%.1f" % (output * 20)) | |
return min(100, v + 15.0) | |