|
import torch |
|
|
|
from mmdet.core import bbox2result, bbox2roi |
|
from ..builder import HEADS, build_head, build_roi_extractor |
|
from .standard_roi_head import StandardRoIHead |
|
|
|
|
|
@HEADS.register_module() |
|
class GridRoIHead(StandardRoIHead): |
|
"""Grid roi head for Grid R-CNN. |
|
|
|
https://arxiv.org/abs/1811.12030 |
|
""" |
|
|
|
def __init__(self, grid_roi_extractor, grid_head, **kwargs): |
|
assert grid_head is not None |
|
super(GridRoIHead, self).__init__(**kwargs) |
|
if grid_roi_extractor is not None: |
|
self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor) |
|
self.share_roi_extractor = False |
|
else: |
|
self.share_roi_extractor = True |
|
self.grid_roi_extractor = self.bbox_roi_extractor |
|
self.grid_head = build_head(grid_head) |
|
|
|
def init_weights(self, pretrained): |
|
"""Initialize the weights in head. |
|
|
|
Args: |
|
pretrained (str, optional): Path to pre-trained weights. |
|
Defaults to None. |
|
""" |
|
super(GridRoIHead, self).init_weights(pretrained) |
|
self.grid_head.init_weights() |
|
if not self.share_roi_extractor: |
|
self.grid_roi_extractor.init_weights() |
|
|
|
def _random_jitter(self, sampling_results, img_metas, amplitude=0.15): |
|
"""Ramdom jitter positive proposals for training.""" |
|
for sampling_result, img_meta in zip(sampling_results, img_metas): |
|
bboxes = sampling_result.pos_bboxes |
|
random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_( |
|
-amplitude, amplitude) |
|
|
|
cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2 |
|
wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs() |
|
|
|
new_cxcy = cxcy + wh * random_offsets[:, :2] |
|
new_wh = wh * (1 + random_offsets[:, 2:]) |
|
|
|
new_x1y1 = (new_cxcy - new_wh / 2) |
|
new_x2y2 = (new_cxcy + new_wh / 2) |
|
new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1) |
|
|
|
max_shape = img_meta['img_shape'] |
|
if max_shape is not None: |
|
new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1) |
|
new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1) |
|
|
|
sampling_result.pos_bboxes = new_bboxes |
|
return sampling_results |
|
|
|
def forward_dummy(self, x, proposals): |
|
"""Dummy forward function.""" |
|
|
|
outs = () |
|
rois = bbox2roi([proposals]) |
|
if self.with_bbox: |
|
bbox_results = self._bbox_forward(x, rois) |
|
outs = outs + (bbox_results['cls_score'], |
|
bbox_results['bbox_pred']) |
|
|
|
|
|
grid_rois = rois[:100] |
|
grid_feats = self.grid_roi_extractor( |
|
x[:self.grid_roi_extractor.num_inputs], grid_rois) |
|
if self.with_shared_head: |
|
grid_feats = self.shared_head(grid_feats) |
|
grid_pred = self.grid_head(grid_feats) |
|
outs = outs + (grid_pred, ) |
|
|
|
|
|
if self.with_mask: |
|
mask_rois = rois[:100] |
|
mask_results = self._mask_forward(x, mask_rois) |
|
outs = outs + (mask_results['mask_pred'], ) |
|
return outs |
|
|
|
def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, |
|
img_metas): |
|
"""Run forward function and calculate loss for box head in training.""" |
|
bbox_results = super(GridRoIHead, |
|
self)._bbox_forward_train(x, sampling_results, |
|
gt_bboxes, gt_labels, |
|
img_metas) |
|
|
|
|
|
sampling_results = self._random_jitter(sampling_results, img_metas) |
|
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) |
|
|
|
|
|
if pos_rois.shape[0] == 0: |
|
return bbox_results |
|
|
|
grid_feats = self.grid_roi_extractor( |
|
x[:self.grid_roi_extractor.num_inputs], pos_rois) |
|
if self.with_shared_head: |
|
grid_feats = self.shared_head(grid_feats) |
|
|
|
max_sample_num_grid = self.train_cfg.get('max_num_grid', 192) |
|
sample_idx = torch.randperm( |
|
grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid |
|
)] |
|
grid_feats = grid_feats[sample_idx] |
|
|
|
grid_pred = self.grid_head(grid_feats) |
|
|
|
grid_targets = self.grid_head.get_targets(sampling_results, |
|
self.train_cfg) |
|
grid_targets = grid_targets[sample_idx] |
|
|
|
loss_grid = self.grid_head.loss(grid_pred, grid_targets) |
|
|
|
bbox_results['loss_bbox'].update(loss_grid) |
|
return bbox_results |
|
|
|
def simple_test(self, |
|
x, |
|
proposal_list, |
|
img_metas, |
|
proposals=None, |
|
rescale=False): |
|
"""Test without augmentation.""" |
|
assert self.with_bbox, 'Bbox head must be implemented.' |
|
|
|
det_bboxes, det_labels = self.simple_test_bboxes( |
|
x, img_metas, proposal_list, self.test_cfg, rescale=False) |
|
|
|
grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes]) |
|
if grid_rois.shape[0] != 0: |
|
grid_feats = self.grid_roi_extractor( |
|
x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois) |
|
self.grid_head.test_mode = True |
|
grid_pred = self.grid_head(grid_feats) |
|
|
|
num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes) |
|
grid_pred = { |
|
k: v.split(num_roi_per_img, 0) |
|
for k, v in grid_pred.items() |
|
} |
|
|
|
|
|
bbox_results = [] |
|
num_imgs = len(det_bboxes) |
|
for i in range(num_imgs): |
|
if det_bboxes[i].shape[0] == 0: |
|
bbox_results.append(grid_rois.new_tensor([])) |
|
else: |
|
det_bbox = self.grid_head.get_bboxes( |
|
det_bboxes[i], grid_pred['fused'][i], [img_metas[i]]) |
|
if rescale: |
|
det_bbox[:, :4] /= img_metas[i]['scale_factor'] |
|
bbox_results.append( |
|
bbox2result(det_bbox, det_labels[i], |
|
self.bbox_head.num_classes)) |
|
else: |
|
bbox_results = [ |
|
grid_rois.new_tensor([]) for _ in range(len(det_bboxes)) |
|
] |
|
|
|
if not self.with_mask: |
|
return bbox_results |
|
else: |
|
segm_results = self.simple_test_mask( |
|
x, img_metas, det_bboxes, det_labels, rescale=rescale) |
|
return list(zip(bbox_results, segm_results)) |
|
|