|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, kaiming_init |
|
from mmcv.runner import auto_fp16, force_fp32 |
|
|
|
from mmdet.models.builder import HEADS |
|
|
|
|
|
@HEADS.register_module() |
|
class FusedSemanticHead(nn.Module): |
|
r"""Multi-level fused semantic segmentation head. |
|
|
|
.. code-block:: none |
|
|
|
in_1 -> 1x1 conv --- |
|
| |
|
in_2 -> 1x1 conv -- | |
|
|| |
|
in_3 -> 1x1 conv - || |
|
||| /-> 1x1 conv (mask prediction) |
|
in_4 -> 1x1 conv -----> 3x3 convs (*4) |
|
| \-> 1x1 conv (feature) |
|
in_5 -> 1x1 conv --- |
|
""" |
|
|
|
def __init__(self, |
|
num_ins, |
|
fusion_level, |
|
num_convs=4, |
|
in_channels=256, |
|
conv_out_channels=256, |
|
num_classes=183, |
|
ignore_label=255, |
|
loss_weight=0.2, |
|
conv_cfg=None, |
|
norm_cfg=None): |
|
super(FusedSemanticHead, self).__init__() |
|
self.num_ins = num_ins |
|
self.fusion_level = fusion_level |
|
self.num_convs = num_convs |
|
self.in_channels = in_channels |
|
self.conv_out_channels = conv_out_channels |
|
self.num_classes = num_classes |
|
self.ignore_label = ignore_label |
|
self.loss_weight = loss_weight |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.fp16_enabled = False |
|
|
|
self.lateral_convs = nn.ModuleList() |
|
for i in range(self.num_ins): |
|
self.lateral_convs.append( |
|
ConvModule( |
|
self.in_channels, |
|
self.in_channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
inplace=False)) |
|
|
|
self.convs = nn.ModuleList() |
|
for i in range(self.num_convs): |
|
in_channels = self.in_channels if i == 0 else conv_out_channels |
|
self.convs.append( |
|
ConvModule( |
|
in_channels, |
|
conv_out_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg)) |
|
self.conv_embedding = ConvModule( |
|
conv_out_channels, |
|
conv_out_channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg) |
|
self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1) |
|
|
|
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label) |
|
|
|
def init_weights(self): |
|
kaiming_init(self.conv_logits) |
|
|
|
@auto_fp16() |
|
def forward(self, feats): |
|
x = self.lateral_convs[self.fusion_level](feats[self.fusion_level]) |
|
fused_size = tuple(x.shape[-2:]) |
|
for i, feat in enumerate(feats): |
|
if i != self.fusion_level: |
|
feat = F.interpolate( |
|
feat, size=fused_size, mode='bilinear', align_corners=True) |
|
x += self.lateral_convs[i](feat) |
|
|
|
for i in range(self.num_convs): |
|
x = self.convs[i](x) |
|
|
|
mask_pred = self.conv_logits(x) |
|
x = self.conv_embedding(x) |
|
return mask_pred, x |
|
|
|
@force_fp32(apply_to=('mask_pred', )) |
|
def loss(self, mask_pred, labels): |
|
labels = labels.squeeze(1).long() |
|
loss_semantic_seg = self.criterion(mask_pred, labels) |
|
loss_semantic_seg *= self.loss_weight |
|
return loss_semantic_seg |
|
|