File size: 13,669 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
from dataclasses import dataclass
from typing import Callable, Tuple

import torch
from mmcv import Config

from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
from risk_biased.mpc_planner.planner_cost import TrackingCostParams
from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams
from risk_biased.mpc_planner.planner_cost import TrackingCost
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
from risk_biased.utils.planner_utils import AbstractState, to_state
from risk_biased.utils.risk import get_risk_estimator


@dataclass
class MPCPlannerParams:
    """Dataclass for MPC-Planner Parameters

    Args:
        dt_s: discrete time interval in seconds that is used for planning
        num_steps: number of time steps for which history of ego's and the other actor's
          trajectories are stored
        num_steps_future: number of time steps into the future for which ego's and the other actor's
          trajectories are considered
        acceleration_std_x_m_s2: Acceleration noise standard deviation (m/s^2) in x-direction that
          is used to initialize the Cross Entropy solver
        acceleration_std_y_m_s2: Acceleration noise standard deviation (m/s^2) in y-direction that
          is used to initialize the Cross Entropy solver
        risk_estimator_params: parameters for the Monte Carlo risk estimator used in the planner for
          ego's control optimization
        solver_params: parameters for the CrossEntropySolver
        tracking_cost_params: parameters for the TrackingCost
        ttc_cost_params: parameters for the TTCCost (i.e., collision cost between ego and the other
          actor)
    """

    dt: float
    num_steps: int
    num_steps_future: int
    acceleration_std_x_m_s2: float
    acceleration_std_y_m_s2: float

    risk_estimator_params: dict
    solver_params: CrossEntropySolverParams
    tracking_cost_params: TrackingCostParams
    ttc_cost_params: TTCCostParams

    @staticmethod
    def from_config(cfg: Config):
        return MPCPlannerParams(
            cfg.dt,
            cfg.num_steps,
            cfg.num_steps_future,
            cfg.acceleration_std_x_m_s2,
            cfg.acceleration_std_y_m_s2,
            cfg.risk_estimator,
            CrossEntropySolverParams.from_config(cfg),
            TrackingCostParams.from_config(cfg),
            TTCCostParams.from_config(cfg),
        )


