|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
from os import makedirs |
|
from os.path import isfile, isdir, join |
|
|
|
from utils.log_helper import init_log, add_file_handler |
|
from utils.bbox_helper import cxy_wh_2_rect |
|
from utils.load_helper import load_pretrain |
|
from utils.benchmark_helper import load_dataset |
|
from utils.average_meter_helper import IouMeter |
|
|
|
import models as models |
|
from tools.test import siamese_init, siamese_track |
|
from utils.config_helper import load_config |
|
|
|
thrs = np.arange(0.3, 0.81, 0.05) |
|
|
|
model_zoo = sorted(name for name in models.__dict__ |
|
if not name.startswith("__") |
|
and callable(models.__dict__[name])) |
|
|
|
|
|
def parse_range(arg): |
|
param = map(float, arg.split(',')) |
|
return np.arange(*param) |
|
|
|
|
|
def parse_range_int(arg): |
|
param = map(int, arg.split(',')) |
|
return np.arange(*param) |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Finetune parameters for SiamMask tracker on DAVIS') |
|
parser.add_argument('--arch', dest='arch', default='SiamRPNA', choices=model_zoo + ['Custom'], |
|
help='architecture of pretrained model') |
|
parser.add_argument('--resume', default='', type=str, required=True, |
|
metavar='PATH',help='path to latest checkpoint (default: none)') |
|
parser.add_argument('--config', dest='config',help='hyperparameter of SiamMask in json format') |
|
parser.add_argument('--mask', action='store_true', help='whether use mask output') |
|
parser.add_argument('--refine', action='store_true', help='whether use mask refine output') |
|
parser.add_argument('--dataset', default='DAVIS2016', type=str, |
|
metavar='DATASET', help='dataset') |
|
parser.add_argument('-l', '--log', default="log_tune_davis.txt", type=str, |
|
help='log file') |
|
parser.add_argument('--penalty-k', default='0.0,0.1,0.03', type=parse_range, |
|
help='penalty_k range') |
|
parser.add_argument('--lr', default='0.8,1.01,0.05', type=parse_range, |
|
help='lr range') |
|
parser.add_argument('--window-influence', default='0.3,0.5,0.04', type=parse_range, |
|
help='window influence range') |
|
parser.add_argument('--search-region', default='255,256,8', type=parse_range_int, |
|
help='search region size') |
|
parser.add_argument('-v', '--visualization', dest='visualization', action='store_true', |
|
help='whether visualize result') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
def tune(param): |
|
regions = [] |
|
|
|
benchmark_result_path = join('result', param['dataset']) |
|
tracker_path = join(benchmark_result_path, (param['network_name'] + ('_refine' if args.refine else '') + |
|
'_r{}'.format(param['hp']['instance_size']) + |
|
'_penalty_k_{:.3f}'.format(param['hp']['penalty_k']) + |
|
'_window_influence_{:.3f}'.format(param['hp']['window_influence']) + |
|
'_lr_{:.3f}'.format(param['hp']['lr'])).replace('.', '_')) |
|
video_path = tracker_path |
|
result_path = join(video_path, param['video'] + '.txt') |
|
|
|
if isfile(result_path): |
|
return |
|
|
|
try: |
|
if not isdir(video_path): |
|
makedirs(video_path) |
|
except OSError as err: |
|
print(err) |
|
|
|
with open(result_path, 'w') as f: |
|
f.write('Occ') |
|
|
|
global ims, gt, annos, image_files, anno_files |
|
if ims is None: |
|
print(param['video'] + ' Only load image once and if needed') |
|
ims = [cv2.imread(x) for x in image_files] |
|
annos = [np.array(Image.open(x)) for x in anno_files] |
|
|
|
iou = IouMeter(thrs, len(ims) - 2) |
|
start_frame, end_frame, toc = 0, len(ims) - 1, 0 |
|
for f, (im, anno) in enumerate(zip(ims, annos)): |
|
tic = cv2.getTickCount() |
|
if f == start_frame: |
|
target_pos = np.array([gt[f, 0]+gt[f, 2]/2, gt[f, 1]+gt[f, 3]/2]) |
|
target_sz = np.array([gt[f, 2], gt[f, 3]]) |
|
state = siamese_init(im, target_pos, target_sz, param['network'], param['hp']) |
|
location = cxy_wh_2_rect(state['target_pos'], state['target_sz']) |
|
regions.append(gt[f]) |
|
elif f > start_frame: |
|
state = siamese_track(state, im, args.mask, args.refine) |
|
location = state['ploygon'].flatten() |
|
mask = state['mask'] |
|
|
|
regions.append(location) |
|
if start_frame < f < end_frame: iou.add(mask, anno) |
|
|
|
toc += cv2.getTickCount() - tic |
|
|
|
if args.visualization and f >= start_frame: |
|
im_show = im.copy() |
|
if f == 0: cv2.destroyAllWindows() |
|
if len(gt[f]) == 8: |
|
cv2.polylines(im_show, [np.array(gt[f], np.int).reshape((-1, 1, 2))], True, (0, 255, 0), 3) |
|
else: |
|
cv2.rectangle(im_show, (gt[f, 0], gt[f, 1]), (gt[f, 0] + gt[f, 2], gt[f, 1] + gt[f, 3]), (0, 255, 0), 3) |
|
if len(location) == 8: |
|
im_show[:,:,2] = mask*255 + (1-mask)*im_show[:,:,2] |
|
cv2.polylines(im_show, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 255), 3) |
|
else: |
|
location = [int(l) for l in location] |
|
cv2.rectangle(im_show, (location[0], location[1]), |
|
(location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3) |
|
cv2.putText(im_show, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) |
|
|
|
cv2.imshow(param['video'], im_show) |
|
cv2.waitKey(1) |
|
toc /= cv2.getTickFrequency() |
|
iou_list = iou.value('mean') |
|
print('Video: {:12s} Time: {:2.1f}s Speed: {:3.1f}fps IOU: {:.3f}'.format(param['video'], toc, f / toc, iou_list.max())) |
|
|
|
with open(result_path, 'w') as f: |
|
f.write(','.join(["%.5f" % i for i in iou_list]) + '\n') |
|
|
|
return iou_list |
|
|
|
|
|
def main(): |
|
init_log('global', logging.INFO) |
|
if args.log != "": |
|
add_file_handler('global', args.log, logging.INFO) |
|
|
|
params = {'penalty_k': args.penalty_k, |
|
'window_influence': args.window_influence, |
|
'lr': args.lr, |
|
'instance_size': args.search_region} |
|
|
|
num_search = len(params['penalty_k']) * len(params['window_influence']) * \ |
|
len(params['lr']) * len(params['instance_size']) |
|
|
|
print(params) |
|
print(num_search) |
|
|
|
cfg = load_config(args) |
|
if args.arch == 'Custom': |
|
from custom import Custom |
|
model = Custom(anchors=cfg['anchors']) |
|
else: |
|
model = models.__dict__[args.arch](anchors=cfg['anchors']) |
|
|
|
if args.resume: |
|
assert isfile(args.resume), '{} is not a valid file'.format(args.resume) |
|
model = load_pretrain(model, args.resume) |
|
model.eval() |
|
model = model.cuda() |
|
|
|
default_hp = cfg.get('hp', {}) |
|
|
|
p = dict() |
|
|
|
p['network'] = model |
|
p['network_name'] = args.arch+'_mask_'+args.resume.split('/')[-1].split('.')[0] |
|
p['dataset'] = args.dataset |
|
|
|
global ims, gt, annos, image_files, anno_files |
|
|
|
dataset_info = load_dataset(args.dataset) |
|
videos = list(dataset_info.keys()) |
|
np.random.shuffle(videos) |
|
for video in videos: |
|
print(video) |
|
p['video'] = video |
|
ims = None |
|
annos = None |
|
image_files = dataset_info[video]['image_files'] |
|
anno_files = dataset_info[video]['anno_files'] |
|
gt = dataset_info[video]['gt'] |
|
|
|
np.random.shuffle(params['penalty_k']) |
|
np.random.shuffle(params['window_influence']) |
|
np.random.shuffle(params['lr']) |
|
for penalty_k in params['penalty_k']: |
|
for window_influence in params['window_influence']: |
|
for lr in params['lr']: |
|
for instance_size in params['instance_size']: |
|
p['hp'] = default_hp.copy() |
|
p['hp'].update({'penalty_k':penalty_k, |
|
'window_influence':window_influence, |
|
'lr':lr, |
|
'instance_size': instance_size, |
|
}) |
|
tune(p) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|