|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
from ..utils import ext_loader
|
|
|
|
ext_module = ext_loader.load_ext(
|
|
'_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
|
|
|
|
|
|
class MaskedConv2dFunction(Function):
|
|
|
|
@staticmethod
|
|
def symbolic(g, features, mask, weight, bias, padding, stride):
|
|
return g.op(
|
|
'mmcv::MMCVMaskedConv2d',
|
|
features,
|
|
mask,
|
|
weight,
|
|
bias,
|
|
padding_i=padding,
|
|
stride_i=stride)
|
|
|
|
@staticmethod
|
|
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
|
|
assert mask.dim() == 3 and mask.size(0) == 1
|
|
assert features.dim() == 4 and features.size(0) == 1
|
|
assert features.size()[2:] == mask.size()[1:]
|
|
pad_h, pad_w = _pair(padding)
|
|
stride_h, stride_w = _pair(stride)
|
|
if stride_h != 1 or stride_w != 1:
|
|
raise ValueError(
|
|
'Stride could not only be 1 in masked_conv2d currently.')
|
|
out_channel, in_channel, kernel_h, kernel_w = weight.size()
|
|
|
|
batch_size = features.size(0)
|
|
out_h = int(
|
|
math.floor((features.size(2) + 2 * pad_h -
|
|
(kernel_h - 1) - 1) / stride_h + 1))
|
|
out_w = int(
|
|
math.floor((features.size(3) + 2 * pad_w -
|
|
(kernel_h - 1) - 1) / stride_w + 1))
|
|
mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
|
|
output = features.new_zeros(batch_size, out_channel, out_h, out_w)
|
|
if mask_inds.numel() > 0:
|
|
mask_h_idx = mask_inds[:, 0].contiguous()
|
|
mask_w_idx = mask_inds[:, 1].contiguous()
|
|
data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
|
|
mask_inds.size(0))
|
|
ext_module.masked_im2col_forward(
|
|
features,
|
|
mask_h_idx,
|
|
mask_w_idx,
|
|
data_col,
|
|
kernel_h=kernel_h,
|
|
kernel_w=kernel_w,
|
|
pad_h=pad_h,
|
|
pad_w=pad_w)
|
|
|
|
masked_output = torch.addmm(1, bias[:, None], 1,
|
|
weight.view(out_channel, -1), data_col)
|
|
ext_module.masked_col2im_forward(
|
|
masked_output,
|
|
mask_h_idx,
|
|
mask_w_idx,
|
|
output,
|
|
height=out_h,
|
|
width=out_w,
|
|
channels=out_channel)
|
|
return output
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad_output):
|
|
return (None, ) * 5
|
|
|
|
|
|
masked_conv2d = MaskedConv2dFunction.apply
|
|
|
|
|
|
class MaskedConv2d(nn.Conv2d):
|
|
"""A MaskedConv2d which inherits the official Conv2d.
|
|
|
|
The masked forward doesn't implement the backward function and only
|
|
supports the stride parameter to be 1 currently.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True):
|
|
super(MaskedConv2d,
|
|
self).__init__(in_channels, out_channels, kernel_size, stride,
|
|
padding, dilation, groups, bias)
|
|
|
|
def forward(self, input, mask=None):
|
|
if mask is None:
|
|
return super(MaskedConv2d, self).forward(input)
|
|
else:
|
|
return masked_conv2d(input, mask, self.weight, self.bias,
|
|
self.padding)
|
|
|