oguzakif's picture
checkpoint paths converted to abs ones
e7aba9f
raw
history blame
10.1 kB
from PIL import Image
import gradio as gr
from FGT_codes.tool.video_inpainting import video_inpainting
from SiamMask.tools.test import *
from SiamMask.experiments.siammask_sharp.custom import Custom
from types import SimpleNamespace
import torch
import numpy as np
import torchvision
import cv2
import sys
from os.path import exists, join, basename, splitext
import os
import argparse
project_name = ''
sys.path.append(project_name)
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes')))
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'tool')))
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'tool','configs')))
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'LAFC', 'flowCheckPoint')))
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'LAFC', 'checkpoint')))
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'FGT', 'checkpoint')))
sys.path.append(os.path.abspath(join(project_name, 'FGT_codes', 'LAFC',
'flowCheckPoint', 'raft-things.pth')))
# sys.path.append(join(project_name, 'SiamMask',
# 'experiments', 'siammask_sharp'))
# sys.path.append(join(project_name, 'SiamMask', 'models'))
# sys.path.append(join(project_name, 'SiamMask'))
exp_path = join(project_name, 'SiamMask/experiments/siammask_sharp')
pretrained_path1 = join(exp_path, 'SiamMask_DAVIS.pth')
print(sys.path)
torch.set_grad_enabled(False)
# init SiamMask
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg = load_config(SimpleNamespace(config=join(exp_path, 'config_davis.json')))
siammask = Custom(anchors=cfg['anchors'])
siammask = load_pretrain(siammask, pretrained_path1)
siammask = siammask.eval().to(device)
# constants
object_x = 0
object_y = 0
object_width = 0
object_height = 0
original_frame_list = []
mask_list = []
parser = argparse.ArgumentParser()
# parser.add_argument('--opt', default='configs/object_removal.yaml',
# help='Please select your config file for inference')
parser.add_argument('--opt', default=os.path.abspath(join(project_name, 'FGT_codes', 'tool','configs','object_removal.yaml')),
help='Please select your config file for inference')
# video completion
parser.add_argument('--mode', default='object_removal', choices=[
'object_removal', 'watermark_removal', 'video_extrapolation'], help="modes: object_removal / video_extrapolation")
parser.add_argument(
'--path', default='/myData/davis_resized/walking', help="dataset for evaluation")
parser.add_argument(
'--path_mask', default='/myData/dilateAnnotations_4/walking', help="mask for object removal")
parser.add_argument(
'--outroot', default='quick_start/walking3', help="output directory")
parser.add_argument('--consistencyThres', dest='consistencyThres', default=5, type=float,
help='flow consistency error threshold')
parser.add_argument('--alpha', dest='alpha', default=0.1, type=float)
parser.add_argument('--Nonlocal', dest='Nonlocal',
default=False, type=bool)
# RAFT
# parser.add_argument(
# '--raft_model', default='../LAFC/flowCheckPoint/raft-things.pth', help="restore checkpoint")
parser.add_argument(
'--raft_model', default=os.path.abspath(join(project_name, 'FGT_codes', 'LAFC','flowCheckPoint','raft-things.pth')), help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true',
help='use efficent correlation implementation')
# LAFC
# parser.add_argument('--lafc_ckpts', type=str, default='../LAFC/checkpoint')
parser.add_argument('--lafc_ckpts', type=str, default=os.path.abspath(join(project_name, 'FGT_codes', 'LAFC','checkpoint')))
# FGT
# parser.add_argument('--fgt_ckpts', type=str, default='../FGT/checkpoint')
parser.add_argument('--fgt_ckpts', type=str, default=os.path.abspath(join(project_name, 'FGT_codes', 'FGT','checkpoint')))
# extrapolation
parser.add_argument('--H_scale', dest='H_scale', default=2,
type=float, help='H extrapolation scale')
parser.add_argument('--W_scale', dest='W_scale', default=2,
type=float, help='W extrapolation scale')
# Image basic information
parser.add_argument('--imgH', type=int, default=256)
parser.add_argument('--imgW', type=int, default=432)
parser.add_argument('--flow_mask_dilates', type=int, default=8)
parser.add_argument('--frame_dilates', type=int, default=0)
parser.add_argument('--gpu', type=int, default=0)
# FGT inference parameters
parser.add_argument('--step', type=int, default=10)
parser.add_argument('--num_ref', type=int, default=-1)
parser.add_argument('--neighbor_stride', type=int, default=5)
# visualization
parser.add_argument('--vis_flows', action='store_true',
help='Visualize the initialized flows')
parser.add_argument('--vis_completed_flows',
action='store_true', help='Visualize the completed flows')
parser.add_argument('--vis_prop', action='store_true',
help='Visualize the frames after stage-I filling (flow guided content propagation)')
parser.add_argument('--vis_frame', action='store_true',
help='Visualize frames')
args = parser.parse_args()
def getBoundaries(mask):
if mask is None:
return 0, 0, 0, 0
indexes = np.where((mask == [255, 255, 255]).all(axis=2))
print(indexes)
x1 = min(indexes[1])
y1 = min(indexes[0])
x2 = max(indexes[1])
y2 = max(indexes[0])
return x1, y1, (x2-x1), (y2-y1)
def track_and_mask(vid, original_frame, masked_frame):
x, y, w, h = getBoundaries(masked_frame)
f = 0
video_capture = cv2.VideoCapture()
if video_capture.open(vid):
width, height = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video_capture.get(cv2.CAP_PROP_FPS)
# can't write out mp4, so try to write into an AVI file
video_writer = cv2.VideoWriter(
"output.avi", cv2.VideoWriter_fourcc(*'MP42'), fps, (width, height))
video_writer2 = cv2.VideoWriter(
"output_mask.avi", cv2.VideoWriter_fourcc(*'MP42'), fps, (width, height))
while video_capture.isOpened():
ret, frame = video_capture.read()
if not ret:
break
# frame = cv2.resize(frame, (w - w % 8, h - h % 8))
if f == 0:
target_pos = np.array([x + w / 2, y + h / 2])
target_sz = np.array([w, h])
# init tracker
state = siamese_init(
frame, target_pos, target_sz, siammask, cfg['hp'], device=device)
else:
# track
state = siamese_track(
state, frame, mask_enable=True, refine_enable=True, device=device)
location = state['ploygon'].flatten()
mask = state['mask'] > state['p'].seg_thr
frame[:, :, 2] = (mask > 0) * 255 + \
(mask == 0) * frame[:, :, 2]
mask = mask.astype(np.uint8) # convert to an unsigned byte
mask = mask * 255
mask_list.append(mask)
cv2.polylines(frame, [np.int0(location).reshape(
(-1, 1, 2))], True, (0, 255, 0), 3)
original_frame_list.append(frame)
mask_list.append(mask)
video_writer.write(frame)
video_writer2.write(mask)
f = f + 1
video_capture.release()
video_writer.release()
video_writer2.release()
else:
print("can't open the given input video file!")
return "output.avi"
def inpaint_video():
video_inpainting(args, original_frame_list, mask_list)
return "result.mp4"
def get_first_frame(video):
video_capture = cv2.VideoCapture()
if video_capture.open(video):
width, height = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
if video_capture.isOpened():
ret, frame = video_capture.read()
RGB_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return RGB_frame
def drawRectangle(frame, mask):
x1, y1, x2, y2 = getBoundaries(mask)
return cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
def getStartEndPoints(mask):
if mask is None:
return 0, 0, 0, 0
indexes = np.where((mask == [255, 255, 255]).all(axis=2))
print(indexes)
x1 = min(indexes[1])
y1 = min(indexes[0])
x2 = max(indexes[1])
y2 = max(indexes[0])
return x1, y1, x2, y2
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
in_video = gr.PlayableVideo()
with gr.Row():
first_frame = gr.ImageMask()
with gr.Row():
approve_mask = gr.Button(value="Approve Mask")
with gr.Column(scale=1):
with gr.Row():
original_image = gr.Image(interactive=False)
with gr.Row():
masked_image = gr.Image(interactive=False)
with gr.Column(scale=2):
out_video = gr.Video()
out_video_inpaint = gr.Video()
track_mask = gr.Button(value="Track and Mask")
inpaint = gr.Button(value="Inpaint")
in_video.change(fn=get_first_frame, inputs=[
in_video], outputs=[first_frame])
approve_mask.click(lambda x: [x['image'], x['mask']], first_frame, [
original_image, masked_image])
track_mask.click(fn=track_and_mask, inputs=[
in_video, original_image, masked_image], outputs=[out_video])
inpaint.click(fn=inpaint_video, outputs=[out_video_inpaint])
demo.launch(debug=True)