|
import torch |
|
from mmdet.models.task_modules import BaseBBoxCoder |
|
|
|
from mmdet3d.registry import TASK_UTILS |
|
from .util import denormalize_bbox |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class NMSFreeCoder(BaseBBoxCoder): |
|
"""Bbox coder for NMS-free detector. |
|
|
|
Args: |
|
pc_range (list[float]): Range of point cloud. |
|
post_center_range (list[float]): Limit of the center. |
|
Default: None. |
|
max_num (int): Max number to be kept. Default: 100. |
|
score_threshold (float): Threshold to filter boxes based on score. |
|
Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
pc_range=None, |
|
voxel_size=None, |
|
post_center_range=None, |
|
max_num=100, |
|
score_threshold=None, |
|
num_classes=10): |
|
|
|
self.pc_range = pc_range |
|
self.voxel_size = voxel_size |
|
self.post_center_range = post_center_range |
|
self.max_num = max_num |
|
self.score_threshold = score_threshold |
|
self.num_classes = num_classes |
|
|
|
def encode(self): |
|
pass |
|
|
|
def decode_single(self, cls_scores, bbox_preds): |
|
"""Decode bboxes. |
|
|
|
Args: |
|
cls_scores (Tensor): Outputs from the classification head, |
|
shape [num_query, cls_out_channels]. Note that |
|
cls_out_channels should includes background. |
|
bbox_preds (Tensor): Outputs from the regression |
|
head with normalized coordinate |
|
(cx, cy, l, w, cz, h, rot_sine, rot_cosine, vx, vy). |
|
Shape [num_query, 10]. |
|
Returns: |
|
list[dict]: Decoded boxes. |
|
""" |
|
max_num = self.max_num |
|
|
|
cls_scores = cls_scores.sigmoid() |
|
scores, indexes = cls_scores.view(-1).topk(max_num) |
|
labels = indexes % self.num_classes |
|
bbox_index = indexes // self.num_classes |
|
bbox_preds = bbox_preds[bbox_index] |
|
|
|
|
|
final_box_preds = denormalize_bbox(bbox_preds, None) |
|
final_scores = scores |
|
final_preds = labels |
|
|
|
|
|
if self.score_threshold is not None: |
|
thresh_mask = final_scores > self.score_threshold |
|
if self.post_center_range is not None: |
|
self.post_center_range = torch.tensor( |
|
self.post_center_range, device=scores.device) |
|
mask = (final_box_preds[..., :3] >= |
|
self.post_center_range[:3]).all(1) |
|
mask &= (final_box_preds[..., :3] <= |
|
self.post_center_range[3:]).all(1) |
|
|
|
if self.score_threshold: |
|
mask &= thresh_mask |
|
|
|
boxes3d = final_box_preds[mask] |
|
scores = final_scores[mask] |
|
labels = final_preds[mask] |
|
predictions_dict = { |
|
'bboxes': boxes3d, |
|
'scores': scores, |
|
'labels': labels |
|
} |
|
|
|
else: |
|
raise NotImplementedError( |
|
'Need to reorganize output as a batch, only ' |
|
'support post_center_range is not None for now!') |
|
return predictions_dict |
|
|
|
def decode(self, preds_dicts): |
|
"""Decode bboxes. |
|
|
|
Args: |
|
all_cls_scores (Tensor): Outputs from the classification head, |
|
shape [nb_dec, bs, num_query, cls_out_channels]. Note |
|
cls_out_channels should includes background. |
|
all_bbox_preds (Tensor): Sigmoid outputs from the regression |
|
head with normalized coordinate format |
|
(cx, cy, l, w, cz, h, rot_sine, rot_cosine, vx, vy). |
|
Shape [nb_dec, bs, num_query, 10]. |
|
Returns: |
|
list[dict]: Decoded boxes. |
|
""" |
|
|
|
all_cls_scores = preds_dicts['all_cls_scores'][-1] |
|
all_bbox_preds = preds_dicts['all_bbox_preds'][-1] |
|
|
|
batch_size = all_cls_scores.size()[0] |
|
predictions_list = [] |
|
for i in range(batch_size): |
|
predictions_list.append( |
|
self.decode_single(all_cls_scores[i], all_bbox_preds[i])) |
|
return predictions_list |
|
|