jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
from dataclasses import dataclass
from typing import Optional
import torch
from mmcv import Config
@dataclass
class TrackingCostParams:
scale_longitudinal: float
scale_lateral: float
reduce: str
@staticmethod
def from_config(cfg: Config):
return TrackingCostParams(
scale_longitudinal=cfg.tracking_cost_scale_longitudinal,
scale_lateral=cfg.tracking_cost_scale_lateral,
reduce=cfg.tracking_cost_reduce,
)
class TrackingCost:
"""Quadratic Trajectory Tracking Cost
Args:
params: tracking cost parameters
"""
def __init__(self, params: TrackingCostParams) -> None:
self.scale_longitudinal = params.scale_longitudinal
self.scale_lateral = params.scale_lateral
assert params.reduce in [
"min",
"max",
"mean",
"now",
"final",
], "unsupported reduce type"
self._reduce_fun_name = params.reduce
def __call__(
self,
ego_position_trajectory: torch.Tensor,
target_position_trajectory: torch.Tensor,
target_velocity_trajectory: torch.Tensor,
) -> torch.Tensor:
"""Computes quadratic tracking cost
Args:
ego_position_trajectory: (some_shape, num_some_steps, 2) tensor of ego
position trajectory
target_position_trajectory: (some_shape, num_some_steps, 2) tensor of
ego target position trajectory
target_velocity_trajectory: (some_shape, num_some_steps, 2) tensor of
ego target velocity trajectory
Returns:
(some_shape) cost
"""
cost_matrix = self._get_quadratic_cost_matrix(target_velocity_trajectory)
cost = (
(
(ego_position_trajectory - target_position_trajectory).unsqueeze(-2)
@ cost_matrix
@ (ego_position_trajectory - target_position_trajectory).unsqueeze(-1)
)
.squeeze(-1)
.squeeze(-1)
)
return self._reduce(cost, dim=-1)
def _reduce(self, cost: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
"""Reduces the cost tensor based on self._reduce_fun_name
Args:
cost: cost tensor of some shape where the last dimension represents time
dim (optional): tensor dimension to be reduced. Defaults to None.
Returns:
reduced cost tensor
"""
if self._reduce_fun_name == "min":
return torch.min(cost, dim=dim)[0] if dim is not None else torch.min(cost)
if self._reduce_fun_name == "max":
return torch.max(cost, dim=dim)[0] if dim is not None else torch.max(cost)
if self._reduce_fun_name == "mean":
return torch.mean(cost, dim=dim) if dim is not None else torch.mean(cost)
if self._reduce_fun_name == "now":
return cost[..., 0]
if self._reduce_fun_name == "final":
return cost[..., -1]
def _get_quadratic_cost_matrix(
self, target_velocity_trajectory: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Gets quadratic cost matrix based on target velocity direction per time step.
If target velocity is 0 in norm, then all zero tensor is returned for that time step.
Args:
target_velocity_trajectory: (some_shape, num_some_steps, 2) tensor of
ego target velocity trajectory
eps (optional): small positive number to ensure numerical stability. Defaults to
1e-8.
Returns:
(some_shape, num_some_steps, 2, 2) quadratic cost matrix
"""
longitudinal_direction = (
target_velocity_trajectory
/ (
torch.linalg.norm(target_velocity_trajectory, dim=-1).unsqueeze(-1)
+ eps
)
).unsqueeze(-1)
rotation_90_deg = torch.Tensor([[[0.0, -1.0], [1.0, 0]]])
lateral_direction = rotation_90_deg @ longitudinal_direction
orthogonal_matrix = torch.cat(
(longitudinal_direction, lateral_direction), dim=-1
)
eigen_matrix = torch.Tensor(
[[[self.scale_longitudinal, 0.0], [0.0, self.scale_lateral]]]
)
cost_matrix = (
orthogonal_matrix @ eigen_matrix @ orthogonal_matrix.transpose(-1, -2)
)
return cost_matrix