The feature, motivation and pitch
Problem
The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.
Error encountered with jax.jit compiled grad function:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.
Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.
Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.
Alternatives
No response
Additional context
No response
The feature, motivation and pitch
Problem
The solver's
jax.lax.while_loopimplementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.Error encountered with
jax.jitcompiled grad function:Current workaround of using
opt.iteration=1leads to potentially inaccurate simulation and gradients.Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either
lax.scanorlax.fori_loopwith static bounds.Alternatives
No response
Additional context
No response