|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, Linear, constant_init, xavier_init |
|
from mmcv.runner import auto_fp16 |
|
|
|
from mmdet.models.builder import HEADS |
|
from .fcn_mask_head import FCNMaskHead |
|
|
|
|
|
@HEADS.register_module() |
|
class CoarseMaskHead(FCNMaskHead): |
|
"""Coarse mask head used in PointRend. |
|
|
|
Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample |
|
the input feature map instead of upsample it. |
|
|
|
Args: |
|
num_convs (int): Number of conv layers in the head. Default: 0. |
|
num_fcs (int): Number of fc layers in the head. Default: 2. |
|
fc_out_channels (int): Number of output channels of fc layer. |
|
Default: 1024. |
|
downsample_factor (int): The factor that feature map is downsampled by. |
|
Default: 2. |
|
""" |
|
|
|
def __init__(self, |
|
num_convs=0, |
|
num_fcs=2, |
|
fc_out_channels=1024, |
|
downsample_factor=2, |
|
*arg, |
|
**kwarg): |
|
super(CoarseMaskHead, self).__init__( |
|
*arg, num_convs=num_convs, upsample_cfg=dict(type=None), **kwarg) |
|
self.num_fcs = num_fcs |
|
assert self.num_fcs > 0 |
|
self.fc_out_channels = fc_out_channels |
|
self.downsample_factor = downsample_factor |
|
assert self.downsample_factor >= 1 |
|
|
|
delattr(self, 'conv_logits') |
|
|
|
if downsample_factor > 1: |
|
downsample_in_channels = ( |
|
self.conv_out_channels |
|
if self.num_convs > 0 else self.in_channels) |
|
self.downsample_conv = ConvModule( |
|
downsample_in_channels, |
|
self.conv_out_channels, |
|
kernel_size=downsample_factor, |
|
stride=downsample_factor, |
|
padding=0, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg) |
|
else: |
|
self.downsample_conv = None |
|
|
|
self.output_size = (self.roi_feat_size[0] // downsample_factor, |
|
self.roi_feat_size[1] // downsample_factor) |
|
self.output_area = self.output_size[0] * self.output_size[1] |
|
|
|
last_layer_dim = self.conv_out_channels * self.output_area |
|
|
|
self.fcs = nn.ModuleList() |
|
for i in range(num_fcs): |
|
fc_in_channels = ( |
|
last_layer_dim if i == 0 else self.fc_out_channels) |
|
self.fcs.append(Linear(fc_in_channels, self.fc_out_channels)) |
|
last_layer_dim = self.fc_out_channels |
|
output_channels = self.num_classes * self.output_area |
|
self.fc_logits = Linear(last_layer_dim, output_channels) |
|
|
|
def init_weights(self): |
|
for m in self.fcs.modules(): |
|
if isinstance(m, nn.Linear): |
|
xavier_init(m) |
|
constant_init(self.fc_logits, 0.001) |
|
|
|
@auto_fp16() |
|
def forward(self, x): |
|
for conv in self.convs: |
|
x = conv(x) |
|
|
|
if self.downsample_conv is not None: |
|
x = self.downsample_conv(x) |
|
|
|
x = x.flatten(1) |
|
for fc in self.fcs: |
|
x = self.relu(fc(x)) |
|
mask_pred = self.fc_logits(x).view( |
|
x.size(0), self.num_classes, *self.output_size) |
|
return mask_pred |
|
|