hcs
Add application file
25ec020
raw
history blame
2.3 kB
import pdb
import torch
import torchvision.transforms as transforms
from PIL import Image
import Nets
def load_model(pretrained_dict, new):
model_dict = new.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
new.load_state_dict(model_dict)
# if torch.cuda.is_available():
# device = torch.device("cuda")
# else:
device = torch.device("cpu")
print("use cpu")
# model_ckpt_path = "./models/resnet18.pth"
# model_ckpt_path = "https://huggingface.co/M4869/beauty_prediction_fpb5k/blob/main/resnet18.pth"
# model_ckpt_path = "M4869/beauty_prediction_fpb5k"
# model = torch.hub.load("huggingface/transformers", model_ckpt_path)
from huggingface_hub import hf_hub_download
import joblib
REPO_ID = "M4869"
FILENAME = "beauty_prediction_fpb5k"
# tmp = joblib.load(
# hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
# )
# ans = hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json")
# /home/my/.cache/huggingface/hub/models--google--pegasus-xsum/snapshots/8d8ffc158a3bee9fbb03afacdfc347c823c5ec8b/config.json
model_ckpt_path = hf_hub_download(repo_id="M4869/beauty_prediction_fpb5k", filename="resnet18.pth")
net = Nets.ResNet(block=Nets.BasicBlock, layers=[2, 2, 2, 2], num_classes=1).to(device)
load_model(torch.load(model_ckpt_path, encoding='latin1'), net)
net.eval()
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
def fun(img_path):
img = Image.open(img_path).convert('RGB')
img = transform(img)
with torch.no_grad():
img = img.unsqueeze(0).to(device)
output = net(img).squeeze(1).cpu().numpy()[0]
return output
# def main():
# for i in range(6, 7):
# img = Image.open("./data2/%d.jpg" % i).convert('RGB')
# img = transform(img)
#
# with torch.no_grad():
# img = img.unsqueeze(0).to(device)
# output = net(img).squeeze(1).cpu().numpy()[0]
# print(i, output * 20)
# if __name__ == '__main__':
# main()