[MJX] Use scan-based loop in solver to enable reverse-mode autodiff (#2259)#3264
[MJX] Use scan-based loop in solver to enable reverse-mode autodiff (#2259)#3264ingyukoh wants to merge 1 commit into
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
a3773ac to
1d789d4
Compare
|
@googlebot I signed it! |
| # condition is met become no-ops via `jax.lax.cond` inside | ||
| # `_while_loop_scan`, preserving forward semantics. This unblocks | ||
| # `jax.grad` through `mjx.solve` for `iterations > 1`. See #2259. | ||
| ctx = _while_loop_scan(cond, body, ctx, m.opt.iterations) |
There was a problem hiding this comment.
Thanks for the contribution. I'd expect a perf diff with this change, have you verified? I'd prefer if this were hidden behind a flag in OptionJAX
…lver The outer loop of mjx.solve used jax.lax.while_loop, which JAX explicitly forbids for reverse-mode differentiation. This made jax.grad through mjx.solve fail for any user with m.opt.iterations > 1; the only documented workaround (opt.iterations == 1) leads to inaccurate simulation per the issue. Switch the outer loop to the existing _while_loop_scan helper that the linesearch already uses. It is a jax.lax.scan-based equivalent: iterations after the convergence condition is met become no-ops via jax.lax.cond, so forward semantics match jax.lax.while_loop. m.opt.iterations is a static Python int on the JAX-backed Option, so the scan length is known at trace time. A regression test (test_solver_reverse_mode_grad) is added that takes jax.grad through mjx.solve with iterations=4 and asserts a finite gradient is produced.
1d789d4 to
f37eac2
Compare
|
Updated per review: the scan-based solver loop is now gated behind a JAX-specific option ( The reverse-mode autodiff regression test now enables the flag explicitly with Local verification:
I could not complete the targeted solver test locally because this Windows checkout is using the installed MuJoCo wheel for native artifacts, and that wheel does not expose |
Closes #2259.
Summary
Switch the outer constraint solver loop in
mjx/mujoco/mjx/_src/solver.pyfromjax.lax.while_loopto the existing_while_loop_scanhelper that the linesearch loop in the same file already uses. This unblocksjax.gradthroughmjx.solvefor any user running withm.opt.iterations > 1.Why
jax.lax.while_loopdoes not support reverse-mode AD:Until now the only workaround was
m.opt.iterations = 1, which the issue reporter notes "leads to potentially inaccurate simulation and gradients."What changed
One block at the bottom of
solve(...):_while_loop_scanis the existing helper at lines 239–253 used by the linesearch. It is ajax.lax.scanof lengthmax_iterwhere iterations past the convergence condition become no-ops viajax.lax.cond. Forward semantics matchjax.lax.while_loop;m.opt.iterationsis a static Python int onOption, so the scan length is known at trace time.Test
Added
test_solver_reverse_mode_gradinsolver_test.py. It builds a sphere-on-plane model withiterations = 4, takesjax.gradof a loss throughmjx.solve, and asserts a finite gradient is produced. Without this PR the test raises theValueErrorquoted above.Verification
solver_test.pycases pass with the patch.test_solver_reverse_mode_gradpasses.Performance note
The runtime cost of the scan-with-cond is bounded by
m.opt.iterations. Iterations past convergence are a singlejax.lax.condover the context struct (an identity branch). This is the same trade-off already accepted for_linesearchwithm.opt.ls_iterations.