|
| 1 | +# pytorch_mppi |
| 2 | + |
| 3 | +Model Predictive Path Integral (MPPI) control library using approximate dynamics in PyTorch. |
| 4 | +Implements batched trajectory sampling for GPU-accelerated model-based control. |
| 5 | + |
| 6 | +## Project Structure |
| 7 | + |
| 8 | +``` |
| 9 | +src/pytorch_mppi/ |
| 10 | + __init__.py # Exports MPPI, SMPPI, KMPPI |
| 11 | + mppi.py # Core implementation (~625 lines): MPPI, SMPPI, KMPPI, run_mppi |
| 12 | + autotune.py # Hyperparameter tuning infrastructure (CMA-ES local optimizer) |
| 13 | + autotune_global.py # Ray Tune global search integration |
| 14 | + autotune_qd.py # Quality Diversity optimization (CMA-ME via pyribs) |
| 15 | +tests/ |
| 16 | + pendulum.py # MPPI on true pendulum dynamics (gym) |
| 17 | + pendulum_approximate.py # MPPI with learned neural network dynamics |
| 18 | + pendulum_approximate_continuous.py # Continuous angle representation variant |
| 19 | + smooth_mppi.py # Visual comparison of MPPI, SMPPI, KMPPI |
| 20 | + auto_tune_parameters.py # Hyperparameter tuning example |
| 21 | + test_batch_wrapper.py # Unit tests for handle_batch_input |
| 22 | +``` |
| 23 | + |
| 24 | +## Architecture |
| 25 | + |
| 26 | +### Class Hierarchy |
| 27 | +- **MPPI** - Base class. Batched trajectory sampling with importance-weighted control update (Algorithm 2, Williams et al. 2017). |
| 28 | +- **SMPPI(MPPI)** - Smooth MPPI. Lifts control space to penalize action rate of change. Maintains separate `action_sequence` and `U` (control differences). |
| 29 | +- **KMPPI(MPPI)** - Kernel MPPI. Samples fewer support points, interpolates to full trajectory via RBF kernel. Uses `functorch.vmap` for batched interpolation. |
| 30 | + |
| 31 | +### Key Data Flow (per `command()` call) |
| 32 | +1. `shift_nominal_trajectory()` - Roll U forward, append u_init |
| 33 | +2. `_compute_perturbed_action_and_noise()` - Sample K noise trajectories (K x T x nu), add to U, bound |
| 34 | +3. `_compute_rollout_costs(perturbed_actions)` - **Hot loop**: iterate T timesteps, call user dynamics+cost each step |
| 35 | +4. `_compute_total_cost_batch()` - Combine rollout cost + perturbation cost |
| 36 | +5. `_compute_weighting()` - Softmax-like exponential weighting (omega) |
| 37 | +6. Update U with weighted sum of noise perturbations |
| 38 | + |
| 39 | +### Key Dimensions |
| 40 | +- **K** = `num_samples` (trajectory samples, typically 100-1000) |
| 41 | +- **T** = `horizon` (timesteps, typically 15-30) |
| 42 | +- **M** = `rollout_samples` (stochastic dynamics replicates, usually 1) |
| 43 | +- **nu** = control dimensions, **nx** = state dimensions |
| 44 | + |
| 45 | +### User-Provided Functions |
| 46 | +- `dynamics(state, action) -> next_state` — state is K x nx, action is K x nu |
| 47 | +- `running_cost(state, action) -> cost` — cost is K x 1 |
| 48 | +- `terminal_state_cost(states, actions) -> cost` — optional, states is K x T x nx |
| 49 | +- Wrapped by `@handle_batch_input(n=2)` from `arm_pytorch_utilities` |
| 50 | + |
| 51 | +## Dependencies |
| 52 | +- **torch** — core tensor operations |
| 53 | +- **arm_pytorch_utilities** — `handle_batch_input` decorator for flexible batch dimensions |
| 54 | +- **functorch** — `vmap` used in KMPPI for batched kernel interpolation |
| 55 | +- **numpy** — minimal use (only in `get_params()` for display and in autotune) |
| 56 | +- Optional: cma, ray[tune], bayesian-optimization, hyperopt (for autotune) |
| 57 | + |
| 58 | +## Planned Optimization Refactor |
| 59 | + |
| 60 | +Goal: Remove Python loops and make code compatible with `torch.compile`, similar to what was done for `pytorch_kinematics`. |
| 61 | + |
| 62 | +### Performance-Critical Hot Path |
| 63 | +The main bottleneck is `_compute_rollout_costs()` (mppi.py:254-267): |
| 64 | +```python |
| 65 | +for t in range(T): |
| 66 | + u = self.u_scale * perturbed_actions[:, t].repeat(self.M, 1, 1) |
| 67 | + next_state = self._dynamics(state, u, t) |
| 68 | + next_state = self._sample_specific_dynamics(next_state, state, u, t) |
| 69 | + state = next_state |
| 70 | + c = self._running_cost(state, u, t) |
| 71 | + cost_samples = cost_samples + c |
| 72 | + states.append(state) |
| 73 | + actions.append(u) |
| 74 | +actions = torch.stack(actions, dim=-2) |
| 75 | +states = torch.stack(states, dim=-2) |
| 76 | +``` |
| 77 | + |
| 78 | +Similarly `get_rollouts()` (mppi.py:357-361) has a horizon loop. |
| 79 | + |
| 80 | +### torch.compile Blockers |
| 81 | +1. **Horizon for-loop with list appends** — `states.append()` / `torch.stack()` pattern. Fix: pre-allocate tensors, use index assignment. |
| 82 | +2. **Shape-dependent control flow** (mppi.py:244): `if self.state.shape == (K, self.nx)`. Fix: use attribute flags or always reshape. |
| 83 | +3. **`@handle_batch_input` decorator** — wraps dynamics/cost with runtime shape inspection. May need compile-friendly alternative or be applied outside compiled region. |
| 84 | +4. **Optional feature branching** — `if self.terminal_state_cost`, `if self.M > 1`, `if self.specific_action_sampler is not None`. Fix: use guards at init time or compile separate variants. |
| 85 | +5. **`from functorch import vmap`** — should migrate to `torch.vmap` (functorch is merged into PyTorch core). |
| 86 | + |
| 87 | +### Vectorization Opportunities |
| 88 | +- The horizon loop is inherently sequential (state[t+1] depends on state[t]), so it cannot be parallelized across T. However, `torch.compile` can still optimize the unrolled loop if graph breaks are eliminated. |
| 89 | +- Pre-allocating `states` and `actions` tensors avoids dynamic list building. |
| 90 | +- The `_compute_weighting` softmax computation is already vectorized. |
| 91 | +- Noise sampling and action cost computation are already fully vectorized. |
| 92 | + |
| 93 | +### Already Optimized |
| 94 | +- Noise sampling: fully batched via `MultivariateNormal.rsample((K, T))` |
| 95 | +- Action cost: matrix multiply K x T x nu @ nu x nu |
| 96 | +- KMPPI kernel interpolation: uses vmap for batch kernel solve |
| 97 | +- Cost weighting: vectorized exp + normalize |
| 98 | + |
| 99 | +### Migration Notes |
| 100 | +- `functorch.vmap` → `torch.vmap` (available since PyTorch 2.0) |
| 101 | +- Consider whether `arm_pytorch_utilities.handle_batch_input` can be replaced with simpler reshape logic to reduce external dependencies |
| 102 | +- User dynamics/cost functions must also be compile-friendly for full graph compilation; consider documenting this requirement or providing a fallback path |
| 103 | + |
| 104 | +## Development |
| 105 | + |
| 106 | +```shell |
| 107 | +pip install -e . # Dev install |
| 108 | +pip install -e .[test] # With test deps |
| 109 | +KMP_DUPLICATE_LIB_OK=TRUE pytest tests/test_mppi.py tests/test_batch_wrapper.py -v # Run tests |
| 110 | +KMP_DUPLICATE_LIB_OK=TRUE python tests/benchmark_mppi.py # Run benchmarks (saves benchmark_results.json) |
| 111 | +``` |
| 112 | + |
| 113 | +Note: `KMP_DUPLICATE_LIB_OK=TRUE` is needed on this machine due to duplicate OpenMP libraries. |
| 114 | + |
| 115 | +## Testing |
| 116 | + |
| 117 | +### Test files |
| 118 | +- `tests/test_mppi.py` — 64 tests covering MPPI/SMPPI/KMPPI correctness and solution quality |
| 119 | +- `tests/test_batch_wrapper.py` — 2 tests for handle_batch_input decorator |
| 120 | +- `tests/benchmark_mppi.py` — Timing + solution quality benchmarks |
| 121 | + |
| 122 | +### Test categories in test_mppi.py |
| 123 | +- **TestMPPI** (26): shapes, bounds, features, determinism, state handling |
| 124 | +- **TestSMPPI** (11): SMPPI-specific behavior, smoothness, action sequences |
| 125 | +- **TestKMPPI** (12): kernel interpolation, support points, stability |
| 126 | +- **TestSpecificActionSampler** (1): custom sampler integration |
| 127 | +- **TestEdgeCases** (5): numpy input, high-dim, 1-sample, float32 |
| 128 | +- **TestSolutionQuality** (9): goal convergence, cost bounds, determinism, bounded actions |
| 129 | + |
| 130 | +### Baseline performance (CPU, K=500, T=15) |
| 131 | +| Controller | per-command | 20-step loop | |
| 132 | +|---|---|---| |
| 133 | +| MPPI | 0.63ms | 12.9ms | |
| 134 | +| SMPPI | 0.68ms | 13.7ms | |
| 135 | +| KMPPI | 1.05ms | 20.5ms | |
| 136 | + |
| 137 | +### Baseline solution quality (CPU, K=500, T=15, 20 steps, 5 trials) |
| 138 | +| Controller | Accum Cost | Final Dist | Control Smoothness | |
| 139 | +|---|---|---|---| |
| 140 | +| MPPI | 113.5 +/- 17.6 | 1.59 +/- 0.96 | 57.8 +/- 9.0 | |
| 141 | +| KMPPI | 111.0 +/- 12.2 | 1.61 +/- 0.58 | 25.9 +/- 3.2 | |
| 142 | + |
| 143 | +Note: SMPPI quality is highly environment-dependent; it requires careful tuning (action bounds, terminal cost) per environment. KMPPI achieves similar cost to MPPI but with 2x smoother control. |
0 commit comments