|
import sys |
|
import os |
|
import warnings |
|
|
|
sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) |
|
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "FGT"))) |
|
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..", "LAFC"))) |
|
warnings.filterwarnings("ignore") |
|
|
|
import cvbase |
|
from torchvision.transforms import ToTensor |
|
from tool import get_flowNN_gradient |
|
from tool.utils.Poisson_blend_img import Poisson_blend_img |
|
from tool.utils.region_fill import regionfill |
|
from importlib import import_module |
|
import yaml |
|
from RAFT import RAFT |
|
from RAFT import utils |
|
import torch.nn.functional as F2 |
|
import torchvision.transforms.functional as F |
|
from skimage.feature import canny |
|
import scipy.ndimage |
|
from PIL import Image |
|
import imageio |
|
import torch |
|
import numpy as np |
|
import copy |
|
import glob |
|
import cv2 |
|
import argparse |
|
|
|
|
|
def to_tensor(img): |
|
img = Image.fromarray(img) |
|
img_t = F.to_tensor(img).float() |
|
return img_t |
|
|
|
|
|
def diffusion(flows, masks): |
|
flows_filled = [] |
|
for i in range(flows.shape[0]): |
|
flow, mask = flows[i], masks[i] |
|
flow_filled = np.zeros(flow.shape) |
|
flow_filled[:, :, 0] = regionfill(flow[:, :, 0], mask[:, :, 0]) |
|
flow_filled[:, :, 1] = regionfill(flow[:, :, 1], mask[:, :, 0]) |
|
flows_filled.append(flow_filled) |
|
return flows_filled |
|
|
|
|
|
def np2tensor(array, near="c"): |
|
if isinstance(array, list): |
|
array = np.stack(array, axis=0) |
|
if near == "c": |
|
array = ( |
|
torch.from_numpy(np.transpose(array, (3, 0, 1, 2))).unsqueeze(0).float() |
|
) |
|
elif near == "t": |
|
array = torch.from_numpy(np.transpose(array, (0, 3, 1, 2))).unsqueeze(0).float() |
|
else: |
|
raise ValueError(f"Unknown near type: {near}") |
|
return array |
|
|
|
|
|
def tensor2np(array): |
|
array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy() |
|
return array |
|
|
|
|
|
def gradient_mask(mask): |
|
gradient_mask = np.logical_or.reduce( |
|
( |
|
mask, |
|
np.concatenate( |
|
(mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)), axis=0 |
|
), |
|
np.concatenate( |
|
(mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)), axis=1 |
|
), |
|
) |
|
) |
|
|
|
return gradient_mask |
|
|
|
|
|
def indicesGen(pivot, interval, frames, t): |
|
singleSide = frames // 2 |
|
results = [] |
|
for i in range(-singleSide, singleSide + 1): |
|
index = pivot + interval * i |
|
if index < 0: |
|
index = abs(index) |
|
if index > t - 1: |
|
index = 2 * (t - 1) - index |
|
results.append(index) |
|
return results |
|
|
|
|
|
def get_ref_index(f, neighbor_ids, length, ref_length, num_ref): |
|
ref_index = [] |
|
if num_ref == -1: |
|
for i in range(0, length, ref_length): |
|
if i not in neighbor_ids: |
|
ref_index.append(i) |
|
else: |
|
start_idx = max(0, f - ref_length * (num_ref // 2)) |
|
end_idx = min(length, f + ref_length * (num_ref // 2)) |
|
for i in range(start_idx, end_idx + 1, ref_length): |
|
if i not in neighbor_ids: |
|
if len(ref_index) > num_ref: |
|
break |
|
ref_index.append(i) |
|
return ref_index |
|
|
|
|
|
def save_flows(output, videoFlowF, videoFlowB): |
|
create_dir(os.path.join(output, "completed_flow", "forward_flo")) |
|
create_dir(os.path.join(output, "completed_flow", "backward_flo")) |
|
create_dir(os.path.join(output, "completed_flow", "forward_png")) |
|
create_dir(os.path.join(output, "completed_flow", "backward_png")) |
|
N = videoFlowF.shape[-1] |
|
for i in range(N): |
|
forward_flow = videoFlowF[..., i] |
|
backward_flow = videoFlowB[..., i] |
|
forward_flow_vis = cvbase.flow2rgb(forward_flow) |
|
backward_flow_vis = cvbase.flow2rgb(backward_flow) |
|
cvbase.write_flow( |
|
forward_flow, |
|
os.path.join( |
|
output, "completed_flow", "forward_flo", "{:05d}.flo".format(i) |
|
), |
|
) |
|
cvbase.write_flow( |
|
backward_flow, |
|
os.path.join( |
|
output, "completed_flow", "backward_flo", "{:05d}.flo".format(i) |
|
), |
|
) |
|
imageio.imwrite( |
|
os.path.join( |
|
output, "completed_flow", "forward_png", "{:05d}.png".format(i) |
|
), |
|
forward_flow_vis, |
|
) |
|
imageio.imwrite( |
|
os.path.join( |
|
output, "completed_flow", "backward_png", "{:05d}.png".format(i) |
|
), |
|
backward_flow_vis, |
|
) |
|
|
|
|
|
def save_fgcp(output, frames, masks): |
|
create_dir(os.path.join(output, "prop_frames")) |
|
create_dir(os.path.join(output, "masks_left")) |
|
create_dir(os.path.join(output, "prop_frames_npy")) |
|
create_dir(os.path.join(output, "masks_left_npy")) |
|
|
|
assert len(frames) == masks.shape[2] |
|
for i in range(len(frames)): |
|
cv2.imwrite( |
|
os.path.join(output, "prop_frames", "%05d.png" % i), frames[i] * 255.0 |
|
) |
|
cv2.imwrite( |
|
os.path.join(output, "masks_left", "%05d.png" % i), masks[:, :, i] * 255.0 |
|
) |
|
np.save( |
|
os.path.join(output, "prop_frames_npy", "%05d.npy" % i), frames[i] * 255.0 |
|
) |
|
np.save( |
|
os.path.join(output, "masks_left_npy", "%05d.npy" % i), |
|
masks[:, :, i] * 255.0, |
|
) |
|
|
|
|
|
def create_dir(dir): |
|
"""Creates a directory if not exist.""" |
|
if not os.path.exists(dir): |
|
os.makedirs(dir) |
|
|
|
|
|
def initialize_RAFT(args, device): |
|
"""Initializes the RAFT model.""" |
|
model = torch.nn.DataParallel(RAFT(args)) |
|
if not torch.cuda.is_available(): |
|
model.load_state_dict(torch.load(args.raft_model, map_location=lambda storage, loc: storage)) |
|
else: |
|
model.load_state_dict(torch.load(args.raft_model)) |
|
|
|
model = model.module |
|
model.to(device) |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
def initialize_LAFC(args, device): |
|
print(args.lafc_ckpts) |
|
assert len(os.listdir(args.lafc_ckpts)) == 2 |
|
checkpoint, config_file = ( |
|
glob.glob(os.path.join(args.lafc_ckpts, "*.tar"))[0], |
|
glob.glob(os.path.join(args.lafc_ckpts, "*.yaml"))[0], |
|
) |
|
with open(config_file, "r") as f: |
|
configs = yaml.full_load(f) |
|
model = configs["model"] |
|
pkg = import_module("LAFC.models.{}".format(model)) |
|
model = pkg.Model(configs) |
|
if not torch.cuda.is_available(): |
|
state = torch.load( |
|
checkpoint, map_location=lambda storage, loc: storage |
|
) |
|
else: |
|
state = torch.load( |
|
checkpoint, map_location=lambda storage, loc: storage.cuda(device) |
|
) |
|
model.load_state_dict(state["model_state_dict"]) |
|
model = model.to(device) |
|
return model, configs |
|
|
|
|
|
def initialize_FGT(args, device): |
|
assert len(os.listdir(args.fgt_ckpts)) == 2 |
|
checkpoint, config_file = ( |
|
glob.glob(os.path.join(args.fgt_ckpts, "*.tar"))[0], |
|
glob.glob(os.path.join(args.fgt_ckpts, "*.yaml"))[0], |
|
) |
|
with open(config_file, "r") as f: |
|
configs = yaml.full_load(f) |
|
model = configs["model"] |
|
net = import_module("FGT.models.{}".format(model)) |
|
model = net.Model(configs).to(device) |
|
if not torch.cuda.is_available(): |
|
state = torch.load( |
|
checkpoint, map_location=lambda storage, loc: storage |
|
) |
|
else: |
|
state = torch.load( |
|
checkpoint, map_location=lambda storage, loc: storage.cuda(device) |
|
) |
|
model.load_state_dict(state["model_state_dict"]) |
|
return model, configs |
|
|
|
|
|
def calculate_flow(args, model, video, mode): |
|
"""Calculates optical flow.""" |
|
if mode not in ["forward", "backward"]: |
|
raise NotImplementedError |
|
|
|
imgH, imgW = args.imgH, args.imgW |
|
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) |
|
|
|
if args.vis_flows: |
|
create_dir(os.path.join(args.outroot, "flow", mode + "_flo")) |
|
create_dir(os.path.join(args.outroot, "flow", mode + "_png")) |
|
|
|
with torch.no_grad(): |
|
for i in range(video.shape[0] - 1): |
|
print( |
|
"Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), |
|
"\r", |
|
end="", |
|
) |
|
if mode == "forward": |
|
|
|
image1 = video[i, None] |
|
image2 = video[i + 1, None] |
|
elif mode == "backward": |
|
|
|
image1 = video[i + 1, None] |
|
image2 = video[i, None] |
|
else: |
|
raise NotImplementedError |
|
|
|
_, flow = model(image1, image2, iters=20, test_mode=True) |
|
flow = flow[0].permute(1, 2, 0).cpu().numpy() |
|
|
|
h, w = flow.shape[:2] |
|
if h != imgH or w != imgW: |
|
flow = cv2.resize(flow, (imgW, imgH), cv2.INTER_LINEAR) |
|
flow[:, :, 0] *= float(imgW) / float(w) |
|
flow[:, :, 1] *= float(imgH) / float(h) |
|
|
|
Flow = np.concatenate((Flow, flow[..., None]), axis=-1) |
|
|
|
if args.vis_flows: |
|
|
|
flow_img = utils.flow_viz.flow_to_image(flow) |
|
flow_img = Image.fromarray(flow_img) |
|
|
|
|
|
flow_img.save( |
|
os.path.join(args.outroot, "flow", mode + "_png", "%05d.png" % i) |
|
) |
|
utils.frame_utils.writeFlow( |
|
os.path.join(args.outroot, "flow", mode + "_flo", "%05d.flo" % i), |
|
flow, |
|
) |
|
|
|
return Flow |
|
|
|
|
|
def extrapolation(args, video_ori, corrFlowF_ori, corrFlowB_ori): |
|
"""Prepares the data for video extrapolation.""" |
|
imgH, imgW, _, nFrame = video_ori.shape |
|
|
|
|
|
imgH_extr = int(args.H_scale * imgH) |
|
imgW_extr = int(args.W_scale * imgW) |
|
imgH_extr = imgH_extr - imgH_extr % 4 |
|
imgW_extr = imgW_extr - imgW_extr % 4 |
|
H_start = int((imgH_extr - imgH) / 2) |
|
W_start = int((imgW_extr - imgW) / 2) |
|
|
|
|
|
flow_mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.bool) |
|
flow_mask[H_start : H_start + imgH, W_start : W_start + imgW] = 0 |
|
|
|
mask_dilated = gradient_mask(flow_mask) |
|
|
|
|
|
video = np.zeros(((imgH_extr, imgW_extr, 3, nFrame)), dtype=np.float32) |
|
video[H_start : H_start + imgH, W_start : W_start + imgW, :, :] = video_ori |
|
|
|
for i in range(nFrame): |
|
print("Preparing frame {0}".format(i), "\r", end="") |
|
video[:, :, :, i] = ( |
|
cv2.inpaint( |
|
(video[:, :, :, i] * 255).astype(np.uint8), |
|
flow_mask.astype(np.uint8), |
|
3, |
|
cv2.INPAINT_TELEA, |
|
).astype(np.float32) |
|
/ 255.0 |
|
) |
|
|
|
|
|
corrFlowF = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32) |
|
corrFlowB = np.zeros(((imgH_extr, imgW_extr, 2, nFrame - 1)), dtype=np.float32) |
|
corrFlowF[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowF_ori |
|
corrFlowB[H_start : H_start + imgH, W_start : W_start + imgW, :] = corrFlowB_ori |
|
|
|
return ( |
|
video, |
|
corrFlowF, |
|
corrFlowB, |
|
flow_mask, |
|
mask_dilated, |
|
(W_start, H_start), |
|
(W_start + imgW, H_start + imgH), |
|
) |
|
|
|
|
|
def complete_flow(config, flow_model, flows, flow_masks, mode, device): |
|
if mode not in ["forward", "backward"]: |
|
raise NotImplementedError(f"Error flow mode {mode}") |
|
flow_masks = np.moveaxis(flow_masks, -1, 0) |
|
flows = np.moveaxis(flows, -1, 0) |
|
if len(flow_masks.shape) == 3: |
|
flow_masks = flow_masks[:, :, :, np.newaxis] |
|
if mode == "forward": |
|
flow_masks = flow_masks[0:-1] |
|
else: |
|
flow_masks = flow_masks[1:] |
|
|
|
num_flows, flow_interval = config["num_flows"], config["flow_interval"] |
|
|
|
diffused_flows = diffusion(flows, flow_masks) |
|
|
|
flows = np2tensor(flows) |
|
flow_masks = np2tensor(flow_masks) |
|
diffused_flows = np2tensor(diffused_flows) |
|
|
|
flows = flows.to(device) |
|
flow_masks = flow_masks.to(device) |
|
diffused_flows = diffused_flows.to(device) |
|
|
|
t = diffused_flows.shape[2] |
|
filled_flows = [None] * t |
|
pivot = num_flows // 2 |
|
for i in range(t): |
|
indices = indicesGen(i, flow_interval, num_flows, t) |
|
print("Indices: ", indices, "\r", end="") |
|
cand_flows = flows[:, :, indices] |
|
cand_masks = flow_masks[:, :, indices] |
|
inputs = diffused_flows[:, :, indices] |
|
pivot_mask = cand_masks[:, :, pivot] |
|
pivot_flow = cand_flows[:, :, pivot] |
|
with torch.no_grad(): |
|
output_flow = flow_model(inputs, cand_masks) |
|
if isinstance(output_flow, tuple) or isinstance(output_flow, list): |
|
output_flow = output_flow[0] |
|
comp = output_flow * pivot_mask + pivot_flow * (1 - pivot_mask) |
|
if filled_flows[i] is None: |
|
filled_flows[i] = comp |
|
assert None not in filled_flows |
|
return filled_flows |
|
|
|
|
|
def read_flow(flow_dir, video): |
|
nFrame, _, imgH, imgW = video.shape |
|
Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) |
|
flows = sorted(glob.glob(os.path.join(flow_dir, "*.flo"))) |
|
for flow in flows: |
|
flow_data = cvbase.read_flow(flow) |
|
h, w = flow_data.shape[:2] |
|
flow_data = cv2.resize(flow_data, (imgW, imgH), cv2.INTER_LINEAR) |
|
flow_data[:, :, 0] *= float(imgW) / float(w) |
|
flow_data[:, :, 1] *= float(imgH) / float(h) |
|
Flow = np.concatenate((Flow, flow_data[..., None]), axis=-1) |
|
return Flow |
|
|
|
|
|
def norm_flows(flows): |
|
assert len(flows.shape) == 5, "FLow shape: {}".format(flows.shape) |
|
flattened_flows = flows.flatten(3) |
|
flow_max = torch.max(flattened_flows, dim=-1, keepdim=True)[0] |
|
flows = flows / flow_max.unsqueeze(-1) |
|
return flows |
|
|
|
|
|
def save_results(outdir, comp_frames): |
|
out_dir = os.path.join(outdir, "frames") |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
for i in range(len(comp_frames)): |
|
out_path = os.path.join(out_dir, "{:05d}.png".format(i)) |
|
cv2.imwrite(out_path, comp_frames[i][:, :, ::-1]) |
|
|
|
|
|
def video_inpainting(args, imgArr, imgMaskArr): |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(args) |
|
if args.opt is not None: |
|
with open(args.opt, "r") as f: |
|
opts = yaml.full_load(f) |
|
|
|
for k in opts.keys(): |
|
if k in args: |
|
setattr(args, k, opts[k]) |
|
|
|
print(args) |
|
|
|
RAFT_model = initialize_RAFT(args, device) |
|
|
|
LAFC_model, LAFC_config = initialize_LAFC(args, device) |
|
|
|
FGT_model, FGT_config = initialize_FGT(args, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
imgH, imgW = args.imgH, args.imgW |
|
|
|
nFrame = len(imgArr) |
|
|
|
if imgH < 350: |
|
flowH, flowW = imgH * 2, imgW * 2 |
|
else: |
|
flowH, flowW = imgH, imgW |
|
|
|
|
|
video, video_flow = [], [] |
|
if args.mode == "watermark_removal": |
|
maskname_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob( |
|
os.path.join(args.path_mask, "*.jpg") |
|
) |
|
assert len(filename_list) == len(maskname_list) |
|
for filename, maskname in zip(sorted(filename_list), sorted(maskname_list)): |
|
frame = ( |
|
torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)) |
|
.permute(2, 0, 1) |
|
.float() |
|
.unsqueeze(0) |
|
) |
|
mask = ( |
|
torch.from_numpy(np.array(Image.open(maskname)).astype(np.uint8)) |
|
.permute(2, 0, 1) |
|
.float() |
|
.unsqueeze(0) |
|
) |
|
mask[mask > 0] = 1 |
|
frame = frame * (1 - mask) |
|
frame = F2.upsample( |
|
frame, size=(imgH, imgW), mode="bilinear", align_corners=False |
|
) |
|
frame_flow = F2.upsample( |
|
frame, size=(flowH, flowW), mode="bilinear", align_corners=False |
|
) |
|
video.append(frame) |
|
video_flow.append(frame_flow) |
|
else: |
|
"""for filename in sorted(filename_list): |
|
frame = torch.from_numpy(np.array(Image.open(filename)).astype(np.uint8)).permute(2, 0, 1).float().unsqueeze(0) |
|
frame = F2.upsample(frame, size=(imgH, imgW), mode='bilinear', align_corners=False) |
|
frame_flow = F2.upsample(frame, size=(flowH, flowW), mode='bilinear', align_corners=False) |
|
video.append(frame) |
|
video_flow.append(frame_flow)""" |
|
for im in imgArr: |
|
frame = ( |
|
torch.from_numpy(np.array(im).astype(np.uint8)) |
|
.permute(2, 0, 1) |
|
.float() |
|
.unsqueeze(0) |
|
) |
|
frame = F2.upsample( |
|
frame, size=(imgH, imgW), mode="bilinear", align_corners=False |
|
) |
|
frame_flow = F2.upsample( |
|
frame, size=(flowH, flowW), mode="bilinear", align_corners=False |
|
) |
|
video.append(frame) |
|
video_flow.append(frame_flow) |
|
|
|
video = torch.cat(video, dim=0) |
|
video_flow = torch.cat(video_flow, dim=0) |
|
gts = video.clone() |
|
video = video.to(device) |
|
video_flow = video_flow.to(device) |
|
|
|
|
|
forward_flows = calculate_flow( |
|
args, RAFT_model, video_flow, "forward" |
|
) |
|
backward_flows = calculate_flow(args, RAFT_model, video_flow, "backward") |
|
|
|
|
|
video = ( |
|
video.permute(2, 3, 1, 0).cpu().numpy()[:, :, ::-1, :] / 255.0 |
|
) |
|
|
|
if args.mode == "video_extrapolation": |
|
|
|
( |
|
video, |
|
forward_flows, |
|
backward_flows, |
|
flow_mask, |
|
mask_dilated, |
|
start_point, |
|
end_point, |
|
) = extrapolation(args, video, forward_flows, backward_flows) |
|
imgH, imgW = video.shape[:2] |
|
|
|
|
|
mask = np.tile(flow_mask[..., None], (1, 1, nFrame)) |
|
flow_mask = np.tile(flow_mask[..., None], (1, 1, nFrame)) |
|
mask_dilated = np.tile(mask_dilated[..., None], (1, 1, nFrame)) |
|
|
|
else: |
|
|
|
filename_list = glob.glob(os.path.join(args.path_mask, "*.png")) + glob.glob( |
|
os.path.join(args.path_mask, "*.jpg") |
|
) |
|
|
|
mask = [] |
|
mask_dilated = [] |
|
flow_mask = [] |
|
"""for filename in sorted(filename_list): |
|
mask_img = np.array(Image.open(filename).convert('L')) |
|
mask_img = cv2.resize(mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST) |
|
|
|
if args.flow_mask_dilates > 0: |
|
flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.flow_mask_dilates) |
|
else: |
|
flow_mask_img = mask_img |
|
flow_mask.append(flow_mask_img) |
|
|
|
if args.frame_dilates > 0: |
|
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=args.frame_dilates) |
|
mask.append(mask_img) |
|
mask_dilated.append(gradient_mask(mask_img))""" |
|
|
|
for f_mask in imgMaskArr: |
|
mask_img = np.array(f_mask) |
|
mask_img = cv2.resize( |
|
mask_img, dsize=(imgW, imgH), interpolation=cv2.INTER_NEAREST |
|
) |
|
|
|
if args.flow_mask_dilates > 0: |
|
flow_mask_img = scipy.ndimage.binary_dilation( |
|
mask_img, iterations=args.flow_mask_dilates |
|
) |
|
else: |
|
flow_mask_img = mask_img |
|
flow_mask.append(flow_mask_img) |
|
|
|
if args.frame_dilates > 0: |
|
mask_img = scipy.ndimage.binary_dilation( |
|
mask_img, iterations=args.frame_dilates |
|
) |
|
mask.append(mask_img) |
|
mask_dilated.append(gradient_mask(mask_img)) |
|
|
|
|
|
mask = np.stack(mask, -1).astype(np.bool) |
|
mask_dilated = np.stack(mask_dilated, -1).astype(np.bool) |
|
flow_mask = np.stack(flow_mask, -1).astype(np.bool) |
|
|
|
|
|
videoFlowF = complete_flow( |
|
LAFC_config, LAFC_model, forward_flows, flow_mask, "forward", device |
|
) |
|
videoFlowB = complete_flow( |
|
LAFC_config, LAFC_model, backward_flows, flow_mask, "backward", device |
|
) |
|
videoFlowF = tensor2np(videoFlowF) |
|
videoFlowB = tensor2np(videoFlowB) |
|
print("\nFinish flow completion.") |
|
|
|
if args.vis_completed_flows: |
|
save_flows(args.outroot, videoFlowF, videoFlowB) |
|
|
|
|
|
gradient_x = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32) |
|
gradient_y = np.empty(((imgH, imgW, 3, 0)), dtype=np.float32) |
|
|
|
for indFrame in range(nFrame): |
|
img = video[:, :, :, indFrame] |
|
img[mask[:, :, indFrame], :] = 0 |
|
img = ( |
|
cv2.inpaint( |
|
(img * 255).astype(np.uint8), |
|
mask[:, :, indFrame].astype(np.uint8), |
|
3, |
|
cv2.INPAINT_TELEA, |
|
).astype(np.float32) |
|
/ 255.0 |
|
) |
|
|
|
gradient_x_ = np.concatenate( |
|
(np.diff(img, axis=1), np.zeros((imgH, 1, 3), dtype=np.float32)), axis=1 |
|
) |
|
gradient_y_ = np.concatenate( |
|
(np.diff(img, axis=0), np.zeros((1, imgW, 3), dtype=np.float32)), axis=0 |
|
) |
|
gradient_x = np.concatenate( |
|
(gradient_x, gradient_x_.reshape(imgH, imgW, 3, 1)), axis=-1 |
|
) |
|
gradient_y = np.concatenate( |
|
(gradient_y, gradient_y_.reshape(imgH, imgW, 3, 1)), axis=-1 |
|
) |
|
|
|
gradient_x[mask_dilated[:, :, indFrame], :, indFrame] = 0 |
|
gradient_y[mask_dilated[:, :, indFrame], :, indFrame] = 0 |
|
|
|
gradient_x_filled = gradient_x |
|
gradient_y_filled = gradient_y |
|
mask_gradient = mask_dilated |
|
video_comp = video |
|
|
|
|
|
gradient_x_filled, gradient_y_filled, mask_gradient = get_flowNN_gradient( |
|
args, |
|
gradient_x_filled, |
|
gradient_y_filled, |
|
mask, |
|
mask_gradient, |
|
videoFlowF, |
|
videoFlowB, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
for indFrame in range(nFrame): |
|
mask_gradient[:, :, indFrame] = scipy.ndimage.binary_fill_holes( |
|
mask_gradient[:, :, indFrame] |
|
).astype(np.bool) |
|
|
|
|
|
|
|
frameBlends = [] |
|
for indFrame in range(nFrame): |
|
print("Poisson blending frame {0:3d}".format(indFrame)) |
|
|
|
if mask[:, :, indFrame].sum() > 0: |
|
try: |
|
frameBlend, UnfilledMask = Poisson_blend_img( |
|
video_comp[:, :, :, indFrame], |
|
gradient_x_filled[:, 0 : imgW - 1, :, indFrame], |
|
gradient_y_filled[0 : imgH - 1, :, :, indFrame], |
|
mask[:, :, indFrame], |
|
mask_gradient[:, :, indFrame], |
|
) |
|
except: |
|
frameBlend, UnfilledMask = ( |
|
video_comp[:, :, :, indFrame], |
|
mask[:, :, indFrame], |
|
) |
|
|
|
frameBlend = np.clip(frameBlend, 0, 1.0) |
|
tmp = ( |
|
cv2.inpaint( |
|
(frameBlend * 255).astype(np.uint8), |
|
UnfilledMask.astype(np.uint8), |
|
3, |
|
cv2.INPAINT_TELEA, |
|
).astype(np.float32) |
|
/ 255.0 |
|
) |
|
frameBlend[UnfilledMask, :] = tmp[UnfilledMask, :] |
|
|
|
video_comp[:, :, :, indFrame] = frameBlend |
|
mask[:, :, indFrame] = UnfilledMask |
|
|
|
frameBlend_ = copy.deepcopy(frameBlend) |
|
|
|
frameBlend_[mask[:, :, indFrame], :] = [0, 1.0, 0] |
|
else: |
|
frameBlend_ = video_comp[:, :, :, indFrame] |
|
frameBlends.append(frameBlend_) |
|
|
|
if args.vis_prop: |
|
save_fgcp(args.outroot, frameBlends, mask) |
|
|
|
video_length = len(frameBlends) |
|
|
|
for i in range(len(frameBlends)): |
|
frameBlends[i] = frameBlends[i][:, :, ::-1] |
|
|
|
frames_first = np2tensor(frameBlends, near="t").to(device) |
|
mask = np.moveaxis(mask, -1, 0) |
|
mask = mask[:, :, :, np.newaxis] |
|
masks = np2tensor(mask, near="t").to(device) |
|
normed_frames = frames_first * 2 - 1 |
|
comp_frames = [None] * video_length |
|
|
|
ref_length = args.step |
|
num_ref = args.num_ref |
|
neighbor_stride = args.neighbor_stride |
|
|
|
videoFlowF = np.moveaxis(videoFlowF, -1, 0) |
|
|
|
videoFlowF = np.concatenate([videoFlowF, videoFlowF[-1:, ...]], axis=0) |
|
|
|
flows = np2tensor(videoFlowF, near="t") |
|
flows = norm_flows(flows).to(device) |
|
|
|
for f in range(0, video_length, neighbor_stride): |
|
neighbor_ids = [ |
|
i |
|
for i in range( |
|
max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1) |
|
) |
|
] |
|
ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_length, num_ref) |
|
print(f, len(neighbor_ids), len(ref_ids)) |
|
selected_frames = normed_frames[:, neighbor_ids + ref_ids] |
|
selected_masks = masks[:, neighbor_ids + ref_ids] |
|
masked_frames = selected_frames * (1 - selected_masks) |
|
selected_flows = flows[:, neighbor_ids + ref_ids] |
|
with torch.no_grad(): |
|
filled_frames = FGT_model(masked_frames, selected_flows, selected_masks) |
|
filled_frames = (filled_frames + 1) / 2 |
|
filled_frames = filled_frames.cpu().permute(0, 2, 3, 1).numpy() * 255 |
|
for i in range(len(neighbor_ids)): |
|
idx = neighbor_ids[i] |
|
valid_frame = frames_first[0, idx].cpu().permute(1, 2, 0).numpy() * 255.0 |
|
valid_mask = masks[0, idx].cpu().permute(1, 2, 0).numpy() |
|
comp = np.array(filled_frames[i]).astype(np.uint8) * valid_mask + np.array( |
|
valid_frame |
|
).astype(np.uint8) * (1 - valid_mask) |
|
if comp_frames[idx] is None: |
|
comp_frames[idx] = comp |
|
else: |
|
comp_frames[idx] = ( |
|
comp_frames[idx].astype(np.float32) * 0.5 |
|
+ comp.astype(np.float32) * 0.5 |
|
) |
|
if args.vis_frame: |
|
save_results(args.outroot, comp_frames) |
|
create_dir(args.outroot) |
|
for i in range(len(comp_frames)): |
|
comp_frames[i] = comp_frames[i].astype(np.uint8) |
|
imageio.mimwrite( |
|
os.path.join(args.outroot, "result.mp4"), comp_frames, fps=30, quality=8 |
|
) |
|
print(f"Done, please check your result in {args.outroot} ") |
|
|
|
|
|
def main(args): |
|
assert args.mode in ( |
|
"object_removal", |
|
"video_extrapolation", |
|
"watermark_removal", |
|
), ( |
|
"Accepted modes: 'object_removal', 'video_extrapolation', and 'watermark_removal', but input is %s" |
|
) % args.mode |
|
video_inpainting(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--opt", |
|
default="configs/object_removal.yaml", |
|
help="Please select your config file for inference", |
|
) |
|
|
|
parser.add_argument( |
|
"--mode", |
|
default="object_removal", |
|
choices=["object_removal", "watermark_removal", "video_extrapolation"], |
|
help="modes: object_removal / video_extrapolation", |
|
) |
|
parser.add_argument( |
|
"--path", default="/myData/davis_resized/walking", help="dataset for evaluation" |
|
) |
|
parser.add_argument( |
|
"--path_mask", |
|
default="/myData/dilateAnnotations_4/walking", |
|
help="mask for object removal", |
|
) |
|
parser.add_argument( |
|
"--outroot", default="quick_start/walking3", help="output directory" |
|
) |
|
parser.add_argument( |
|
"--consistencyThres", |
|
dest="consistencyThres", |
|
default=5, |
|
type=float, |
|
help="flow consistency error threshold", |
|
) |
|
parser.add_argument("--alpha", dest="alpha", default=0.1, type=float) |
|
parser.add_argument("--Nonlocal", dest="Nonlocal", default=False, type=bool) |
|
|
|
|
|
parser.add_argument( |
|
"--raft_model", |
|
default="../LAFC/flowCheckPoint/raft-things.pth", |
|
help="restore checkpoint", |
|
) |
|
parser.add_argument("--small", action="store_true", help="use small model") |
|
parser.add_argument( |
|
"--mixed_precision", action="store_true", help="use mixed precision" |
|
) |
|
parser.add_argument( |
|
"--alternate_corr", |
|
action="store_true", |
|
help="use efficent correlation implementation", |
|
) |
|
|
|
|
|
parser.add_argument("--lafc_ckpts", type=str, default="../LAFC/checkpoint") |
|
|
|
|
|
parser.add_argument("--fgt_ckpts", type=str, default="../FGT/checkpoint") |
|
|
|
|
|
parser.add_argument( |
|
"--H_scale", dest="H_scale", default=2, type=float, help="H extrapolation scale" |
|
) |
|
parser.add_argument( |
|
"--W_scale", dest="W_scale", default=2, type=float, help="W extrapolation scale" |
|
) |
|
|
|
|
|
parser.add_argument("--imgH", type=int, default=256) |
|
parser.add_argument("--imgW", type=int, default=432) |
|
parser.add_argument("--flow_mask_dilates", type=int, default=8) |
|
parser.add_argument("--frame_dilates", type=int, default=0) |
|
|
|
parser.add_argument("--gpu", type=int, default=0) |
|
|
|
|
|
parser.add_argument("--step", type=int, default=10) |
|
parser.add_argument("--num_ref", type=int, default=-1) |
|
parser.add_argument("--neighbor_stride", type=int, default=5) |
|
|
|
|
|
parser.add_argument( |
|
"--vis_flows", action="store_true", help="Visualize the initialized flows" |
|
) |
|
parser.add_argument( |
|
"--vis_completed_flows", |
|
action="store_true", |
|
help="Visualize the completed flows", |
|
) |
|
parser.add_argument( |
|
"--vis_prop", |
|
action="store_true", |
|
help="Visualize the frames after stage-I filling (flow guided content propagation)", |
|
) |
|
parser.add_argument("--vis_frame", action="store_true", help="Visualize frames") |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|