Spaces:
Sleeping
Sleeping
File size: 2,410 Bytes
d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# --------------------------------------------------------
# SiamMask
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
import torch.nn as nn
import logging
logger = logging.getLogger('global')
class Features(nn.Module):
def __init__(self):
super(Features, self).__init__()
self.feature_size = -1
def forward(self, x):
raise NotImplementedError
def param_groups(self, start_lr, feature_mult=1):
params = filter(lambda x:x.requires_grad, self.parameters())
params = [{'params': params, 'lr': start_lr * feature_mult}]
return params
def load_model(self, f='pretrain.model'):
with open(f) as f:
pretrained_dict = torch.load(f)
model_dict = self.state_dict()
print(pretrained_dict.keys())
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
print(pretrained_dict.keys())
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
class MultiStageFeature(Features):
def __init__(self):
super(MultiStageFeature, self).__init__()
self.layers = []
self.train_num = -1
self.change_point = []
self.train_nums = []
def unfix(self, ratio=0.0):
if self.train_num == -1:
self.train_num = 0
self.unlock()
self.eval()
for p, t in reversed(list(zip(self.change_point, self.train_nums))):
if ratio >= p:
if self.train_num != t:
self.train_num = t
self.unlock()
return True
break
return False
def train_layers(self):
return self.layers[:self.train_num]
def unlock(self):
for p in self.parameters():
p.requires_grad = False
logger.info('Current training {} layers:\n\t'.format(self.train_num, self.train_layers()))
for m in self.train_layers():
for p in m.parameters():
p.requires_grad = True
def train(self, mode):
self.training = mode
if mode == False:
super(MultiStageFeature,self).train(False)
else:
for m in self.train_layers():
m.train(True)
return self
|