File size: 1,051 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Union

import torch
from torch import Tensor

from mmdet3d.registry import TASK_UTILS


@TASK_UTILS.register_module()
class BBox3DL1Cost(object):
    """BBox3DL1Cost.

    Args:
        weight (Union[float, int]): Cost weight. Defaults to 1.
    """

    def __init__(self, weight: Union[float, int] = 1.):
        self.weight = weight

    def __call__(self, bbox_pred: Tensor, gt_bboxes: Tensor) -> Tensor:
        """Compute match cost.

        Args:
            bbox_pred (Tensor): Predicted boxes with normalized coordinates
                (cx,cy,l,w,cz,h,sin(φ),cos(φ),v_x,v_y)
                which are all in range [0, 1] and shape [num_query, 10].
            gt_bboxes (Tensor): Ground truth boxes with `normalized`
                coordinates (cx,cy,l,w,cz,h,sin(φ),cos(φ),v_x,v_y).
                Shape [num_gt, 10].
        Returns:
            Tensor: Match Cost matrix of shape (num_preds, num_gts).
        """
        bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
        return bbox_cost * self.weight