Skip to content

Commit 4f171f8

Browse files
authored
Merge pull request #9 from UM-ARM-Lab/optimize-mppi
Optimize MPPI hot path for GPU performance and add torch.compile support
2 parents bfcc915 + e661e80 commit 4f171f8

7 files changed

Lines changed: 1952 additions & 104 deletions

File tree

.github/workflows/tests.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [master]
6+
pull_request:
7+
branches: [master]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.9", "3.12"]
15+
steps:
16+
- uses: actions/checkout@v4
17+
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
23+
- name: Install dependencies
24+
run: |
25+
pip install pytest
26+
pip install -e .
27+
28+
- name: Run tests
29+
run: pytest tests/test_mppi.py tests/test_batch_wrapper.py -v

CLAUDE.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# pytorch_mppi
2+
3+
Model Predictive Path Integral (MPPI) control library using approximate dynamics in PyTorch.
4+
Implements batched trajectory sampling for GPU-accelerated model-based control.
5+
6+
## Project Structure
7+
8+
```
9+
src/pytorch_mppi/
10+
__init__.py # Exports MPPI, SMPPI, KMPPI
11+
mppi.py # Core implementation (~625 lines): MPPI, SMPPI, KMPPI, run_mppi
12+
autotune.py # Hyperparameter tuning infrastructure (CMA-ES local optimizer)
13+
autotune_global.py # Ray Tune global search integration
14+
autotune_qd.py # Quality Diversity optimization (CMA-ME via pyribs)
15+
tests/
16+
pendulum.py # MPPI on true pendulum dynamics (gym)
17+
pendulum_approximate.py # MPPI with learned neural network dynamics
18+
pendulum_approximate_continuous.py # Continuous angle representation variant
19+
smooth_mppi.py # Visual comparison of MPPI, SMPPI, KMPPI
20+
auto_tune_parameters.py # Hyperparameter tuning example
21+
test_batch_wrapper.py # Unit tests for handle_batch_input
22+
```
23+
24+
## Architecture
25+
26+
### Class Hierarchy
27+
- **MPPI** - Base class. Batched trajectory sampling with importance-weighted control update (Algorithm 2, Williams et al. 2017).
28+
- **SMPPI(MPPI)** - Smooth MPPI. Lifts control space to penalize action rate of change. Maintains separate `action_sequence` and `U` (control differences).
29+
- **KMPPI(MPPI)** - Kernel MPPI. Samples fewer support points, interpolates to full trajectory via RBF kernel. Uses `functorch.vmap` for batched interpolation.
30+
31+
### Key Data Flow (per `command()` call)
32+
1. `shift_nominal_trajectory()` - Roll U forward, append u_init
33+
2. `_compute_perturbed_action_and_noise()` - Sample K noise trajectories (K x T x nu), add to U, bound
34+
3. `_compute_rollout_costs(perturbed_actions)` - **Hot loop**: iterate T timesteps, call user dynamics+cost each step
35+
4. `_compute_total_cost_batch()` - Combine rollout cost + perturbation cost
36+
5. `_compute_weighting()` - Softmax-like exponential weighting (omega)
37+
6. Update U with weighted sum of noise perturbations
38+
39+
### Key Dimensions
40+
- **K** = `num_samples` (trajectory samples, typically 100-1000)
41+
- **T** = `horizon` (timesteps, typically 15-30)
42+
- **M** = `rollout_samples` (stochastic dynamics replicates, usually 1)
43+
- **nu** = control dimensions, **nx** = state dimensions
44+
45+
### User-Provided Functions
46+
- `dynamics(state, action) -> next_state` — state is K x nx, action is K x nu
47+
- `running_cost(state, action) -> cost` — cost is K x 1
48+
- `terminal_state_cost(states, actions) -> cost` — optional, states is K x T x nx
49+
- Wrapped by `@handle_batch_input(n=2)` from `arm_pytorch_utilities`
50+
51+
## Dependencies
52+
- **torch** — core tensor operations
53+
- **arm_pytorch_utilities**`handle_batch_input` decorator for flexible batch dimensions
54+
- **functorch**`vmap` used in KMPPI for batched kernel interpolation
55+
- **numpy** — minimal use (only in `get_params()` for display and in autotune)
56+
- Optional: cma, ray[tune], bayesian-optimization, hyperopt (for autotune)
57+
58+
## Planned Optimization Refactor
59+
60+
Goal: Remove Python loops and make code compatible with `torch.compile`, similar to what was done for `pytorch_kinematics`.
61+
62+
### Performance-Critical Hot Path
63+
The main bottleneck is `_compute_rollout_costs()` (mppi.py:254-267):
64+
```python
65+
for t in range(T):
66+
u = self.u_scale * perturbed_actions[:, t].repeat(self.M, 1, 1)
67+
next_state = self._dynamics(state, u, t)
68+
next_state = self._sample_specific_dynamics(next_state, state, u, t)
69+
state = next_state
70+
c = self._running_cost(state, u, t)
71+
cost_samples = cost_samples + c
72+
states.append(state)
73+
actions.append(u)
74+
actions = torch.stack(actions, dim=-2)
75+
states = torch.stack(states, dim=-2)
76+
```
77+
78+
Similarly `get_rollouts()` (mppi.py:357-361) has a horizon loop.
79+
80+
### torch.compile Blockers
81+
1. **Horizon for-loop with list appends**`states.append()` / `torch.stack()` pattern. Fix: pre-allocate tensors, use index assignment.
82+
2. **Shape-dependent control flow** (mppi.py:244): `if self.state.shape == (K, self.nx)`. Fix: use attribute flags or always reshape.
83+
3. **`@handle_batch_input` decorator** — wraps dynamics/cost with runtime shape inspection. May need compile-friendly alternative or be applied outside compiled region.
84+
4. **Optional feature branching**`if self.terminal_state_cost`, `if self.M > 1`, `if self.specific_action_sampler is not None`. Fix: use guards at init time or compile separate variants.
85+
5. **`from functorch import vmap`** — should migrate to `torch.vmap` (functorch is merged into PyTorch core).
86+
87+
### Vectorization Opportunities
88+
- The horizon loop is inherently sequential (state[t+1] depends on state[t]), so it cannot be parallelized across T. However, `torch.compile` can still optimize the unrolled loop if graph breaks are eliminated.
89+
- Pre-allocating `states` and `actions` tensors avoids dynamic list building.
90+
- The `_compute_weighting` softmax computation is already vectorized.
91+
- Noise sampling and action cost computation are already fully vectorized.
92+
93+
### Already Optimized
94+
- Noise sampling: fully batched via `MultivariateNormal.rsample((K, T))`
95+
- Action cost: matrix multiply K x T x nu @ nu x nu
96+
- KMPPI kernel interpolation: uses vmap for batch kernel solve
97+
- Cost weighting: vectorized exp + normalize
98+
99+
### Migration Notes
100+
- `functorch.vmap``torch.vmap` (available since PyTorch 2.0)
101+
- Consider whether `arm_pytorch_utilities.handle_batch_input` can be replaced with simpler reshape logic to reduce external dependencies
102+
- User dynamics/cost functions must also be compile-friendly for full graph compilation; consider documenting this requirement or providing a fallback path
103+
104+
## Development
105+
106+
```shell
107+
pip install -e . # Dev install
108+
pip install -e .[test] # With test deps
109+
KMP_DUPLICATE_LIB_OK=TRUE pytest tests/test_mppi.py tests/test_batch_wrapper.py -v # Run tests
110+
KMP_DUPLICATE_LIB_OK=TRUE python tests/benchmark_mppi.py # Run benchmarks (saves benchmark_results.json)
111+
```
112+
113+
Note: `KMP_DUPLICATE_LIB_OK=TRUE` is needed on this machine due to duplicate OpenMP libraries.
114+
115+
## Testing
116+
117+
### Test files
118+
- `tests/test_mppi.py` — 64 tests covering MPPI/SMPPI/KMPPI correctness and solution quality
119+
- `tests/test_batch_wrapper.py` — 2 tests for handle_batch_input decorator
120+
- `tests/benchmark_mppi.py` — Timing + solution quality benchmarks
121+
122+
### Test categories in test_mppi.py
123+
- **TestMPPI** (26): shapes, bounds, features, determinism, state handling
124+
- **TestSMPPI** (11): SMPPI-specific behavior, smoothness, action sequences
125+
- **TestKMPPI** (12): kernel interpolation, support points, stability
126+
- **TestSpecificActionSampler** (1): custom sampler integration
127+
- **TestEdgeCases** (5): numpy input, high-dim, 1-sample, float32
128+
- **TestSolutionQuality** (9): goal convergence, cost bounds, determinism, bounded actions
129+
130+
### Baseline performance (CPU, K=500, T=15)
131+
| Controller | per-command | 20-step loop |
132+
|---|---|---|
133+
| MPPI | 0.63ms | 12.9ms |
134+
| SMPPI | 0.68ms | 13.7ms |
135+
| KMPPI | 1.05ms | 20.5ms |
136+
137+
### Baseline solution quality (CPU, K=500, T=15, 20 steps, 5 trials)
138+
| Controller | Accum Cost | Final Dist | Control Smoothness |
139+
|---|---|---|---|
140+
| MPPI | 113.5 +/- 17.6 | 1.59 +/- 0.96 | 57.8 +/- 9.0 |
141+
| KMPPI | 111.0 +/- 12.2 | 1.61 +/- 0.58 | 25.9 +/- 3.2 |
142+
143+
Note: SMPPI quality is highly environment-dependent; it requires careful tuning (action bounds, terminal cost) per environment. KMPPI achieves similar cost to MPPI but with 2x smoother control.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "pytorch_mppi"
3-
version = "0.8.0"
3+
version = "0.9.0"
44
description = "Model Predictive Path Integral (MPPI) implemented in pytorch"
55
readme = "README.md" # Optional
66

src/pytorch_mppi/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from pytorch_mppi.mppi import MPPI, SMPPI, KMPPI
1+
from pytorch_mppi.mppi import MPPI, SMPPI, KMPPI, MPPI_Batched

0 commit comments

Comments
 (0)