import pdb import torch import torchvision.transforms as transforms from PIL import Image import Nets def load_model(pretrained_dict, new): model_dict = new.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) new.load_state_dict(model_dict) # if torch.cuda.is_available(): # device = torch.device("cuda") # else: device = torch.device("cpu") print("use cpu") # model_ckpt_path = "./models/resnet18.pth" # model_ckpt_path = "https://huggingface.co/M4869/beauty_prediction_fpb5k/blob/main/resnet18.pth" # model_ckpt_path = "M4869/beauty_prediction_fpb5k" # model = torch.hub.load("huggingface/transformers", model_ckpt_path) from huggingface_hub import hf_hub_download import joblib REPO_ID = "M4869" FILENAME = "beauty_prediction_fpb5k" # tmp = joblib.load( # hf_hub_download(repo_id=REPO_ID, filename=FILENAME) # ) # ans = hf_hub_download(repo_id="google/pegasus-xsum", filename="config.json") # /home/my/.cache/huggingface/hub/models--google--pegasus-xsum/snapshots/8d8ffc158a3bee9fbb03afacdfc347c823c5ec8b/config.json model_ckpt_path = hf_hub_download(repo_id="M4869/beauty_prediction_fpb5k", filename="resnet18.pth") net = Nets.ResNet(block=Nets.BasicBlock, layers=[2, 2, 2, 2], num_classes=1).to(device) load_model(torch.load(model_ckpt_path, encoding='latin1'), net) net.eval() transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def fun(img_path): img = Image.open(img_path).convert('RGB') img = transform(img) with torch.no_grad(): img = img.unsqueeze(0).to(device) output = net(img).squeeze(1).cpu().numpy()[0] return output # def main(): # for i in range(6, 7): # img = Image.open("./data2/%d.jpg" % i).convert('RGB') # img = transform(img) # # with torch.no_grad(): # img = img.unsqueeze(0).to(device) # output = net(img).squeeze(1).cpu().numpy()[0] # print(i, output * 20) # if __name__ == '__main__': # main()