File size: 1,415 Bytes
25ec020
 
3f481af
475ac6e
 
 
52e8824
475ac6e
 
 
 
 
 
 
 
 
 
 
d42ee1a
 
475ac6e
cc314be
475ac6e
25ec020
cc314be
4afdb1d
475ac6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5de1a86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import pdb

import torch
import torchvision.transforms as transforms
from PIL import Image
import Nets
from huggingface_hub import hf_hub_download


def load_model(pretrained_dict, new):
    model_dict = new.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    new.load_state_dict(model_dict)


device = torch.device("cpu")
print("use cpu")

net = Nets.ResNet(block=Nets.BasicBlock, layers=[2, 2, 2, 2], num_classes=1).to(device)

model_ckpt_path = hf_hub_download(repo_id="M4869/beauty_prediction_fpb5k", filename="resnet18.pth")
# /home/my/.cache/huggingface/hub/models--google--pegasus-xsum/snapshots/8d8ffc158a3bee9fbb03afacdfc347c823c5ec8b/config.json
load_model(torch.load(model_ckpt_path, encoding='latin1', map_location=torch.device('cpu')), net)
net.eval()

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])


def fun(img_path):
    img = Image.open(img_path).convert('RGB')
    img = transform(img)
    with torch.no_grad():
        img = img.unsqueeze(0).to(device)
        output = net(img).squeeze(1).cpu().numpy()[0]
        return "%.1f" % (output * 20)