|
from abc import ABCMeta, abstractmethod |
|
|
|
import torch.nn as nn |
|
|
|
|
|
class BaseDenseHead(nn.Module, metaclass=ABCMeta): |
|
"""Base class for DenseHeads.""" |
|
|
|
def __init__(self): |
|
super(BaseDenseHead, self).__init__() |
|
|
|
@abstractmethod |
|
def loss(self, **kwargs): |
|
"""Compute losses of the head.""" |
|
pass |
|
|
|
@abstractmethod |
|
def get_bboxes(self, **kwargs): |
|
"""Transform network output for a batch into bbox predictions.""" |
|
pass |
|
|
|
def forward_train(self, |
|
x, |
|
img_metas, |
|
gt_bboxes, |
|
gt_labels=None, |
|
gt_bboxes_ignore=None, |
|
proposal_cfg=None, |
|
**kwargs): |
|
""" |
|
Args: |
|
x (list[Tensor]): Features from FPN. |
|
img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
gt_bboxes (Tensor): Ground truth bboxes of the image, |
|
shape (num_gts, 4). |
|
gt_labels (Tensor): Ground truth labels of each box, |
|
shape (num_gts,). |
|
gt_bboxes_ignore (Tensor): Ground truth bboxes to be |
|
ignored, shape (num_ignored_gts, 4). |
|
proposal_cfg (mmcv.Config): Test / postprocessing configuration, |
|
if None, test_cfg would be used |
|
|
|
Returns: |
|
tuple: |
|
losses: (dict[str, Tensor]): A dictionary of loss components. |
|
proposal_list (list[Tensor]): Proposals of each image. |
|
""" |
|
outs = self(x) |
|
if gt_labels is None: |
|
loss_inputs = outs + (gt_bboxes, img_metas) |
|
else: |
|
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) |
|
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) |
|
if proposal_cfg is None: |
|
return losses |
|
else: |
|
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) |
|
return losses, proposal_list |
|
|