|
import torch |
|
import argparse |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import glob |
|
import cv2 |
|
from tqdm import tqdm |
|
|
|
import time |
|
import os |
|
from model.deep_lab_model.deeplab import * |
|
from MBD import mask_base_dewarper |
|
import time |
|
|
|
from utils import cvimg2torch,torch2cvimg |
|
|
|
|
|
|
|
def net1_net2_infer(model,img_paths,args): |
|
|
|
|
|
seg_model=model |
|
seg_model.eval() |
|
for img_path in tqdm(img_paths): |
|
if os.path.exists(img_path.replace('_origin','_capture')): |
|
continue |
|
t1 = time.time() |
|
|
|
img_org = cv2.imread(img_path) |
|
h_org,w_org = img_org.shape[:2] |
|
img = cv2.resize(img_org,(448, 448)) |
|
img = cv2.GaussianBlur(img,(15,15),0,0) |
|
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) |
|
img = cvimg2torch(img) |
|
|
|
with torch.no_grad(): |
|
pred = seg_model(img.cuda()) |
|
mask_pred = pred[:,0,:,:].unsqueeze(1) |
|
mask_pred = F.interpolate(mask_pred,(h_org,w_org)) |
|
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy() |
|
mask_pred = (mask_pred*255).astype(np.uint8) |
|
kernel = np.ones((3,3)) |
|
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3) |
|
mask_pred = cv2.erode(mask_pred,kernel,iterations=3) |
|
mask_pred[mask_pred>100] = 255 |
|
mask_pred[mask_pred<100] = 0 |
|
|
|
|
|
try: |
|
dewarp, grid = mask_base_dewarper(img_org,mask_pred) |
|
except: |
|
print('fail') |
|
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1) |
|
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1) |
|
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0] |
|
grid = grid[0].numpy() |
|
|
|
|
|
|
|
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp) |
|
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred) |
|
|
|
grid0 = cv2.resize(grid[:,:,0],(128,128)) |
|
grid1 = cv2.resize(grid[:,:,1],(128,128)) |
|
grid = np.stack((grid0,grid1),axis=-1) |
|
np.save(img_path.replace('_origin','_grid1'),grid) |
|
|
|
|
|
def net1_net2_infer_single_im(img,model_path): |
|
seg_model = DeepLab(num_classes=1, |
|
backbone='resnet', |
|
output_stride=16, |
|
sync_bn=None, |
|
freeze_bn=False) |
|
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count())) |
|
seg_model.cuda() |
|
checkpoint = torch.load(model_path) |
|
seg_model.load_state_dict(checkpoint['model_state']) |
|
|
|
seg_model.eval() |
|
|
|
img_org = img |
|
h_org,w_org = img_org.shape[:2] |
|
img = cv2.resize(img_org,(448, 448)) |
|
img = cv2.GaussianBlur(img,(15,15),0,0) |
|
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) |
|
img = cvimg2torch(img) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
pred = seg_model(img.cuda()) |
|
mask_pred = pred[:,0,:,:].unsqueeze(1) |
|
mask_pred = F.interpolate(mask_pred,(h_org,w_org)) |
|
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy() |
|
mask_pred = (mask_pred*255).astype(np.uint8) |
|
kernel = np.ones((3,3)) |
|
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3) |
|
mask_pred = cv2.erode(mask_pred,kernel,iterations=3) |
|
mask_pred[mask_pred>100] = 255 |
|
mask_pred[mask_pred<100] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mask_pred |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Hyperparams') |
|
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data') |
|
parser.add_argument('--img_rows', nargs='?', type=int, default=448, |
|
help='Height of the input image') |
|
parser.add_argument('--img_cols', nargs='?', type=int, default=448, |
|
help='Width of the input image') |
|
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl', |
|
help='Path to previous saved model to restart from') |
|
args = parser.parse_args() |
|
|
|
seg_model = DeepLab(num_classes=1, |
|
backbone='resnet', |
|
output_stride=16, |
|
sync_bn=None, |
|
freeze_bn=False) |
|
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count())) |
|
seg_model.cuda() |
|
checkpoint = torch.load(args.seg_model_path) |
|
seg_model.load_state_dict(checkpoint['model_state']) |
|
|
|
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*')) |
|
|
|
net1_net2_infer(seg_model,im_paths,args) |
|
|
|
|