Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import numpy | |
import os | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import torch | |
import skimage | |
from models.hr_net import hr_w32 | |
from tool_utils import heatmaps_to_coords,draw_joints | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
#Create example list from 'examples/'directory | |
example_list=[["./examples/"+example] for example in os.listdir("examples")] | |
def predict(numpy_img): | |
#resize the numpy_image size to (256,256) | |
img_np=skimage.transform.resize(numpy_img,[256,256]) | |
#convert numpy_image to tensor | |
img=torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float().to(device) | |
#choose model class hr_w32 | |
model=hr_w32().to(device) | |
#load weights of model | |
model.load_state_dict(torch.load('./weights/HRNet_epoch20_loss0.000474.pth',map_location=torch.device('cpu'))['model']) | |
# #set model to pred state | |
model.eval() | |
# #predict the heatmaps of joints | |
start_time=time.time() | |
heatmaps_pred=model(img) | |
heatmaps_pred=heatmaps_pred.double() | |
# #convert output to numpy | |
heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy() | |
# #heatmaps to joints location | |
coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[numpy_img.shape[0],numpy_img.shape[1]],prob_threshold=0.1) | |
inference_time=time.time()-start_time | |
inference_time_text="model inference time:{:.4f}s".format(inference_time) | |
# #draw coords on image_np | |
img_rgb=draw_joints(numpy_img,coord_joints) | |
return img_rgb,inference_time_text | |
demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy'),"text"],examples=example_list) | |
if __name__=="__main__": | |
demo.launch(show_api=False) | |