Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import Callable, Tuple | |
import torch | |
from mmcv import Config | |
from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor | |
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator | |
from risk_biased.mpc_planner.planner_cost import TrackingCostParams | |
from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams | |
from risk_biased.mpc_planner.planner_cost import TrackingCost | |
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams | |
from risk_biased.utils.planner_utils import AbstractState, to_state | |
from risk_biased.utils.risk import get_risk_estimator | |
class MPCPlannerParams: | |
"""Dataclass for MPC-Planner Parameters | |
Args: | |
dt_s: discrete time interval in seconds that is used for planning | |
num_steps: number of time steps for which history of ego's and the other actor's | |
trajectories are stored | |
num_steps_future: number of time steps into the future for which ego's and the other actor's | |
trajectories are considered | |
acceleration_std_x_m_s2: Acceleration noise standard deviation (m/s^2) in x-direction that | |
is used to initialize the Cross Entropy solver | |
acceleration_std_y_m_s2: Acceleration noise standard deviation (m/s^2) in y-direction that | |
is used to initialize the Cross Entropy solver | |
risk_estimator_params: parameters for the Monte Carlo risk estimator used in the planner for | |
ego's control optimization | |
solver_params: parameters for the CrossEntropySolver | |
tracking_cost_params: parameters for the TrackingCost | |
ttc_cost_params: parameters for the TTCCost (i.e., collision cost between ego and the other | |
actor) | |
""" | |
dt: float | |
num_steps: int | |
num_steps_future: int | |
acceleration_std_x_m_s2: float | |
acceleration_std_y_m_s2: float | |
risk_estimator_params: dict | |
solver_params: CrossEntropySolverParams | |
tracking_cost_params: TrackingCostParams | |
ttc_cost_params: TTCCostParams | |
def from_config(cfg: Config): | |
return MPCPlannerParams( | |
cfg.dt, | |
cfg.num_steps, | |
cfg.num_steps_future, | |
cfg.acceleration_std_x_m_s2, | |
cfg.acceleration_std_y_m_s2, | |
cfg.risk_estimator, | |
CrossEntropySolverParams.from_config(cfg), | |
TrackingCostParams.from_config(cfg), | |
TTCCostParams.from_config(cfg), | |
) | |
class MPCPlanner: | |
"""MPC Planner with a Cross Entropy solver | |
Args: | |
params: MPCPlannerParams object | |
predictor: LitTrajectoryPredictor object | |
normalizer: function that takes in an unnormalized trajectory and that outputs the | |
normalized trajectory and the offset in this order | |
""" | |
def __init__( | |
self, | |
params: MPCPlannerParams, | |
predictor: LitTrajectoryPredictor, | |
normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], | |
) -> None: | |
self.params = params | |
self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt) | |
self.control_input_mean_init = torch.zeros( | |
1, params.num_steps_future, self.dynamics_model.control_dim | |
) | |
self.control_input_std_init = torch.Tensor( | |
[ | |
params.acceleration_std_x_m_s2, | |
params.acceleration_std_y_m_s2, | |
] | |
).expand_as(self.control_input_mean_init) | |
self.solver = CrossEntropySolver( | |
params=params.solver_params, | |
dynamics_model=self.dynamics_model, | |
control_input_mean=self.control_input_mean_init, | |
control_input_std=self.control_input_std_init, | |
tracking_cost_function=TrackingCost(params.tracking_cost_params), | |
interaction_cost_function=TTCCostTorch(params.ttc_cost_params), | |
risk_estimator=get_risk_estimator(params.risk_estimator_params), | |
) | |
self.predictor = predictor | |
self.normalizer = normalizer | |
self._ego_state_history = [] | |
self._ego_state_target_trajectory = None | |
self._ego_state_planned_trajectory = None | |
self._ado_state_history = [] | |
self._latest_ado_position_future_samples = None | |
def replan( | |
self, | |
current_ado_state: AbstractState, | |
current_ego_state: AbstractState, | |
target_velocity: torch.Tensor, | |
num_prediction_samples: int = 1, | |
risk_level: float = 0.0, | |
resample_prediction: bool = False, | |
risk_in_predictor: bool = False, | |
) -> None: | |
"""Performs re-planning given the current_ado_position, current_ego_state, and | |
target_velocity. Updates ego_state_planned_trajectory. Note that all the information given | |
to the solver.solve(...) is expressed in the ego-centric frame, whose origin is the initial | |
ego position in ego_state_history and the x-direction is parallel to the initial ego | |
velocity. | |
Args: | |
current_ado_position: ado state | |
current_ego_state: ego state | |
target_velocity: ((1), 2) tensor | |
num_prediction_samples (optional): number of prediction samples. Defaults to 1. | |
risk_level (optional): a risk-level float for the entire prediction-planning pipeline. | |
If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0. | |
resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy | |
iteration. Defaults to False. | |
risk_in_predictor (optional): If True, risk-biased prediction is used and the solver | |
becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes | |
risk-sensitive. Defaults to False. | |
""" | |
self._update_ado_state_history(current_ado_state) | |
self._update_ego_state_history(current_ego_state) | |
self._update_ego_state_target_trajectory(current_ego_state, target_velocity) | |
if not self.ado_state_history.shape[-1] < self.params.num_steps: | |
self.solver.solve( | |
self.predictor, | |
self._map_to_ego_centric_frame(self.ego_state_history), | |
self._map_to_ego_centric_frame(self._ego_state_target_trajectory), | |
self._map_to_ego_centric_frame(self.ado_state_history), | |
self.normalizer, | |
num_prediction_samples=num_prediction_samples, | |
risk_level=risk_level, | |
resample_prediction=resample_prediction, | |
risk_in_predictor=risk_in_predictor, | |
) | |
ego_state_planned_trajectory_in_ego_frame = self.dynamics_model.simulate( | |
self._map_to_ego_centric_frame(self.ego_state_history[..., -1]), | |
self.solver.control_sequence, | |
) | |
self._ego_state_planned_trajectory = self._map_to_world_frame( | |
ego_state_planned_trajectory_in_ego_frame | |
) | |
latest_ado_position_future_samples_in_ego_frame = ( | |
self.solver.fetch_latest_prediction() | |
) | |
if latest_ado_position_future_samples_in_ego_frame is not None: | |
self._latest_ado_position_future_samples = self._map_to_world_frame( | |
latest_ado_position_future_samples_in_ego_frame | |
) | |
else: | |
self._latest_ado_position_future_samples = None | |
def get_planned_next_ego_state(self) -> AbstractState: | |
"""Returns the next ego state according to the ego_state_planned_trajectory | |
Returns: | |
Planned state | |
""" | |
assert ( | |
self._ego_state_planned_trajectory is not None | |
), "call self.replan(...) first" | |
return self._ego_state_planned_trajectory[..., 0] | |
def reset(self) -> None: | |
"""Resets the planner's internal state. This will fully reset the solver's internal state, | |
including solver.control_input_mean_init and solver.control_input_std_init.""" | |
self.solver.control_input_mean_init = ( | |
self.control_input_mean_init.detach().clone() | |
) | |
self.solver.control_input_std_init = ( | |
self.control_input_std_init.detach().clone() | |
) | |
self.solver.reset() | |
self._ego_state_history = [] | |
self._ego_state_target_trajectory = None | |
self._ego_state_planned_trajectory = None | |
self._ado_state_history = [] | |
self._latest_ado_position_future_samples = None | |
def fetch_latest_prediction(self) -> torch.Tensor: | |
if self._latest_ado_position_future_samples is not None: | |
return self._latest_ado_position_future_samples | |
else: | |
return None | |
def ego_state_history(self) -> torch.Tensor: | |
"""Returns ego_state_history as a concatenated tensor | |
Returns: | |
ego_state_history tensor | |
""" | |
assert len(self._ego_state_history) > 0 | |
return to_state( | |
torch.stack( | |
[ego_state.get_states(4) for ego_state in self._ego_state_history], | |
dim=-2, | |
), | |
self.params.dt, | |
) | |
def ado_state_history(self) -> torch.Tensor: | |
"""Returns ado_position_history as a concatenated tensor | |
Returns: | |
ado_position_history tensor | |
""" | |
assert len(self._ado_state_history) > 0 | |
return to_state( | |
torch.stack( | |
[ado_state.get_states(4) for ado_state in self._ado_state_history], | |
dim=-2, | |
), | |
self.params.dt, | |
) | |
def _update_ego_state_history(self, current_ego_state: AbstractState) -> None: | |
"""Updates ego_state_history with the current_ego_state | |
Args: | |
current_ego_state: (1, state_dim) tensor | |
""" | |
if len(self._ego_state_history) >= self.params.num_steps: | |
self._ego_state_history = self._ego_state_history[1:] | |
self._ego_state_history.append(current_ego_state) | |
assert len(self._ego_state_history) <= self.params.num_steps | |
def _update_ado_state_history(self, current_ado_state: AbstractState) -> None: | |
"""Updates ego_state_history with the current_ado_position | |
Args: | |
current_ado_state states of the current non-ego vehicles | |
""" | |
if len(self._ado_state_history) >= self.params.num_steps: | |
self._ado_state_history = self._ado_state_history[1:] | |
self._ado_state_history.append(current_ado_state) | |
assert len(self._ado_state_history) <= self.params.num_steps | |
def _update_ego_state_target_trajectory( | |
self, current_ego_state: AbstractState, target_velocity: torch.Tensor | |
) -> None: | |
"""Updates ego_state_target_trajectory based on the current_ego_state and the target_velocity | |
Args: | |
current_ego_state: state | |
target_velocity: (1, 2) tensor | |
""" | |
target_displacement = self.params.dt * target_velocity | |
target_position_list = [current_ego_state.position] | |
for time_idx in range(self.params.num_steps_future): | |
target_position_list.append(target_position_list[-1] + target_displacement) | |
target_position_list = target_position_list[1:] | |
target_position = torch.cat(target_position_list, dim=-2) | |
target_state = to_state( | |
torch.cat( | |
(target_position, target_velocity.expand_as(target_position)), dim=-1 | |
), | |
self.params.dt, | |
) | |
self._ego_state_target_trajectory = target_state | |
def _map_to_ego_centric_frame( | |
self, trajectory_in_world_frame: AbstractState | |
) -> torch.Tensor: | |
"""Maps trajectory epxressed in the world frame to the ego-centric frame, whose origin is | |
the initial ego position in ego_state_history and the x-direction is parallel to the initial | |
ego velocity | |
Args: | |
trajectory: sequence of states | |
Returns: | |
trajectory mapped to the ego-centric frame | |
""" | |
# If trajectory_in_world_frame is of shape (..., state_dim) then use the associated | |
# dynamics model in translate_position and rotate_angle. Otherwise assume that th | |
# trajectory is in the 2D position space. | |
ego_pos_init = self.ego_state_history.position[..., -1, :] | |
ego_vel_init = self.ego_state_history.velocity[..., -1, :] | |
ego_rot_init = torch.atan2(ego_vel_init[..., 1], ego_vel_init[..., 0]) | |
trajectory_in_ego_frame = trajectory_in_world_frame.translate( | |
-ego_pos_init | |
).rotate(-ego_rot_init) | |
return trajectory_in_ego_frame | |
def _map_to_world_frame( | |
self, trajectory_in_ego_frame: torch.Tensor | |
) -> torch.Tensor: | |
"""Maps trajectory epxressed in the ego-centric frame to the world frame | |
Args: | |
trajectory_in_ego_frame: (..., 2) position trajectory or (..., markov_state_dim) state | |
trajectory expressed in the ego-centric frame, whose origin is the initial ego | |
position in ego_state_history and the x-direction is parallel to the initial ego | |
velocity | |
Returns: | |
trajectory mapped to the world frame | |
""" | |
# state starts with x, y, angle | |
ego_pos_init = self.ego_state_history.position[..., -1, :] | |
ego_rot_init = self.ego_state_history.angle[..., -1, :] | |
trajectory_in_world_frame = trajectory_in_ego_frame.rotate( | |
ego_rot_init | |
).translate(ego_pos_init) | |
return trajectory_in_world_frame | |