pose_experiment / app.py
yijiu's picture
feat:recover image resolution
383ca9f
import gradio as gr
import time
import numpy
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch
import skimage
from models.hr_net import hr_w32
from tool_utils import heatmaps_to_coords,draw_joints
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#Create example list from 'examples/'directory
example_list=[["./examples/"+example] for example in os.listdir("examples")]
def predict(numpy_img):
#resize the numpy_image size to (256,256)
img_np=skimage.transform.resize(numpy_img,[256,256])
#convert numpy_image to tensor
img=torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float().to(device)
#choose model class hr_w32
model=hr_w32().to(device)
#load weights of model
model.load_state_dict(torch.load('./weights/HRNet_epoch20_loss0.000474.pth',map_location=torch.device('cpu'))['model'])
# #set model to pred state
model.eval()
# #predict the heatmaps of joints
start_time=time.time()
heatmaps_pred=model(img)
heatmaps_pred=heatmaps_pred.double()
# #convert output to numpy
heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
# #heatmaps to joints location
coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[numpy_img.shape[0],numpy_img.shape[1]],prob_threshold=0.1)
inference_time=time.time()-start_time
inference_time_text="model inference time:{:.4f}s".format(inference_time)
# #draw coords on image_np
img_rgb=draw_joints(numpy_img,coord_joints)
return img_rgb,inference_time_text
demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy'),"text"],examples=example_list)
if __name__=="__main__":
demo.launch(show_api=False)