class MPCPlanner:
    """MPC Planner with a Cross Entropy solver

    Args:
        params: MPCPlannerParams object
        predictor: LitTrajectoryPredictor object
        normalizer: function that takes in an unnormalized trajectory and that outputs the
          normalized trajectory and the offset in this order
    """

    def __init__(
        self,
        params: MPCPlannerParams,
        predictor: LitTrajectoryPredictor,
        normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
    ) -> None:

        self.params = params
        self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt)
        self.control_input_mean_init = torch.zeros(
            1, params.num_steps_future, self.dynamics_model.control_dim
        )
        self.control_input_std_init = torch.Tensor(
            [
                params.acceleration_std_x_m_s2,
                params.acceleration_std_y_m_s2,
            ]
        ).expand_as(self.control_input_mean_init)
        self.solver = CrossEntropySolver(
            params=params.solver_params,
            dynamics_model=self.dynamics_model,
            control_input_mean=self.control_input_mean_init,
            control_input_std=self.control_input_std_init,
            tracking_cost_function=TrackingCost(params.tracking_cost_params),
            interaction_cost_function=TTCCostTorch(params.ttc_cost_params),
            risk_estimator=get_risk_estimator(params.risk_estimator_params),
        )
        self.predictor = predictor
        self.normalizer = normalizer

        self._ego_state_history = []
        self._ego_state_target_trajectory = None
        self._ego_state_planned_trajectory = None

        self._ado_state_history = []
        self._latest_ado_position_future_samples = None

    def replan(
        self,
        current_ado_state: AbstractState,
        current_ego_state: AbstractState,
        target_velocity: torch.Tensor,
        num_prediction_samples: int = 1,
        risk_level: float = 0.0,
        resample_prediction: bool = False,
        risk_in_predictor: bool = False,
    ) -> None:
        """Performs re-planning given the current_ado_position, current_ego_state, and
        target_velocity. Updates ego_state_planned_trajectory. Note that all the information given
        to the solver.solve(...) is expressed in the ego-centric frame, whose origin is the initial
        ego position in ego_state_history and the x-direction is parallel to the initial ego
        velocity.

        Args:
            current_ado_position: ado state
            current_ego_state: ego state
            target_velocity: ((1), 2) tensor
            num_prediction_samples (optional): number of prediction samples. Defaults to 1.
            risk_level (optional): a risk-level float for the entire prediction-planning pipeline.
              If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0.
            resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy
              iteration. Defaults to False.
            risk_in_predictor (optional): If True, risk-biased prediction is used and the solver
              becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes
              risk-sensitive. Defaults to False.
        """
        self._update_ado_state_history(current_ado_state)
        self._update_ego_state_history(current_ego_state)
        self._update_ego_state_target_trajectory(current_ego_state, target_velocity)
        if not self.ado_state_history.shape[-1] < self.params.num_steps:
            self.solver.solve(
                self.predictor,
                self._map_to_ego_centric_frame(self.ego_state_history),
                self._map_to_ego_centric_frame(self._ego_state_target_trajectory),
                self._map_to_ego_centric_frame(self.ado_state_history),
                self.normalizer,
                num_prediction_samples=num_prediction_samples,
                risk_level=risk_level,
                resample_prediction=resample_prediction,
                risk_in_predictor=risk_in_predictor,
            )
        ego_state_planned_trajectory_in_ego_frame = self.dynamics_model.simulate(
            self._map_to_ego_centric_frame(self.ego_state_history[..., -1]),
            self.solver.control_sequence,
        )
        self._ego_state_planned_trajectory = self._map_to_world_frame(
            ego_state_planned_trajectory_in_ego_frame
        )
        latest_ado_position_future_samples_in_ego_frame = (
            self.solver.fetch_latest_prediction()
        )
        if latest_ado_position_future_samples_in_ego_frame is not None:
            self._latest_ado_position_future_samples = self._map_to_world_frame(
                latest_ado_position_future_samples_in_ego_frame
            )
        else:
            self._latest_ado_position_future_samples = None

    def get_planned_next_ego_state(self) -> AbstractState:
        """Returns the next ego state according to the ego_state_planned_trajectory

        Returns:
            Planned state
        """
        assert (
            self._ego_state_planned_trajectory is not None
        ), "call self.replan(...) first"
        return self._ego_state_planned_trajectory[..., 0]

    def reset(self) -> None:
        """Resets the planner's internal state. This will fully reset the solver's internal state,
        including solver.control_input_mean_init and solver.control_input_std_init."""
        self.solver.control_input_mean_init = (
            self.control_input_mean_init.detach().clone()
        )
        self.solver.control_input_std_init = (
            self.control_input_std_init.detach().clone()
        )
        self.solver.reset()

        self._ego_state_history = []
        self._ego_state_target_trajectory = None
        self._ego_state_planned_trajectory = None

        self._ado_state_history = []
        self._latest_ado_position_future_samples = None

    def fetch_latest_prediction(self) -> torch.Tensor:
        if self._latest_ado_position_future_samples is not None:
            return self._latest_ado_position_future_samples
        else:
            return None

    @property
    def ego_state_history(self) -> torch.Tensor:
        """Returns ego_state_history as a concatenated tensor
        Returns:
            ego_state_history tensor
        """
        assert len(self._ego_state_history) > 0
        return to_state(
            torch.stack(
                [ego_state.get_states(4) for ego_state in self._ego_state_history],
                dim=-2,
            ),
            self.params.dt,
        )

    @property
    def ado_state_history(self) -> torch.Tensor:
        """Returns ado_position_history as a concatenated tensor

        Returns:
            ado_position_history tensor
        """
        assert len(self._ado_state_history) > 0
        return to_state(
            torch.stack(
                [ado_state.get_states(4) for ado_state in self._ado_state_history],
                dim=-2,
            ),
            self.params.dt,
        )

    def _update_ego_state_history(self, current_ego_state: AbstractState) -> None:
        """Updates ego_state_history with the current_ego_state

        Args:
            current_ego_state: (1, state_dim) tensor
        """

        if len(self._ego_state_history) >= self.params.num_steps:
            self._ego_state_history = self._ego_state_history[1:]
        self._ego_state_history.append(current_ego_state)
        assert len(self._ego_state_history) <= self.params.num_steps

    def _update_ado_state_history(self, current_ado_state: AbstractState) -> None:
        """Updates ego_state_history with the current_ado_position

        Args:
            current_ado_state states of the current non-ego vehicles
        """

        if len(self._ado_state_history) >= self.params.num_steps:
            self._ado_state_history = self._ado_state_history[1:]
        self._ado_state_history.append(current_ado_state)
        assert len(self._ado_state_history) <= self.params.num_steps

    def _update_ego_state_target_trajectory(
        self, current_ego_state: AbstractState, target_velocity: torch.Tensor
    ) -> None:
        """Updates ego_state_target_trajectory based on the current_ego_state and the target_velocity

        Args:
            current_ego_state: state
            target_velocity: (1, 2) tensor
        """

        target_displacement = self.params.dt * target_velocity
        target_position_list = [current_ego_state.position]
        for time_idx in range(self.params.num_steps_future):
            target_position_list.append(target_position_list[-1] + target_displacement)
        target_position_list = target_position_list[1:]
        target_position = torch.cat(target_position_list, dim=-2)
        target_state = to_state(
            torch.cat(
                (target_position, target_velocity.expand_as(target_position)), dim=-1
            ),
            self.params.dt,
        )
        self._ego_state_target_trajectory = target_state

    def _map_to_ego_centric_frame(
        self, trajectory_in_world_frame: AbstractState
    ) -> torch.Tensor:
        """Maps trajectory epxressed in the world frame to the ego-centric frame, whose origin is
        the initial ego position in ego_state_history and the x-direction is parallel to the initial
        ego velocity

        Args:
            trajectory: sequence of states

        Returns:
            trajectory mapped to the ego-centric frame
        """
        # If trajectory_in_world_frame is of shape (..., state_dim) then use the associated
        # dynamics model in translate_position and rotate_angle. Otherwise assume that th
        # trajectory is in the 2D position space.

        ego_pos_init = self.ego_state_history.position[..., -1, :]
        ego_vel_init = self.ego_state_history.velocity[..., -1, :]
        ego_rot_init = torch.atan2(ego_vel_init[..., 1], ego_vel_init[..., 0])
        trajectory_in_ego_frame = trajectory_in_world_frame.translate(
            -ego_pos_init
        ).rotate(-ego_rot_init)
        return trajectory_in_ego_frame

    def _map_to_world_frame(
        self, trajectory_in_ego_frame: torch.Tensor
    ) -> torch.Tensor:
        """Maps trajectory epxressed in the ego-centric frame to the world frame

        Args:
            trajectory_in_ego_frame: (..., 2) position trajectory or (..., markov_state_dim) state
              trajectory expressed in the ego-centric frame, whose origin is the initial ego
              position in ego_state_history and the x-direction is parallel to the initial ego
              velocity

        Returns:
            trajectory mapped to the world frame
        """
        # state starts with x, y, angle
        ego_pos_init = self.ego_state_history.position[..., -1, :]
        ego_rot_init = self.ego_state_history.angle[..., -1, :]
        trajectory_in_world_frame = trajectory_in_ego_frame.rotate(
            ego_rot_init
        ).translate(ego_pos_init)
        return trajectory_in_world_frame