Skip to content

Optimize MPPI hot path for GPU performance and add torch.compile support#9

Merged
LemonPi merged 7 commits intomasterfrom
optimize-mppi
Mar 12, 2026
Merged

Optimize MPPI hot path for GPU performance and add torch.compile support#9
LemonPi merged 7 commits intomasterfrom
optimize-mppi

Conversation

@LemonPi
Copy link
Copy Markdown
Member

@LemonPi LemonPi commented Mar 12, 2026

Summary

  • Optimize the MPPI hot path for ~1.5-2x speedup on CUDA (eager mode), with no breaking API changes
  • Add optional torch.compile support via mppi.compile() (beneficial at large K)
  • Add MPPI_Batched class for running N parallel environments in a single dynamics call
  • Add comprehensive test suite (75 tests) and benchmarks
  • Bump version to 0.9.0

Optimizations

  • Pre-allocate tensors instead of list + torch.stack in rollout loop
  • Direct Cholesky noise sampling replacing MultivariateNormal.rsample
  • Diagonal covariance detection at init — element-wise ops instead of matmul when sigma is diagonal (common case)
  • M=1 fast path — skip .repeat(), variance computation, and M dimension when rollout_samples=1 (default)
  • Lazy state/action storage — skip allocating (M, K, T, nx) tensors when no terminal_state_cost
  • torch.einsum for weighted perturbation sum (avoids materializing K×T×nu broadcast)
  • Bypass handle_batch_input in hot path — pre-flatten to (M*K, nx) and call dynamics/cost directly
  • Resolve all branches at init — bounds default to ±inf, dynamics/cost wrapped to uniform signature, action cost function pre-selected
  • Cache KMPPI kernel matrixKtktk computed once at init, not every call
  • functorch.vmaptorch.vmap, torch.inversetorch.linalg.inv

CUDA benchmark results (RTX 4070)

Config Before After Speedup
MPPI K=500 T=15 1.38ms 0.92ms 1.50x
MPPI K=5000 T=15 3.78ms 2.02ms 1.87x
MPPI 20-step loop 27.5ms 18.5ms 1.49x

Notes on torch.compile

The mppi.compile() API is provided for users with large batch sizes and compile-friendly dynamics/cost functions, but the gains are situational. At K=5000 compiled mode matches the eager speedup (~1.95x vs original), but at typical K=500 it actually adds overhead (0.88x) and hurts in multi-step loops (0.81x). This is because the sequential horizon loop (state[t+1] depends on state[t]) prevents meaningful graph fusion — compile just adds tracing and guard-check overhead. The bulk of the speedup comes from the eager-mode optimizations listed above.

New features

  • mppi.compile(**kwargs) — compiles dynamics/cost with torch.compile
  • MPPI_Batched — runs N environments with shared noise and a single (N*K, nx) dynamics call per timestep

Test plan

  • 75 unit tests passing on CPU and CUDA
  • Solution quality regression tests verify controllers still reach goals within bounds
  • Benchmarks track timing + solution quality (accumulated cost, final distance, control smoothness)
  • Verify downstream users are unaffected (no API changes to existing classes)

🤖 Generated with Claude Code

LemonPi and others added 7 commits March 10, 2026 14:05
…efactor

- tests/test_mppi.py: 64 tests covering MPPI/SMPPI/KMPPI correctness,
  feature coverage, edge cases, and solution quality regression guards
- tests/benchmark_mppi.py: timing and solution quality benchmarks across
  controllers, sample counts, horizons, and feature variations
- CLAUDE.md: project documentation with architecture, optimization plan,
  and baseline performance/quality numbers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…g, cache KMPPI kernel matrix

Phase 1-2 optimizations:
- Replace MultivariateNormal.rsample with direct randn @ cholesky + mu
- Replace torch.inverse with torch.linalg.inv
- Pre-allocate states/actions tensors instead of list + stack
- Use .expand() instead of .repeat() where possible
- Resolve step_dependency branching at init time
- Pre-compute variance discount factors for M>1
- Cache KMPPI Ktktk kernel matrix (constant across calls)
- Migrate functorch.vmap to torch.vmap with functools.partial for KMPPI
- Use torch.linalg.solve instead of inverse @ matmul in KMPPI interpolation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…torch.compile support

Phase 3 optimizations:
- Call dynamics/cost functions directly with pre-flattened (M*K, nx) tensors,
  bypassing handle_batch_input decorator overhead in the rollout loop
- Resolve _bound_action to always use torch.clamp with ±inf defaults (no branch)
- Resolve terminal_state_cost to zero-returning lambda when not provided
- Resolve SMPPI action_min/action_max with ±inf defaults
- Add compile() method that wraps dynamics/cost with torch.compile
- Add compile tests for MPPI and KMPPI
- Add compiled vs eager benchmarks

~10% eager speedup on both CPU and CUDA across all controllers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…azy storage

- Use torch.einsum for weighted perturbation sum (avoids broadcasting K×T×nu)
- Detect diagonal covariance at init for element-wise sampling and action cost
- Split _compute_rollout_costs into single/multi paths; M=1 skips repeat/variance
- Lazy state/action storage: skip allocation when no terminal_state_cost

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Concatenates N×K states into one (N*K, nx) batch per timestep, amortizing
kernel launch overhead on GPU. Includes diagonal sigma, einsum weighting,
and compile() support. Adds 6 tests for correctness, bounds, and compile.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Runs test_mppi.py and test_batch_wrapper.py on Python 3.9 and 3.12.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@LemonPi LemonPi merged commit 4f171f8 into master Mar 12, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant