File size: 6,054 Bytes
c509e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import argparse
import numpy as np
import torch.nn.functional as F
import glob
import cv2
from tqdm import tqdm
import time
import os
from model.deep_lab_model.deeplab import *
from MBD import mask_base_dewarper
import time
from utils import cvimg2torch,torch2cvimg
def net1_net2_infer(model,img_paths,args):
### validate on the real datasets
seg_model=model
seg_model.eval()
for img_path in tqdm(img_paths):
if os.path.exists(img_path.replace('_origin','_capture')):
continue
t1 = time.time()
### segmentation mask predict
img_org = cv2.imread(img_path)
h_org,w_org = img_org.shape[:2]
img = cv2.resize(img_org,(448, 448))
img = cv2.GaussianBlur(img,(15,15),0,0)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cvimg2torch(img)
with torch.no_grad():
pred = seg_model(img.cuda())
mask_pred = pred[:,0,:,:].unsqueeze(1)
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
mask_pred = (mask_pred*255).astype(np.uint8)
kernel = np.ones((3,3))
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
mask_pred[mask_pred>100] = 255
mask_pred[mask_pred<100] = 0
### tps transform base on the mask
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
try:
dewarp, grid = mask_base_dewarper(img_org,mask_pred)
except:
print('fail')
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
grid = grid[0].numpy()
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
# cv2.waitKey(0)
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
grid0 = cv2.resize(grid[:,:,0],(128,128))
grid1 = cv2.resize(grid[:,:,1],(128,128))
grid = np.stack((grid0,grid1),axis=-1)
np.save(img_path.replace('_origin','_grid1'),grid)
def net1_net2_infer_single_im(img,model_path):
seg_model = DeepLab(num_classes=1,
backbone='resnet',
output_stride=16,
sync_bn=None,
freeze_bn=False)
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
seg_model.cuda()
checkpoint = torch.load(model_path)
seg_model.load_state_dict(checkpoint['model_state'])
### validate on the real datasets
seg_model.eval()
### segmentation mask predict
img_org = img
h_org,w_org = img_org.shape[:2]
img = cv2.resize(img_org,(448, 448))
img = cv2.GaussianBlur(img,(15,15),0,0)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cvimg2torch(img)
with torch.no_grad():
# from torchtoolbox.tools import summary
# print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
pred = seg_model(img.cuda())
mask_pred = pred[:,0,:,:].unsqueeze(1)
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
mask_pred = (mask_pred*255).astype(np.uint8)
kernel = np.ones((3,3))
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
mask_pred[mask_pred>100] = 255
mask_pred[mask_pred<100] = 0
### tps transform base on the mask
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
# try:
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
# except:
# print('fail')
# grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
# grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
# dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
# grid = grid[0].numpy()
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
# cv2.waitKey(0)
# cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
# cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
# grid0 = cv2.resize(grid[:,:,0],(128,128))
# grid1 = cv2.resize(grid[:,:,1],(128,128))
# grid = np.stack((grid0,grid1),axis=-1)
# np.save(img_path.replace('_origin','_grid1'),grid)
return mask_pred
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
parser.add_argument('--img_rows', nargs='?', type=int, default=448,
help='Height of the input image')
parser.add_argument('--img_cols', nargs='?', type=int, default=448,
help='Width of the input image')
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
help='Path to previous saved model to restart from')
args = parser.parse_args()
seg_model = DeepLab(num_classes=1,
backbone='resnet',
output_stride=16,
sync_bn=None,
freeze_bn=False)
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
seg_model.cuda()
checkpoint = torch.load(args.seg_model_path)
seg_model.load_state_dict(checkpoint['model_state'])
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
net1_net2_infer(seg_model,im_paths,args)
|