|
|
|
|
|
|
|
|
|
|
|
from __future__ import division |
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
import json |
|
import random |
|
import logging |
|
from os.path import join |
|
from utils.bbox_helper import * |
|
from utils.anchors import Anchors |
|
import math |
|
import sys |
|
pyv = sys.version[0] |
|
import cv2 |
|
if pyv[0] == '3': |
|
cv2.ocl.setUseOpenCL(False) |
|
|
|
logger = logging.getLogger('global') |
|
|
|
|
|
sample_random = random.Random() |
|
sample_random.seed(123456) |
|
|
|
|
|
class SubDataSet(object): |
|
def __init__(self, cfg): |
|
for string in ['root', 'anno']: |
|
if string not in cfg: |
|
raise Exception('SubDataSet need "{}"'.format(string)) |
|
|
|
with open(cfg['anno']) as fin: |
|
logger.info("loading " + cfg['anno']) |
|
self.labels = self.filter_zero(json.load(fin), cfg) |
|
|
|
def isint(x): |
|
try: |
|
int(x) |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
to_del = [] |
|
for video in self.labels: |
|
for track in self.labels[video]: |
|
frames = self.labels[video][track] |
|
frames = list(map(int, filter(lambda x: isint(x), frames.keys()))) |
|
frames.sort() |
|
self.labels[video][track]['frames'] = frames |
|
if len(frames) <= 0: |
|
logger.info("warning {}/{} has no frames.".format(video, track)) |
|
to_del.append((video, track)) |
|
|
|
|
|
for video, track in to_del: |
|
del self.labels[video][track] |
|
|
|
|
|
to_del = [] |
|
for video in self.labels: |
|
if len(self.labels[video]) <= 0: |
|
logger.info("warning {} has no tracks".format(video)) |
|
to_del.append(video) |
|
|
|
for video in to_del: |
|
del self.labels[video] |
|
|
|
self.videos = list(self.labels.keys()) |
|
|
|
logger.info(cfg['anno'] + " loaded.") |
|
|
|
|
|
self.root = "/" |
|
self.start = 0 |
|
self.num = len(self.labels) |
|
self.num_use = self.num |
|
self.frame_range = 100 |
|
self.mark = "vid" |
|
self.path_format = "{}.{}.{}.jpg" |
|
|
|
self.pick = [] |
|
|
|
|
|
self.__dict__.update(cfg) |
|
|
|
self.num_use = int(self.num_use) |
|
|
|
|
|
self.shuffle() |
|
|
|
def filter_zero(self, anno, cfg): |
|
name = cfg.get('mark', '') |
|
|
|
out = {} |
|
tot = 0 |
|
new = 0 |
|
zero = 0 |
|
|
|
for video, tracks in anno.items(): |
|
new_tracks = {} |
|
for trk, frames in tracks.items(): |
|
new_frames = {} |
|
for frm, bbox in frames.items(): |
|
tot += 1 |
|
if len(bbox) == 4: |
|
x1, y1, x2, y2 = bbox |
|
w, h = x2 - x1, y2 -y1 |
|
else: |
|
w, h= bbox |
|
if w == 0 or h == 0: |
|
logger.info('Error, {name} {video} {trk} {bbox}'.format(**locals())) |
|
zero += 1 |
|
continue |
|
new += 1 |
|
new_frames[frm] = bbox |
|
|
|
if len(new_frames) > 0: |
|
new_tracks[trk] = new_frames |
|
|
|
if len(new_tracks) > 0: |
|
out[video] = new_tracks |
|
|
|
return out |
|
|
|
def log(self): |
|
logger.info('SubDataSet {name} start-index {start} select [{select}/{num}] path {format}'.format( |
|
name=self.mark, start=self.start, select=self.num_use, num=self.num, format=self.path_format |
|
)) |
|
|
|
def shuffle(self): |
|
lists = list(range(self.start, self.start + self.num)) |
|
|
|
m = 0 |
|
pick = [] |
|
while m < self.num_use: |
|
sample_random.shuffle(lists) |
|
pick += lists |
|
m += self.num |
|
|
|
self.pick = pick[:self.num_use] |
|
return self.pick |
|
|
|
def get_image_anno(self, video, track, frame): |
|
frame = "{:06d}".format(frame) |
|
image_path = join(self.root, video, self.path_format.format(frame, track, 'x')) |
|
image_anno = self.labels[video][track][frame] |
|
|
|
return image_path, image_anno |
|
|
|
def get_positive_pair(self, index): |
|
video_name = self.videos[index] |
|
video = self.labels[video_name] |
|
track = random.choice(list(video.keys())) |
|
track_info = video[track] |
|
|
|
frames = track_info['frames'] |
|
|
|
if 'hard' not in track_info: |
|
template_frame = random.randint(0, len(frames)-1) |
|
|
|
left = max(template_frame - self.frame_range, 0) |
|
right = min(template_frame + self.frame_range, len(frames)-1) + 1 |
|
search_range = frames[left:right] |
|
template_frame = frames[template_frame] |
|
search_frame = random.choice(search_range) |
|
else: |
|
search_frame = random.choice(track_info['hard']) |
|
left = max(search_frame - self.frame_range, 0) |
|
right = min(search_frame + self.frame_range, len(frames)-1) + 1 |
|
template_range = frames[left:right] |
|
template_frame = random.choice(template_range) |
|
search_frame = frames[search_frame] |
|
|
|
return self.get_image_anno(video_name, track, template_frame), \ |
|
self.get_image_anno(video_name, track, search_frame) |
|
|
|
def get_random_target(self, index=-1): |
|
if index == -1: |
|
index = random.randint(0, self.num-1) |
|
video_name = self.videos[index] |
|
video = self.labels[video_name] |
|
track = random.choice(list(video.keys())) |
|
track_info = video[track] |
|
|
|
frames = track_info['frames'] |
|
frame = random.choice(frames) |
|
|
|
return self.get_image_anno(video_name, track, frame) |
|
|
|
|
|
def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): |
|
bbox = [float(x) for x in bbox] |
|
a = (out_sz-1) / (bbox[2]-bbox[0]) |
|
b = (out_sz-1) / (bbox[3]-bbox[1]) |
|
c = -a * bbox[0] |
|
d = -b * bbox[1] |
|
mapping = np.array([[a, 0, c], |
|
[0, b, d]]).astype(np.float) |
|
crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) |
|
return crop |
|
|
|
|
|
class Augmentation: |
|
def __init__(self, cfg): |
|
|
|
self.shift = 0 |
|
self.scale = 0 |
|
self.blur = 0 |
|
self.resize = False |
|
self.rgbVar = np.array([[-0.55919361, 0.98062831, - 0.41940627], |
|
[1.72091413, 0.19879334, - 1.82968581], |
|
[4.64467907, 4.73710203, 4.88324118]], dtype=np.float32) |
|
self.flip = 0 |
|
|
|
self.eig_vec = np.array([ |
|
[0.4009, 0.7192, -0.5675], |
|
[-0.8140, -0.0045, -0.5808], |
|
[0.4203, -0.6948, -0.5836], |
|
], dtype=np.float32) |
|
|
|
self.eig_val = np.array([[0.2175, 0.0188, 0.0045]], np.float32) |
|
|
|
self.__dict__.update(cfg) |
|
|
|
@staticmethod |
|
def random(): |
|
return random.random() * 2 - 1.0 |
|
|
|
def blur_image(self, image): |
|
def rand_kernel(): |
|
size = np.random.randn(1) |
|
size = int(np.round(size)) * 2 + 1 |
|
if size < 0: return None |
|
if random.random() < 0.5: return None |
|
size = min(size, 45) |
|
kernel = np.zeros((size, size)) |
|
c = int(size/2) |
|
wx = random.random() |
|
kernel[:, c] += 1. / size * wx |
|
kernel[c, :] += 1. / size * (1-wx) |
|
return kernel |
|
|
|
kernel = rand_kernel() |
|
|
|
if kernel is not None: |
|
image = cv2.filter2D(image, -1, kernel) |
|
return image |
|
|
|
def __call__(self, image, bbox, size, gray=False): |
|
if gray: |
|
grayed = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
image = np.zeros((grayed.shape[0], grayed.shape[1], 3), np.uint8) |
|
image[:, :, 0] = image[:, :, 1] = image[:, :, 2] = grayed |
|
|
|
shape = image.shape |
|
|
|
crop_bbox = center2corner((shape[0]//2, shape[1]//2, size-1, size-1)) |
|
|
|
param = {} |
|
if self.shift: |
|
param['shift'] = (Augmentation.random() * self.shift, Augmentation.random() * self.shift) |
|
|
|
if self.scale: |
|
param['scale'] = ((1.0 + Augmentation.random() * self.scale), (1.0 + Augmentation.random() * self.scale)) |
|
|
|
crop_bbox, _ = aug_apply(Corner(*crop_bbox), param, shape) |
|
|
|
x1 = crop_bbox.x1 |
|
y1 = crop_bbox.y1 |
|
|
|
bbox = BBox(bbox.x1 - x1, bbox.y1 - y1, |
|
bbox.x2 - x1, bbox.y2 - y1) |
|
|
|
if self.scale: |
|
scale_x, scale_y = param['scale'] |
|
bbox = Corner(bbox.x1 / scale_x, bbox.y1 / scale_y, bbox.x2 / scale_x, bbox.y2 / scale_y) |
|
|
|
image = crop_hwc(image, crop_bbox, size) |
|
|
|
offset = np.dot(self.rgbVar, np.random.randn(3, 1)) |
|
offset = offset[::-1] |
|
offset = offset.reshape(3) |
|
image = image - offset |
|
|
|
if self.blur > random.random(): |
|
image = self.blur_image(image) |
|
|
|
if self.resize: |
|
imageSize = image.shape[:2] |
|
ratio = max(math.pow(random.random(), 0.5), 0.2) |
|
rand_size = (int(round(ratio*imageSize[0])), int(round(ratio*imageSize[1]))) |
|
image = cv2.resize(image, rand_size) |
|
image = cv2.resize(image, tuple(imageSize)) |
|
|
|
if self.flip and self.flip > Augmentation.random(): |
|
image = cv2.flip(image, 1) |
|
width = image.shape[1] |
|
bbox = Corner(width - 1 - bbox.x2, bbox.y1, width - 1 - bbox.x1, bbox.y2) |
|
|
|
return image, bbox |
|
|
|
|
|
class AnchorTargetLayer: |
|
def __init__(self, cfg): |
|
self.thr_high = 0.6 |
|
self.thr_low = 0.3 |
|
self.negative = 16 |
|
self.rpn_batch = 64 |
|
self.positive = 16 |
|
|
|
self.__dict__.update(cfg) |
|
|
|
def __call__(self, anchor, target, size, neg=False, need_iou=False): |
|
anchor_num = anchor.anchors.shape[0] |
|
|
|
cls = np.zeros((anchor_num, size, size), dtype=np.int64) |
|
cls[...] = -1 |
|
delta = np.zeros((4, anchor_num, size, size), dtype=np.float32) |
|
delta_weight = np.zeros((anchor_num, size, size), dtype=np.float32) |
|
|
|
def select(position, keep_num=16): |
|
num = position[0].shape[0] |
|
if num <= keep_num: |
|
return position, num |
|
slt = np.arange(num) |
|
np.random.shuffle(slt) |
|
slt = slt[:keep_num] |
|
return tuple(p[slt] for p in position), keep_num |
|
|
|
if neg: |
|
l = size // 2 - 3 |
|
r = size // 2 + 3 + 1 |
|
|
|
cls[:, l:r, l:r] = 0 |
|
|
|
neg, neg_num = select(np.where(cls == 0), self.negative) |
|
cls[:] = -1 |
|
cls[neg] = 0 |
|
|
|
if not need_iou: |
|
return cls, delta, delta_weight |
|
else: |
|
overlap = np.zeros((anchor_num, size, size), dtype=np.float32) |
|
return cls, delta, delta_weight, overlap |
|
|
|
tcx, tcy, tw, th = corner2center(target) |
|
|
|
anchor_box = anchor.all_anchors[0] |
|
anchor_center = anchor.all_anchors[1] |
|
x1, y1, x2, y2 = anchor_box[0], anchor_box[1], anchor_box[2], anchor_box[3] |
|
cx, cy, w, h = anchor_center[0], anchor_center[1], anchor_center[2], anchor_center[3] |
|
|
|
|
|
delta[0] = (tcx - cx) / w |
|
delta[1] = (tcy - cy) / h |
|
delta[2] = np.log(tw / w) |
|
delta[3] = np.log(th / h) |
|
|
|
|
|
overlap = IoU([x1, y1, x2, y2], target) |
|
|
|
pos = np.where(overlap > self.thr_high) |
|
neg = np.where(overlap < self.thr_low) |
|
|
|
pos, pos_num = select(pos, self.positive) |
|
neg, neg_num = select(neg, self.rpn_batch - pos_num) |
|
|
|
cls[pos] = 1 |
|
delta_weight[pos] = 1. / (pos_num + 1e-6) |
|
|
|
cls[neg] = 0 |
|
|
|
if not need_iou: |
|
return cls, delta, delta_weight |
|
else: |
|
return cls, delta, delta_weight, overlap |
|
|
|
|
|
class DataSets(Dataset): |
|
def __init__(self, cfg, anchor_cfg, num_epoch=1): |
|
super(DataSets, self).__init__() |
|
global logger |
|
logger = logging.getLogger('global') |
|
|
|
|
|
self.anchors = Anchors(anchor_cfg) |
|
|
|
|
|
self.template_size = 127 |
|
self.origin_size = 127 |
|
self.search_size = 255 |
|
self.size = 17 |
|
self.base_size = 0 |
|
self.crop_size = 0 |
|
|
|
if 'template_size' in cfg: |
|
self.template_size = cfg['template_size'] |
|
if 'origin_size' in cfg: |
|
self.origin_size = cfg['origin_size'] |
|
if 'search_size' in cfg: |
|
self.search_size = cfg['search_size'] |
|
if 'base_size' in cfg: |
|
self.base_size = cfg['base_size'] |
|
if 'size' in cfg: |
|
self.size = cfg['size'] |
|
|
|
if (self.search_size - self.template_size) / self.anchors.stride + 1 + self.base_size != self.size: |
|
raise Exception("size not match!") |
|
if 'crop_size' in cfg: |
|
self.crop_size = cfg['crop_size'] |
|
self.template_small = False |
|
if 'template_small' in cfg and cfg['template_small']: |
|
self.template_small = True |
|
|
|
self.anchors.generate_all_anchors(im_c=self.search_size//2, size=self.size) |
|
|
|
if 'anchor_target' not in cfg: |
|
cfg['anchor_target'] = {} |
|
self.anchor_target = AnchorTargetLayer(cfg['anchor_target']) |
|
|
|
|
|
if 'datasets' not in cfg: |
|
raise(Exception('DataSet need "{}"'.format('datasets'))) |
|
|
|
self.all_data = [] |
|
start = 0 |
|
self.num = 0 |
|
for name in cfg['datasets']: |
|
dataset = cfg['datasets'][name] |
|
dataset['mark'] = name |
|
dataset['start'] = start |
|
|
|
dataset = SubDataSet(dataset) |
|
dataset.log() |
|
self.all_data.append(dataset) |
|
|
|
start += dataset.num |
|
self.num += dataset.num_use |
|
|
|
|
|
aug_cfg = cfg['augmentation'] |
|
self.template_aug = Augmentation(aug_cfg['template']) |
|
self.search_aug = Augmentation(aug_cfg['search']) |
|
self.gray = aug_cfg['gray'] |
|
self.neg = aug_cfg['neg'] |
|
self.inner_neg = 0 if 'inner_neg' not in aug_cfg else aug_cfg['inner_neg'] |
|
|
|
self.pick = None |
|
if 'num' in cfg: |
|
self.num = int(cfg['num']) |
|
self.num *= num_epoch |
|
self.shuffle() |
|
|
|
self.infos = { |
|
'template': self.template_size, |
|
'search': self.search_size, |
|
'template_small': self.template_small, |
|
'gray': self.gray, |
|
'neg': self.neg, |
|
'inner_neg': self.inner_neg, |
|
'crop_size': self.crop_size, |
|
'anchor_target': self.anchor_target.__dict__, |
|
'num': self.num // num_epoch |
|
} |
|
logger.info('dataset informations: \n{}'.format(json.dumps(self.infos, indent=4))) |
|
|
|
def imread(self, path): |
|
img = cv2.imread(path) |
|
|
|
if self.origin_size == self.template_size: |
|
return img, 1.0 |
|
|
|
def map_size(exe, size): |
|
return int(round(((exe + 1) / (self.origin_size + 1) * (size+1) - 1))) |
|
|
|
nsize = map_size(self.template_size, img.shape[1]) |
|
|
|
img = cv2.resize(img, (nsize, nsize)) |
|
|
|
return img, nsize / img.shape[1] |
|
|
|
def shuffle(self): |
|
pick = [] |
|
m = 0 |
|
while m < self.num: |
|
p = [] |
|
for subset in self.all_data: |
|
sub_p = subset.shuffle() |
|
p += sub_p |
|
|
|
sample_random.shuffle(p) |
|
|
|
pick += p |
|
m = len(pick) |
|
self.pick = pick |
|
logger.info("shuffle done!") |
|
logger.info("dataset length {}".format(self.num)) |
|
|
|
def __len__(self): |
|
return self.num |
|
|
|
def find_dataset(self, index): |
|
for dataset in self.all_data: |
|
if dataset.start + dataset.num > index: |
|
return dataset, index - dataset.start |
|
|
|
def __getitem__(self, index, debug=False): |
|
index = self.pick[index] |
|
dataset, index = self.find_dataset(index) |
|
|
|
gray = self.gray and self.gray > random.random() |
|
neg = self.neg and self.neg > random.random() |
|
|
|
if neg: |
|
template = dataset.get_random_target(index) |
|
if self.inner_neg and self.inner_neg > random.random(): |
|
search = dataset.get_random_target() |
|
else: |
|
search = random.choice(self.all_data).get_random_target() |
|
else: |
|
template, search = dataset.get_positive_pair(index) |
|
|
|
def center_crop(img, size): |
|
shape = img.shape[1] |
|
if shape == size: return img |
|
c = shape // 2 |
|
l = c - size // 2 |
|
r = c + size // 2 + 1 |
|
return img[l:r, l:r] |
|
|
|
template_image, scale_z = self.imread(template[0]) |
|
|
|
if self.template_small: |
|
template_image = center_crop(template_image, self.template_size) |
|
|
|
search_image, scale_x = self.imread(search[0]) |
|
if self.crop_size > 0: |
|
search_image = center_crop(search_image, self.crop_size) |
|
|
|
def toBBox(image, shape): |
|
imh, imw = image.shape[:2] |
|
if len(shape) == 4: |
|
w, h = shape[2]-shape[0], shape[3]-shape[1] |
|
else: |
|
w, h = shape |
|
context_amount = 0.5 |
|
exemplar_size = self.template_size |
|
wc_z = w + context_amount * (w+h) |
|
hc_z = h + context_amount * (w+h) |
|
s_z = np.sqrt(wc_z * hc_z) |
|
scale_z = exemplar_size / s_z |
|
w = w*scale_z |
|
h = h*scale_z |
|
cx, cy = imw//2, imh//2 |
|
bbox = center2corner(Center(cx, cy, w, h)) |
|
return bbox |
|
|
|
template_box = toBBox(template_image, template[1]) |
|
search_box = toBBox(search_image, search[1]) |
|
|
|
template, _ = self.template_aug(template_image, template_box, self.template_size, gray=gray) |
|
search, bbox = self.search_aug(search_image, search_box, self.search_size, gray=gray) |
|
|
|
def draw(image, box, name): |
|
image = image.copy() |
|
x1, y1, x2, y2 = map(lambda x: int(round(x)), box) |
|
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0)) |
|
cv2.imwrite(name, image) |
|
|
|
if debug: |
|
draw(template_image, template_box, "debug/{:06d}_ot.jpg".format(index)) |
|
draw(search_image, search_box, "debug/{:06d}_os.jpg".format(index)) |
|
draw(template, _, "debug/{:06d}_t.jpg".format(index)) |
|
draw(search, bbox, "debug/{:06d}_s.jpg".format(index)) |
|
|
|
cls, delta, delta_weight = self.anchor_target(self.anchors, bbox, self.size, neg) |
|
|
|
template, search = map(lambda x: np.transpose(x, (2, 0, 1)).astype(np.float32), [template, search]) |
|
|
|
return template, search, cls, delta, delta_weight, np.array(bbox, np.float32) |
|
|
|
|