|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
|
|
from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d
|
|
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
|
|
|
|
|
|
@CONV_LAYERS.register_module(name='SAC')
|
|
class SAConv2d(ConvAWS2d):
|
|
"""SAC (Switchable Atrous Convolution)
|
|
|
|
This is an implementation of SAC in DetectoRS
|
|
(https://arxiv.org/pdf/2006.02334.pdf).
|
|
|
|
Args:
|
|
in_channels (int): Number of channels in the input image
|
|
out_channels (int): Number of channels produced by the convolution
|
|
kernel_size (int or tuple): Size of the convolving kernel
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of
|
|
the input. Default: 0
|
|
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
|
|
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
|
dilation (int or tuple, optional): Spacing between kernel elements.
|
|
Default: 1
|
|
groups (int, optional): Number of blocked connections from input
|
|
channels to output channels. Default: 1
|
|
bias (bool, optional): If ``True``, adds a learnable bias to the
|
|
output. Default: ``True``
|
|
use_deform: If ``True``, replace convolution with deformable
|
|
convolution. Default: ``False``.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
use_deform=False):
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias)
|
|
self.use_deform = use_deform
|
|
self.switch = nn.Conv2d(
|
|
self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
|
|
self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
|
|
self.pre_context = nn.Conv2d(
|
|
self.in_channels, self.in_channels, kernel_size=1, bias=True)
|
|
self.post_context = nn.Conv2d(
|
|
self.out_channels, self.out_channels, kernel_size=1, bias=True)
|
|
if self.use_deform:
|
|
self.offset_s = nn.Conv2d(
|
|
self.in_channels,
|
|
18,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=stride,
|
|
bias=True)
|
|
self.offset_l = nn.Conv2d(
|
|
self.in_channels,
|
|
18,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=stride,
|
|
bias=True)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
constant_init(self.switch, 0, bias=1)
|
|
self.weight_diff.data.zero_()
|
|
constant_init(self.pre_context, 0)
|
|
constant_init(self.post_context, 0)
|
|
if self.use_deform:
|
|
constant_init(self.offset_s, 0)
|
|
constant_init(self.offset_l, 0)
|
|
|
|
def forward(self, x):
|
|
|
|
avg_x = F.adaptive_avg_pool2d(x, output_size=1)
|
|
avg_x = self.pre_context(avg_x)
|
|
avg_x = avg_x.expand_as(x)
|
|
x = x + avg_x
|
|
|
|
avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
|
|
avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
|
|
switch = self.switch(avg_x)
|
|
|
|
weight = self._get_weight(self.weight)
|
|
zero_bias = torch.zeros(
|
|
self.out_channels, device=weight.device, dtype=weight.dtype)
|
|
|
|
if self.use_deform:
|
|
offset = self.offset_s(avg_x)
|
|
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
|
|
self.dilation, self.groups, 1)
|
|
else:
|
|
if (TORCH_VERSION == 'parrots'
|
|
or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
|
|
out_s = super().conv2d_forward(x, weight)
|
|
elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
|
|
|
|
out_s = super()._conv_forward(x, weight, zero_bias)
|
|
else:
|
|
out_s = super()._conv_forward(x, weight)
|
|
ori_p = self.padding
|
|
ori_d = self.dilation
|
|
self.padding = tuple(3 * p for p in self.padding)
|
|
self.dilation = tuple(3 * d for d in self.dilation)
|
|
weight = weight + self.weight_diff
|
|
if self.use_deform:
|
|
offset = self.offset_l(avg_x)
|
|
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
|
|
self.dilation, self.groups, 1)
|
|
else:
|
|
if (TORCH_VERSION == 'parrots'
|
|
or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
|
|
out_l = super().conv2d_forward(x, weight)
|
|
elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
|
|
|
|
out_l = super()._conv_forward(x, weight, zero_bias)
|
|
else:
|
|
out_l = super()._conv_forward(x, weight)
|
|
|
|
out = switch * out_s + (1 - switch) * out_l
|
|
self.padding = ori_p
|
|
self.dilation = ori_d
|
|
|
|
avg_x = F.adaptive_avg_pool2d(out, output_size=1)
|
|
avg_x = self.post_context(avg_x)
|
|
avg_x = avg_x.expand_as(out)
|
|
out = out + avg_x
|
|
return out
|
|
|