|
from SiamMask.models.siammask_sharp import SiamMask |
|
from SiamMask.models.features import MultiStageFeature |
|
from SiamMask.models.rpn import RPN, DepthCorr |
|
from SiamMask.models.mask import Mask |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from SiamMask.utils.load_helper import load_pretrain |
|
from SiamMask.experiments.siammask_sharp.resnet import resnet50 |
|
|
|
|
|
class ResDownS(nn.Module): |
|
def __init__(self, inplane, outplane): |
|
super(ResDownS, self).__init__() |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d(inplane, outplane, kernel_size=1, bias=False), |
|
nn.BatchNorm2d(outplane)) |
|
|
|
def forward(self, x): |
|
x = self.downsample(x) |
|
if x.size(3) < 20: |
|
l = 4 |
|
r = -4 |
|
x = x[:, :, l:r, l:r] |
|
return x |
|
|
|
|
|
class ResDown(MultiStageFeature): |
|
def __init__(self, pretrain=False): |
|
super(ResDown, self).__init__() |
|
self.features = resnet50(layer3=True, layer4=False) |
|
if pretrain: |
|
load_pretrain(self.features, 'resnet.model') |
|
|
|
self.downsample = ResDownS(1024, 256) |
|
|
|
self.layers = [self.downsample, self.features.layer2, self.features.layer3] |
|
self.train_nums = [1, 3] |
|
self.change_point = [0, 0.5] |
|
|
|
self.unfix(0.0) |
|
|
|
def param_groups(self, start_lr, feature_mult=1): |
|
lr = start_lr * feature_mult |
|
|
|
def _params(module, mult=1): |
|
params = list(filter(lambda x:x.requires_grad, module.parameters())) |
|
if len(params): |
|
return [{'params': params, 'lr': lr * mult}] |
|
else: |
|
return [] |
|
|
|
groups = [] |
|
groups += _params(self.downsample) |
|
groups += _params(self.features, 0.1) |
|
return groups |
|
|
|
def forward(self, x): |
|
output = self.features(x) |
|
p3 = self.downsample(output[-1]) |
|
return p3 |
|
|
|
def forward_all(self, x): |
|
output = self.features(x) |
|
p3 = self.downsample(output[-1]) |
|
return output, p3 |
|
|
|
|
|
class UP(RPN): |
|
def __init__(self, anchor_num=5, feature_in=256, feature_out=256): |
|
super(UP, self).__init__() |
|
|
|
self.anchor_num = anchor_num |
|
self.feature_in = feature_in |
|
self.feature_out = feature_out |
|
|
|
self.cls_output = 2 * self.anchor_num |
|
self.loc_output = 4 * self.anchor_num |
|
|
|
self.cls = DepthCorr(feature_in, feature_out, self.cls_output) |
|
self.loc = DepthCorr(feature_in, feature_out, self.loc_output) |
|
|
|
def forward(self, z_f, x_f): |
|
cls = self.cls(z_f, x_f) |
|
loc = self.loc(z_f, x_f) |
|
return cls, loc |
|
|
|
|
|
class MaskCorr(Mask): |
|
def __init__(self, oSz=63): |
|
super(MaskCorr, self).__init__() |
|
self.oSz = oSz |
|
self.mask = DepthCorr(256, 256, self.oSz**2) |
|
|
|
def forward(self, z, x): |
|
return self.mask(z, x) |
|
|
|
|
|
class Refine(nn.Module): |
|
def __init__(self): |
|
super(Refine, self).__init__() |
|
self.v0 = nn.Sequential(nn.Conv2d(64, 16, 3, padding=1), nn.ReLU(), |
|
nn.Conv2d(16, 4, 3, padding=1),nn.ReLU()) |
|
|
|
self.v1 = nn.Sequential(nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(), |
|
nn.Conv2d(64, 16, 3, padding=1), nn.ReLU()) |
|
|
|
self.v2 = nn.Sequential(nn.Conv2d(512, 128, 3, padding=1), nn.ReLU(), |
|
nn.Conv2d(128, 32, 3, padding=1), nn.ReLU()) |
|
|
|
self.h2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(), |
|
nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()) |
|
|
|
self.h1 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), |
|
nn.Conv2d(16, 16, 3, padding=1), nn.ReLU()) |
|
|
|
self.h0 = nn.Sequential(nn.Conv2d(4, 4, 3, padding=1), nn.ReLU(), |
|
nn.Conv2d(4, 4, 3, padding=1), nn.ReLU()) |
|
|
|
self.deconv = nn.ConvTranspose2d(256, 32, 15, 15) |
|
|
|
self.post0 = nn.Conv2d(32, 16, 3, padding=1) |
|
self.post1 = nn.Conv2d(16, 4, 3, padding=1) |
|
self.post2 = nn.Conv2d(4, 1, 3, padding=1) |
|
|
|
for modules in [self.v0, self.v1, self.v2, self.h2, self.h1, self.h0, self.deconv, self.post0, self.post1, self.post2,]: |
|
for l in modules.modules(): |
|
if isinstance(l, nn.Conv2d): |
|
nn.init.kaiming_uniform_(l.weight, a=1) |
|
|
|
def forward(self, f, corr_feature, pos=None, test=False): |
|
if test: |
|
p0 = torch.nn.functional.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61] |
|
p1 = torch.nn.functional.pad(f[1], [8, 8, 8, 8])[:, :, 2 * pos[0]:2 * pos[0] + 31, 2 * pos[1]:2 * pos[1] + 31] |
|
p2 = torch.nn.functional.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0] + 15, pos[1]:pos[1] + 15] |
|
else: |
|
p0 = F.unfold(f[0], (61, 61), padding=0, stride=4).permute(0, 2, 1).contiguous().view(-1, 64, 61, 61) |
|
if not (pos is None): p0 = torch.index_select(p0, 0, pos) |
|
p1 = F.unfold(f[1], (31, 31), padding=0, stride=2).permute(0, 2, 1).contiguous().view(-1, 256, 31, 31) |
|
if not (pos is None): p1 = torch.index_select(p1, 0, pos) |
|
p2 = F.unfold(f[2], (15, 15), padding=0, stride=1).permute(0, 2, 1).contiguous().view(-1, 512, 15, 15) |
|
if not (pos is None): p2 = torch.index_select(p2, 0, pos) |
|
|
|
if not(pos is None): |
|
p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1) |
|
else: |
|
p3 = corr_feature.permute(0, 2, 3, 1).contiguous().view(-1, 256, 1, 1) |
|
|
|
out = self.deconv(p3) |
|
out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31))) |
|
out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61))) |
|
out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127))) |
|
out = out.view(-1, 127*127) |
|
return out |
|
|
|
def param_groups(self, start_lr, feature_mult=1): |
|
params = filter(lambda x:x.requires_grad, self.parameters()) |
|
params = [{'params': params, 'lr': start_lr * feature_mult}] |
|
return params |
|
|
|
|
|
class Custom(SiamMask): |
|
def __init__(self, pretrain=False, **kwargs): |
|
super(Custom, self).__init__(**kwargs) |
|
self.features = ResDown(pretrain=pretrain) |
|
self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256) |
|
self.mask_model = MaskCorr() |
|
self.refine_model = Refine() |
|
|
|
def refine(self, f, pos=None): |
|
return self.refine_model(f, pos) |
|
|
|
def template(self, template): |
|
self.zf = self.features(template) |
|
|
|
def track(self, search): |
|
search = self.features(search) |
|
rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search) |
|
return rpn_pred_cls, rpn_pred_loc |
|
|
|
def track_mask(self, search): |
|
self.feature, self.search = self.features.forward_all(search) |
|
rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, self.search) |
|
self.corr_feature = self.mask_model.mask.forward_corr(self.zf, self.search) |
|
pred_mask = self.mask_model.mask.head(self.corr_feature) |
|
return rpn_pred_cls, rpn_pred_loc, pred_mask |
|
|
|
def track_refine(self, pos): |
|
pred_mask = self.refine_model(self.feature, self.corr_feature, pos=pos, test=True) |
|
return pred_mask |
|
|
|
|