|
|
|
from torch import nn
|
|
from torch.autograd import Function
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
from ..utils import ext_loader
|
|
|
|
ext_module = ext_loader.load_ext('_ext',
|
|
['psamask_forward', 'psamask_backward'])
|
|
|
|
|
|
class PSAMaskFunction(Function):
|
|
|
|
@staticmethod
|
|
def symbolic(g, input, psa_type, mask_size):
|
|
return g.op(
|
|
'mmcv::MMCVPSAMask',
|
|
input,
|
|
psa_type_i=psa_type,
|
|
mask_size_i=mask_size)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, psa_type, mask_size):
|
|
ctx.psa_type = psa_type
|
|
ctx.mask_size = _pair(mask_size)
|
|
ctx.save_for_backward(input)
|
|
|
|
h_mask, w_mask = ctx.mask_size
|
|
batch_size, channels, h_feature, w_feature = input.size()
|
|
assert channels == h_mask * w_mask
|
|
output = input.new_zeros(
|
|
(batch_size, h_feature * w_feature, h_feature, w_feature))
|
|
|
|
ext_module.psamask_forward(
|
|
input,
|
|
output,
|
|
psa_type=psa_type,
|
|
num_=batch_size,
|
|
h_feature=h_feature,
|
|
w_feature=w_feature,
|
|
h_mask=h_mask,
|
|
w_mask=w_mask,
|
|
half_h_mask=(h_mask - 1) // 2,
|
|
half_w_mask=(w_mask - 1) // 2)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input = ctx.saved_tensors[0]
|
|
psa_type = ctx.psa_type
|
|
h_mask, w_mask = ctx.mask_size
|
|
batch_size, channels, h_feature, w_feature = input.size()
|
|
grad_input = grad_output.new_zeros(
|
|
(batch_size, channels, h_feature, w_feature))
|
|
ext_module.psamask_backward(
|
|
grad_output,
|
|
grad_input,
|
|
psa_type=psa_type,
|
|
num_=batch_size,
|
|
h_feature=h_feature,
|
|
w_feature=w_feature,
|
|
h_mask=h_mask,
|
|
w_mask=w_mask,
|
|
half_h_mask=(h_mask - 1) // 2,
|
|
half_w_mask=(w_mask - 1) // 2)
|
|
return grad_input, None, None, None
|
|
|
|
|
|
psa_mask = PSAMaskFunction.apply
|
|
|
|
|
|
class PSAMask(nn.Module):
|
|
|
|
def __init__(self, psa_type, mask_size=None):
|
|
super(PSAMask, self).__init__()
|
|
assert psa_type in ['collect', 'distribute']
|
|
if psa_type == 'collect':
|
|
psa_type_enum = 0
|
|
else:
|
|
psa_type_enum = 1
|
|
self.psa_type_enum = psa_type_enum
|
|
self.mask_size = mask_size
|
|
self.psa_type = psa_type
|
|
|
|
def forward(self, input):
|
|
return psa_mask(input, self.psa_type_enum, self.mask_size)
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__
|
|
s += f'(psa_type={self.psa_type}, '
|
|
s += f'mask_size={self.mask_size})'
|
|
return s
|
|
|