|
from trainer import Trainer |
|
from importlib import import_module |
|
import math |
|
import torch |
|
from torch import optim |
|
from torch.optim import lr_scheduler |
|
import numpy as np |
|
import os |
|
from shutil import copyfile |
|
import glob |
|
from models.utils.flow_losses import smoothness_loss, second_order_loss |
|
from models.utils.fbConsistencyCheck import image_warp |
|
from models.utils.fbConsistencyCheck import ternary_loss2 |
|
import torch.nn.functional as F |
|
import cv2 |
|
import cvbase |
|
from data.util.flow_utils import region_fill as rf |
|
import imageio |
|
import torch.nn as nn |
|
from skimage.feature import canny |
|
from skimage.metrics import peak_signal_noise_ratio as psnr |
|
from skimage.metrics import structural_similarity as ssim |
|
from models.utils.bce_edge_loss import edgeLoss, EdgeAcc |
|
|
|
|
|
class Network(Trainer): |
|
def init_model(self): |
|
self.edgeMeasure = EdgeAcc() |
|
model_package = import_module('models.{}'.format(self.opt['model'])) |
|
model = model_package.Model(self.opt) |
|
optimizer = optim.Adam(model.parameters(), lr=float(self.opt['train']['lr']), |
|
betas=(float(self.opt['train']['BETA1']), float(float(self.opt['train']['BETA2'])))) |
|
if self.rank <= 0: |
|
self.logger.info( |
|
'Optimizer is Adam, BETA1: {}, BETA2: {}'.format(float(self.opt['train']['BETA1']), |
|
float(self.opt['train']['BETA2']))) |
|
step_size = int(math.ceil(self.opt['train']['UPDATE_INTERVAL'] / self.trainSize)) |
|
if self.rank <= 0: |
|
self.logger.info('Step size for optimizer is {} epoch'.format(step_size)) |
|
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=self.opt['train']['lr_decay']) |
|
return model, optimizer, scheduler |
|
|
|
def resume_training(self): |
|
gen_state = torch.load(self.opt['path']['gen_state'], |
|
map_location=lambda storage, loc: storage.cuda(self.opt['device'])) |
|
opt_state = torch.load(self.opt['path']['opt_state'], |
|
map_location=lambda storage, loc: storage.cuda(self.opt['device'])) |
|
if self.rank <= 0: |
|
self.logger.info('Resume state is activated') |
|
self.logger.info('Resume training from epoch: {}, iter: {}'.format( |
|
opt_state['epoch'], opt_state['iteration'] |
|
)) |
|
if self.opt['finetune'] == False: |
|
start_epoch = opt_state['epoch'] |
|
current_step = opt_state['iteration'] |
|
self.optimizer.load_state_dict(opt_state['optimizer_state_dict']) |
|
self.scheduler.load_state_dict(opt_state['scheduler_state_dict']) |
|
else: |
|
start_epoch = 0 |
|
current_step = 0 |
|
self.model.load_state_dict(gen_state['model_state_dict']) |
|
if self.rank <= 0: |
|
self.logger.info('Resume training mode, optimizer, scheduler and model have been uploaded') |
|
return start_epoch, current_step |
|
|
|
def _trainEpoch(self, epoch): |
|
for idx, train_data in enumerate(self.trainLoader): |
|
self.currentStep += 1 |
|
|
|
if self.currentStep > self.totalIterations: |
|
if self.rank <= 0: |
|
self.logger.info('Train process has been finished') |
|
break |
|
if self.opt['train']['WARMUP'] is not None and self.currentStep <= self.opt['train']['WARMUP'] // self.opt[ |
|
'world_size']: |
|
target_lr = self.opt['train']['lr'] * self.currentStep / ( |
|
self.opt['train']['WARMUP']) |
|
self.adjust_learning_rate(self.optimizer, target_lr) |
|
|
|
flows = train_data['flows'] |
|
diffused_flows = train_data['diffused_flows'] |
|
target_edge = train_data['edges'] |
|
current_frame = train_data['current_frame'] |
|
current_frame = current_frame.to(self.opt['device']) |
|
shift_frame = train_data['shift_frame'] |
|
shift_frame = shift_frame.to(self.opt['device']) |
|
masks = train_data['masks'] |
|
flows = flows.to(self.opt['device']) |
|
masks = masks.to(self.opt['device']) |
|
diffused_flows = diffused_flows.to(self.opt['device']) |
|
target_edge = target_edge.to(self.opt['device']) |
|
|
|
if len(masks.shape) == 5: |
|
b, c, t, h, w = masks.shape |
|
target_flow = flows[:, :, t // 2] |
|
target_mask = masks[:, :, t // 2] |
|
else: |
|
assert len(masks.shape) == 4 and len(flows.shape) == 4 |
|
target_flow = flows |
|
target_mask = masks |
|
|
|
filled_flow = self.model(diffused_flows, masks) |
|
|
|
filled_flow, filled_edge = filled_flow |
|
|
|
combined_flow = target_flow * (1 - target_mask) + filled_flow * target_mask |
|
combined_edge = target_edge * (1 - target_mask) + filled_edge * target_mask |
|
edge_loss = (edgeLoss(filled_edge, target_edge) + 5 * edgeLoss(combined_edge, target_edge)) |
|
|
|
|
|
L1Loss_masked = self.maskedLoss(combined_flow * target_mask, |
|
target_flow * target_mask) / torch.mean(target_mask) |
|
L1Loss_valid = self.validLoss(filled_flow * (1 - target_mask), |
|
target_flow * (1 - target_mask)) / torch.mean(1 - target_mask) |
|
|
|
smoothLoss = smoothness_loss(combined_flow, target_mask) |
|
smoothLoss2 = second_order_loss(combined_flow, target_mask) |
|
ternary_loss = self.ternary_loss(combined_flow, target_flow, target_mask, current_frame, shift_frame, |
|
scale_factor=1) |
|
|
|
m_losses = (L1Loss_masked + L1Loss_valid) * self.opt['L1M'] |
|
sm1_loss = smoothLoss * self.opt['sm'] |
|
sm2_loss = smoothLoss2 * self.opt['sm2'] |
|
t_loss = self.opt['ternary'] * ternary_loss |
|
e_loss = edge_loss * self.opt['edge_loss'] |
|
|
|
loss = m_losses + sm1_loss + sm2_loss + t_loss + e_loss |
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
if self.opt['gc']: |
|
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10, |
|
norm_type=2) |
|
self.optimizer.step() |
|
|
|
if self.opt['use_tb_logger'] and self.rank <= 0 and self.currentStep % 8 == 0: |
|
print('Mask: {:.03f}, sm: {:.03f}, sm2: {:.03f}, ternary: {:.03f}, edge: {:03f}'.format( |
|
m_losses.item(), |
|
sm1_loss.item(), |
|
sm2_loss.item(), |
|
t_loss.item(), |
|
e_loss.item() |
|
)) |
|
self.tb_logger.add_scalar('{}/recon'.format('train'), m_losses.item(), |
|
self.currentStep) |
|
self.tb_logger.add_scalar('{}/sm'.format('train'), sm1_loss.item(), self.currentStep) |
|
self.tb_logger.add_scalar('{}/sm2'.format('train'), sm2_loss.item(), |
|
self.currentStep) |
|
self.tb_logger.add_scalar('{}/ternary'.format('train'), |
|
t_loss.item(), |
|
self.currentStep) |
|
self.tb_logger.add_scalar('{}/edge'.format('train'), e_loss.item(), |
|
self.currentStep) |
|
|
|
if self.currentStep % self.opt['logger']['PRINT_FREQ'] == 0 and self.rank <= 0: |
|
compLog = np.array(combined_flow.detach().permute(0, 2, 3, 1).cpu()) |
|
flowsLog = np.array(target_flow.detach().permute(0, 2, 3, 1).cpu()) |
|
logs = self.calculate_metrics(compLog, flowsLog) |
|
prec, recall = self.edgeMeasure(filled_edge.detach(), target_edge.detach()) |
|
logs['prec'] = prec |
|
logs['recall'] = recall |
|
self._printLog(logs, epoch, loss) |
|
|
|
def ternary_loss(self, comp, flow, mask, current_frame, shift_frame, scale_factor): |
|
if scale_factor != 1: |
|
current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear') |
|
shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear') |
|
warped_sc = image_warp(shift_frame, flow) |
|
noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1) |
|
warped_comp_sc = image_warp(shift_frame, comp) |
|
loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask) |
|
return loss |
|
|
|
def calculate_metrics(self, results, gts): |
|
B, H, W, C = results.shape |
|
psnr_values, ssim_values, L1errors, L2errors = [], [], [], [] |
|
for i in range(B): |
|
result, gt = results[i], gts[i] |
|
result_rgb = cvbase.flow2rgb(result) |
|
gt_rgb = cvbase.flow2rgb(gt) |
|
psnr_value = psnr(result_rgb, gt_rgb) |
|
ssim_value = ssim(result_rgb, gt_rgb, multichannel=True) |
|
residual = result - gt |
|
L1error = np.mean(np.abs(residual)) |
|
L2error = np.sum(residual ** 2) ** 0.5 / (H * W * C) |
|
psnr_values.append(psnr_value) |
|
ssim_values.append(ssim_value) |
|
L1errors.append(L1error) |
|
L2errors.append(L2error) |
|
psnr_value = np.mean(psnr_values) |
|
ssim_value = np.mean(ssim_values) |
|
L1_value = np.mean(L1errors) |
|
L2_value = np.mean(L2errors) |
|
return {'l1': L1_value, 'l2': L2_value, 'psnr': psnr_value, 'ssim': ssim_value} |
|
|
|
def _printLog(self, logs, epoch, loss): |
|
if self.countDown % self.opt['record_iter'] == 0: |
|
self.total_psnr = 0 |
|
self.total_ssim = 0 |
|
self.total_l1 = 0 |
|
self.total_l2 = 0 |
|
self.total_loss = 0 |
|
self.total_prec = 0 |
|
self.total_recall = 0 |
|
self.countDown = 0 |
|
self.countDown += 1 |
|
message = '[epoch:{:3d}, iter:{:7d}, lr:('.format(epoch, self.currentStep) |
|
for v in self.get_lr(): |
|
message += '{:.3e}, '.format(v) |
|
message += ')] ' |
|
self.total_psnr += logs['psnr'] |
|
self.total_ssim += logs['ssim'] |
|
self.total_l1 += logs['l1'] |
|
self.total_l2 += logs['l2'] |
|
self.total_prec += logs['prec'].item() |
|
self.total_recall += logs['recall'].item() |
|
self.total_loss += loss.item() |
|
mean_psnr = self.total_psnr / self.countDown |
|
mean_ssim = self.total_ssim / self.countDown |
|
mean_l1 = self.total_l1 / self.countDown |
|
mean_l2 = self.total_l2 / self.countDown |
|
mean_prec = self.total_prec / self.countDown |
|
mean_recall = self.total_recall / self.countDown |
|
mean_loss = self.total_loss / self.countDown |
|
|
|
message += '{:s}: {:.4e} '.format('mean_loss', mean_loss) |
|
message += '{:s}: {:} '.format('mean_psnr', mean_psnr) |
|
message += '{:s}: {:} '.format('mean_ssim', mean_ssim) |
|
message += '{:s}: {:} '.format('mean_l1', mean_l1) |
|
message += '{:s}: {:} '.format('mean_l2', mean_l2) |
|
message += '{:s}: {:} '.format('mean_prec', mean_prec) |
|
message += '{:s}: {:} '.format('mean_recall', mean_recall) |
|
|
|
if self.opt['use_tb_logger']: |
|
self.tb_logger.add_scalar('train/mean_psnr', mean_psnr, self.currentStep) |
|
self.tb_logger.add_scalar('train/mean_ssim', mean_ssim, self.currentStep) |
|
self.tb_logger.add_scalar('train/mean_l1', mean_l1, self.currentStep) |
|
self.tb_logger.add_scalar('train/mean_l2', mean_l2, self.currentStep) |
|
self.tb_logger.add_scalar('train/mean_loss', mean_loss, self.currentStep) |
|
self.tb_logger.add_scalar('train/mean_prec', mean_prec, self.currentStep) |
|
self.tb_logger.add_scalar('train/mean_recall', mean_recall, self.currentStep) |
|
self.logger.info(message) |
|
|
|
if self.currentStep % self.opt['logger']['SAVE_CHECKPOINT_FREQ'] == 0: |
|
self.save_checkpoint(epoch, 'l1', logs['l1']) |
|
|
|
def save_checkpoint(self, epoch, metric, number): |
|
if isinstance(self.model, torch.nn.DataParallel) or isinstance(self.model, |
|
torch.nn.parallel.DistributedDataParallel): |
|
model_state = self.model.module.state_dict() |
|
else: |
|
model_state = self.model.state_dict() |
|
gen_state = { |
|
'model_state_dict': model_state |
|
} |
|
|
|
opt_state = { |
|
'epoch': epoch, |
|
'iteration': self.currentStep, |
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
} |
|
|
|
gen_name = os.path.join(self.opt['path']['TRAINING_STATE'], |
|
'gen_{}_{}.pth.tar'.format(epoch, self.currentStep)) |
|
opt_name = os.path.join(self.opt['path']['TRAINING_STATE'], |
|
'opt_{}_{}.pth.tar'.format(epoch, self.currentStep)) |
|
torch.save(gen_state, gen_name) |
|
torch.save(opt_state, opt_name) |
|
|
|
def _validate(self, epoch): |
|
data_path = self.valInfo['data_root'] |
|
mask_path = self.valInfo['mask_root'] |
|
self.model.eval() |
|
test_list = os.listdir(data_path) |
|
test_list = test_list[:10] |
|
width, height = self.valInfo['flow_width'], self.valInfo['flow_height'] |
|
flow_interval = self.opt['flow_interval'] |
|
psnr, ssim, l1, l2, prec, recall = {}, {}, {}, {}, {}, {} |
|
pivot, sequenceLen = 20, self.opt['num_flows'] |
|
for i in range(len(test_list)): |
|
videoName = test_list[i] |
|
if self.rank <= 0: |
|
self.logger.info(f'Video {videoName} is being processed') |
|
for direction in ['forward_flo', 'backward_flo']: |
|
flow_dir = os.path.join(data_path, videoName, direction) |
|
mask_dir = os.path.join(mask_path, videoName) |
|
flows = self.read_flows(flow_dir, width, height, pivot, sequenceLen, flow_interval) |
|
masks = self.read_masks(mask_dir, width, height, pivot, sequenceLen, flow_interval) |
|
if flows == [] or masks == []: |
|
if self.rank <= 0: |
|
print('Video {} doesn\'t have enough {} flows'.format(videoName, direction)) |
|
continue |
|
if self.rank <= 0: |
|
self.logger.info('Flows have been read') |
|
diffused_flows = self.diffusion_filling(flows, masks) |
|
flows = np.stack(flows, axis=0) |
|
masks = np.stack(masks, axis=0) |
|
diffused_flows = np.stack(diffused_flows, axis=0) |
|
target_flow = flows[self.opt['num_flows'] // 2] |
|
target_edge = self.load_edge(target_flow) |
|
target_edge = target_edge[:, :, np.newaxis] |
|
diffused_flows = torch.from_numpy(np.transpose(diffused_flows, (3, 0, 1, 2))).unsqueeze( |
|
0).float() |
|
masks = torch.from_numpy(np.transpose(masks, (3, 0, 1, 2))).unsqueeze(0).float() |
|
target_flow = torch.from_numpy(np.transpose(target_flow, (2, 0, 1))).unsqueeze( |
|
0).float() |
|
target_edge = torch.from_numpy(np.transpose(target_edge, (2, 0, 1))).unsqueeze(0).float() |
|
diffused_flows = diffused_flows.to(self.opt['device']) |
|
masks = masks.to(self.opt['device']) |
|
target_flow = target_flow.to(self.opt['device']) |
|
target_edge = target_edge.to(self.opt['device']) |
|
target_mask = masks[:, :, sequenceLen // 2] |
|
if diffused_flows.shape[2] == 1 and len(diffused_flows.shape) == 5: |
|
assert masks.shape[2] == 1 |
|
diffused_flows = diffused_flows.squeeze(2) |
|
masks = masks.squeeze(2) |
|
with torch.no_grad(): |
|
filled_flow = self.model(diffused_flows, masks, None) |
|
filled_flow, filled_edge = filled_flow |
|
if len(diffused_flows.shape) == 5: |
|
target_diffused_flow = diffused_flows[:, :, sequenceLen // 2] |
|
else: |
|
target_diffused_flow = diffused_flows |
|
combined_flow = target_flow * (1 - target_mask) + filled_flow * target_mask |
|
|
|
|
|
psnr_avg, ssim_avg, l1_avg, l2_avg = self.metrics_calc(combined_flow, target_flow) |
|
prec_avg, recall_avg = self.edgeMeasure(filled_edge, target_edge) |
|
psnr[videoName] = psnr_avg |
|
ssim[videoName] = ssim_avg |
|
l1[videoName] = l1_avg |
|
l2[videoName] = l2_avg |
|
prec[videoName] = prec_avg.item() |
|
recall[videoName] = recall_avg.item() |
|
|
|
|
|
if self.rank <= 0: |
|
if self.opt['use_tb_logger']: |
|
self.tb_logger.add_scalar('test/{}/l1'.format(videoName), l1_avg, |
|
self.currentStep) |
|
self.tb_logger.add_scalar('test/{}/l2'.format(videoName), l2_avg, self.currentStep) |
|
self.tb_logger.add_scalar('test/{}/psnr'.format(videoName), psnr_avg, self.currentStep) |
|
self.tb_logger.add_scalar('test/{}/ssim'.format(videoName), ssim_avg, self.currentStep) |
|
self.tb_logger.add_scalar('test/{}/prec'.format(videoName), prec_avg, self.currentStep) |
|
self.tb_logger.add_scalar('test/{}/recall'.format(videoName), recall_avg, self.currentStep) |
|
self.vis_flows(combined_flow, target_flow, target_diffused_flow, videoName, |
|
epoch) |
|
mean_psnr = np.mean([psnr[k] for k in psnr.keys()]) |
|
mean_ssim = np.mean([ssim[k] for k in ssim.keys()]) |
|
mean_l1 = np.mean([l1[k] for k in l1.keys()]) |
|
mean_l2 = np.mean([l2[k] for k in l2.keys()]) |
|
mean_prec = np.mean([prec[k] for k in prec.keys()]) |
|
mean_recall = np.mean([recall[k] for k in recall.keys()]) |
|
self.logger.info( |
|
'[epoch:{:3d}, vid:{}/{}], mean_l1: {:.4e}, mean_l2: {:.4e}, mean_psnr: {:}, mean_ssim: {:}, prec: {:}, recall: {:}'.format( |
|
epoch, i, len(test_list), mean_l1, mean_l2, mean_psnr, mean_ssim, mean_prec, mean_recall)) |
|
|
|
|
|
if self.rank <= 0: |
|
mean_psnr = np.mean([psnr[k] for k in psnr.keys()]) |
|
mean_ssim = np.mean([ssim[k] for k in ssim.keys()]) |
|
mean_l1 = np.mean([l1[k] for k in l1.keys()]) |
|
mean_l2 = np.mean([l2[k] for k in l2.keys()]) |
|
mean_prec = np.mean([prec[k] for k in prec.keys()]) |
|
mean_recall = np.mean([recall[k] for k in recall.keys()]) |
|
self.logger.info( |
|
'[epoch:{:3d}], mean_l1: {:.4e} mean_l2: {:.4e} mean_psnr: {:} mean_ssim: {:}, prec: {:}, recall: {:}'.format( |
|
epoch, mean_l1, mean_l2, mean_psnr, mean_ssim, mean_prec, mean_recall)) |
|
valid_l1 = mean_l1 + 100 |
|
self.save_checkpoint(epoch, 'l1', valid_l1) |
|
|
|
self.model.train() |
|
|
|
def load_edge(self, flow): |
|
flow_rgb = cvbase.flow2rgb(flow) |
|
flow_gray = cv2.cvtColor(flow_rgb, cv2.COLOR_RGB2GRAY) |
|
return canny(flow_gray, sigma=self.opt['datasets']['dataInfo']['edge']['sigma'], mask=None, |
|
low_threshold=self.opt['datasets']['dataInfo']['edge']['low_threshold'], |
|
high_threshold=self.opt['datasets']['dataInfo']['edge']['high_threshold']).astype( |
|
np.float) |
|
|
|
def read_flows(self, flow_dir, width, height, pivot, sequenceLen, sample_interval): |
|
flow_paths = glob.glob(os.path.join(flow_dir, '*.flo')) |
|
flows = [] |
|
half_seq = sequenceLen // 2 |
|
for i in range(-half_seq, half_seq + 1): |
|
index = pivot + sample_interval * i |
|
if index < 0: |
|
index = 0 |
|
if index >= len(flow_paths): |
|
index = len(flow_paths) - 1 |
|
flow_path = os.path.join(flow_dir, '{:05d}.flo'.format(index)) |
|
flow = cvbase.read_flow(flow_path) |
|
pre_height, pre_width = flow.shape[:2] |
|
flow = cv2.resize(flow, (width, height), cv2.INTER_LINEAR) |
|
flow[:, :, 0] = flow[:, :, 0] / pre_width * width |
|
flow[:, :, 1] = flow[:, :, 1] / pre_height * height |
|
flows.append(flow) |
|
return flows |
|
|
|
def metrics_calc(self, result, frames): |
|
psnr_avg, ssim_avg, l1_avg, l2_avg = 0, 0, 0, 0 |
|
result = np.array(result.permute(0, 2, 3, 1).cpu()) |
|
gt = np.array(frames.permute(0, 2, 3, 1).cpu()) |
|
logs = self.calculate_metrics(result, gt) |
|
psnr_avg += logs['psnr'] |
|
ssim_avg += logs['ssim'] |
|
l1_avg += logs['l1'] |
|
l2_avg += logs['l2'] |
|
return psnr_avg, ssim_avg, l1_avg, l2_avg |
|
|
|
def read_frames(self, frame_dir, width, height, pivot, sequenceLen): |
|
frame_paths = sorted(glob.glob(os.path.join(frame_dir, '*.jpg'))) |
|
frames = [] |
|
if len(frame_paths) <= 30: |
|
return frames |
|
for i in range(pivot, pivot + sequenceLen): |
|
frame_path = os.path.join(frame_dir, '{:05d}.jpg'.format(i)) |
|
frame = imageio.imread(frame_path) |
|
frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR) |
|
frames.append(frame) |
|
return frames |
|
|
|
def load_edges(self, frames, width, height): |
|
edges = [] |
|
for i in range(len(frames)): |
|
frame = frames[i] |
|
frame_gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
|
edge = canny(frame_gray, sigma=self.valInfo['sigma'], mask=None, |
|
low_threshold=self.valInfo['low_threshold'], |
|
high_threshold=self.valInfo['high_threshold']).astype(np.float) |
|
edge_t = self.to_tensor(edge, width, height, mode='nearest') |
|
edges.append(edge_t) |
|
return edges |
|
|
|
def to_tensor(self, frame, width, height, mode='bilinear'): |
|
if len(frame.shape) == 2: |
|
frame = frame[:, :, np.newaxis] |
|
frame_t = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2).float() |
|
if width != 0 and height != 0: |
|
frame_t = F.interpolate(frame_t, size=(height, width), mode=mode) |
|
return frame_t |
|
|
|
def to_numpy(self, tensor): |
|
tensor = tensor.cpu() |
|
tensor = tensor[0] |
|
array = np.array(tensor.permute(1, 2, 0)) |
|
return array |
|
|
|
def read_masks(self, mask_dir, width, height, pivot, sequenceLen, sample_interval): |
|
mask_path = sorted(glob.glob(os.path.join(mask_dir, '*.png'))) |
|
masks = [] |
|
half_seq = sequenceLen // 2 |
|
for i in range(-half_seq, half_seq + 1): |
|
index = pivot + i * sample_interval |
|
if index < 0: |
|
index = 0 |
|
if index >= len(mask_path): |
|
index = len(mask_path) - 1 |
|
mask = cv2.imread(mask_path[index], 0) |
|
mask = mask / 255. |
|
mask = cv2.resize(mask, (width, height), cv2.INTER_NEAREST) |
|
mask[mask > 0] = 1 |
|
if len(mask.shape) == 2: |
|
mask = mask[:, :, np.newaxis] |
|
assert len(mask.shape) == 3, 'Invalid mask shape: {}'.format(mask.shape) |
|
masks.append(mask) |
|
return masks |
|
|
|
def diffusion_filling(self, flows, masks): |
|
filled_flows = [] |
|
for i in range(len(flows)): |
|
flow, mask = flows[i], masks[i][:, :, 0] |
|
flow_filled = np.zeros(flow.shape) |
|
flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0], mask) |
|
flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1], mask) |
|
filled_flows.append(flow_filled) |
|
return filled_flows |
|
|
|
def vis_flows(self, result, target_flow, diffused_flow, video_name, epoch): |
|
""" |
|
Vis the filled frames, the GT and the masked frames with the following format |
|
| | | | |
|
| Ours | GT | diffused_flows | |
|
| | | | |
|
Args: |
|
result: contains generated flow tensors with shape [1, 2, h, w] |
|
target_flow: contains GT flow tensors with shape [1, 2, h, w] |
|
diffused_flow: contains diffused flow tensor with shape [1, 2, h, w] |
|
video_name: video name |
|
epoch: epoch |
|
|
|
Returns: No returns, but will save the flows for every flow |
|
|
|
""" |
|
out_root = self.opt['path']['VAL_IMAGES'] |
|
out_dir = os.path.join(out_root, str(epoch), video_name) |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
black_column_pixels = 20 |
|
result = self.to_numpy(result) |
|
target_flow = self.to_numpy(target_flow) |
|
diffused_flow = self.to_numpy(diffused_flow) |
|
result = cvbase.flow2rgb(result) |
|
target_flow = cvbase.flow2rgb(target_flow) |
|
diffused_flow = cvbase.flow2rgb(diffused_flow) |
|
height, width = result.shape[:2] |
|
canvas = np.zeros((height, width * 3 + black_column_pixels * 2, 3)) |
|
canvas[:, 0:width, :] = result |
|
canvas[:, width + black_column_pixels: 2 * width + black_column_pixels, :] = target_flow |
|
canvas[:, 2 * (width + black_column_pixels):, :] = diffused_flow |
|
imageio.imwrite(os.path.join(out_dir, 'result_compare.png'), canvas) |
|
|