Skip to content

[MJX] Use scan-based loop in solver to enable reverse-mode autodiff (#2259)#3264

Open
ingyukoh wants to merge 1 commit into
google-deepmind:mainfrom
ingyukoh:fix-mjx-solver-grad
Open

[MJX] Use scan-based loop in solver to enable reverse-mode autodiff (#2259)#3264
ingyukoh wants to merge 1 commit into
google-deepmind:mainfrom
ingyukoh:fix-mjx-solver-grad

Conversation

@ingyukoh
Copy link
Copy Markdown

@ingyukoh ingyukoh commented May 8, 2026

Closes #2259.

Summary

Switch the outer constraint solver loop in mjx/mujoco/mjx/_src/solver.py from jax.lax.while_loop to the existing _while_loop_scan helper that the linesearch loop in the same file already uses. This unblocks jax.grad through mjx.solve for any user running with m.opt.iterations > 1.

Why

jax.lax.while_loop does not support reverse-mode AD:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

Until now the only workaround was m.opt.iterations = 1, which the issue reporter notes "leads to potentially inaccurate simulation and gradients."

What changed

One block at the bottom of solve(...):

ctx = Context.create(m, d)
if m.opt.iterations == 1:
  ctx = body(ctx)
else:
  # was: ctx = jax.lax.while_loop(cond, body, ctx)
  ctx = _while_loop_scan(cond, body, ctx, m.opt.iterations)

_while_loop_scan is the existing helper at lines 239–253 used by the linesearch. It is a jax.lax.scan of length max_iter where iterations past the convergence condition become no-ops via jax.lax.cond. Forward semantics match jax.lax.while_loop; m.opt.iterations is a static Python int on Option, so the scan length is known at trace time.

Test

Added test_solver_reverse_mode_grad in solver_test.py. It builds a sphere-on-plane model with iterations = 4, takes jax.grad of a loss through mjx.solve, and asserts a finite gradient is produced. Without this PR the test raises the ValueError quoted above.

Verification

  • All 13 existing solver_test.py cases pass with the patch.
  • The new test_solver_reverse_mode_grad passes.
  • Reverting the patch reproduces the original error.

Performance note

The runtime cost of the scan-with-cond is bounded by m.opt.iterations. Iterations past convergence are a single jax.lax.cond over the context struct (an identity branch). This is the same trade-off already accepted for _linesearch with m.opt.ls_iterations.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 8, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@ingyukoh ingyukoh force-pushed the fix-mjx-solver-grad branch from a3773ac to 1d789d4 Compare May 8, 2026 07:36
@ingyukoh
Copy link
Copy Markdown
Author

ingyukoh commented May 8, 2026

@googlebot I signed it!

Copy link
Copy Markdown
Collaborator

@btaba btaba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

# condition is met become no-ops via `jax.lax.cond` inside
# `_while_loop_scan`, preserving forward semantics. This unblocks
# `jax.grad` through `mjx.solve` for `iterations > 1`. See #2259.
ctx = _while_loop_scan(cond, body, ctx, m.opt.iterations)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. I'd expect a perf diff with this change, have you verified? I'd prefer if this were hidden behind a flag in OptionJAX

…lver

The outer loop of mjx.solve used jax.lax.while_loop, which JAX explicitly
forbids for reverse-mode differentiation. This made jax.grad through
mjx.solve fail for any user with m.opt.iterations > 1; the only documented
workaround (opt.iterations == 1) leads to inaccurate simulation per the
issue.

Switch the outer loop to the existing _while_loop_scan helper that the
linesearch already uses. It is a jax.lax.scan-based equivalent: iterations
after the convergence condition is met become no-ops via jax.lax.cond, so
forward semantics match jax.lax.while_loop. m.opt.iterations is a static
Python int on the JAX-backed Option, so the scan length is known at trace
time.

A regression test (test_solver_reverse_mode_grad) is added that takes
jax.grad through mjx.solve with iterations=4 and asserts a finite gradient
is produced.
@ingyukoh ingyukoh force-pushed the fix-mjx-solver-grad branch from 1d789d4 to f37eac2 Compare May 8, 2026 21:49
@ingyukoh
Copy link
Copy Markdown
Author

ingyukoh commented May 8, 2026

Updated per review: the scan-based solver loop is now gated behind a JAX-specific option (OptionJAX.solver_scan) and remains disabled by default, so the existing jax.lax.while_loop path and its performance characteristics are preserved unless users explicitly opt in.

The reverse-mode autodiff regression test now enables the flag explicitly with mx.tree_replace({'opt._impl.solver_scan': True}).

Local verification:

  • python -m py_compile mjx\mujoco\mjx\_src\types.py mjx\mujoco\mjx\_src\io.py mjx\mujoco\mjx\_src\solver.py mjx\mujoco\mjx\_src\solver_test.py

I could not complete the targeted solver test locally because this Windows checkout is using the installed MuJoCo wheel for native artifacts, and that wheel does not expose mju_sym2dense, which this branch's MJX io.py expects. The test reaches mjx.put_data before hitting that native-extension mismatch.

Copy link
Copy Markdown
Collaborator

@btaba btaba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients

2 participants