Optimize MPPI hot path for GPU performance and add torch.compile support#9
Merged
Optimize MPPI hot path for GPU performance and add torch.compile support#9
Conversation
…efactor - tests/test_mppi.py: 64 tests covering MPPI/SMPPI/KMPPI correctness, feature coverage, edge cases, and solution quality regression guards - tests/benchmark_mppi.py: timing and solution quality benchmarks across controllers, sample counts, horizons, and feature variations - CLAUDE.md: project documentation with architecture, optimization plan, and baseline performance/quality numbers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…g, cache KMPPI kernel matrix Phase 1-2 optimizations: - Replace MultivariateNormal.rsample with direct randn @ cholesky + mu - Replace torch.inverse with torch.linalg.inv - Pre-allocate states/actions tensors instead of list + stack - Use .expand() instead of .repeat() where possible - Resolve step_dependency branching at init time - Pre-compute variance discount factors for M>1 - Cache KMPPI Ktktk kernel matrix (constant across calls) - Migrate functorch.vmap to torch.vmap with functools.partial for KMPPI - Use torch.linalg.solve instead of inverse @ matmul in KMPPI interpolation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…torch.compile support Phase 3 optimizations: - Call dynamics/cost functions directly with pre-flattened (M*K, nx) tensors, bypassing handle_batch_input decorator overhead in the rollout loop - Resolve _bound_action to always use torch.clamp with ±inf defaults (no branch) - Resolve terminal_state_cost to zero-returning lambda when not provided - Resolve SMPPI action_min/action_max with ±inf defaults - Add compile() method that wraps dynamics/cost with torch.compile - Add compile tests for MPPI and KMPPI - Add compiled vs eager benchmarks ~10% eager speedup on both CPU and CUDA across all controllers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…azy storage - Use torch.einsum for weighted perturbation sum (avoids broadcasting K×T×nu) - Detect diagonal covariance at init for element-wise sampling and action cost - Split _compute_rollout_costs into single/multi paths; M=1 skips repeat/variance - Lazy state/action storage: skip allocation when no terminal_state_cost Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Concatenates N×K states into one (N*K, nx) batch per timestep, amortizing kernel launch overhead on GPU. Includes diagonal sigma, einsum weighting, and compile() support. Adds 6 tests for correctness, bounds, and compile. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Runs test_mppi.py and test_batch_wrapper.py on Python 3.9 and 3.12. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
torch.compilesupport viamppi.compile()(beneficial at large K)MPPI_Batchedclass for running N parallel environments in a single dynamics callOptimizations
torch.stackin rollout loopMultivariateNormal.rsample.repeat(), variance computation, and M dimension whenrollout_samples=1(default)terminal_state_costtorch.einsumfor weighted perturbation sum (avoids materializing K×T×nu broadcast)handle_batch_inputin hot path — pre-flatten to (M*K, nx) and call dynamics/cost directlyKtktkcomputed once at init, not every callfunctorch.vmap→torch.vmap,torch.inverse→torch.linalg.invCUDA benchmark results (RTX 4070)
Notes on
torch.compileThe
mppi.compile()API is provided for users with large batch sizes and compile-friendly dynamics/cost functions, but the gains are situational. At K=5000 compiled mode matches the eager speedup (~1.95x vs original), but at typical K=500 it actually adds overhead (0.88x) and hurts in multi-step loops (0.81x). This is because the sequential horizon loop (state[t+1] depends on state[t]) prevents meaningful graph fusion — compile just adds tracing and guard-check overhead. The bulk of the speedup comes from the eager-mode optimizations listed above.New features
mppi.compile(**kwargs)— compiles dynamics/cost withtorch.compileMPPI_Batched— runs N environments with shared noise and a single (N*K, nx) dynamics call per timestepTest plan
🤖 Generated with Claude Code