Skip to content

Commit 9e505f1

Browse files
Skip J reuse logic entirely for non-W-methods (IIP + OOP)
Non-W-methods (strict Rosenbrock like Rodas5P) always recompute J and W. Bypass the reuse decision, newJW hint, and jac_reuse bookkeeping entirely for these methods. The jac_reuse mutations (pending_dtgamma, last_step_iter) during the forward pass interfere with Enzyme + Krylov solvers, causing Rodas5P to lose 5th-order convergence (observed order 1.74). IIP path: use original calc_W! without newJW hint for non-W-methods. OOP path: early return to standard calc_tderivative + calc_W path. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 97e05de commit 9e505f1

1 file changed

Lines changed: 23 additions & 18 deletions

File tree

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -830,26 +830,29 @@ end
830830

831831
function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
832832
nlsolver = nothing
833+
alg = OrdinaryDiffEqCore.unwrap_alg(integrator, true)
833834
# we need to skip calculating `J` and `W` when a step is repeated
834835
new_jac = new_W = false
835836
if !repeat_step
836-
# For W-methods, use reuse logic; for strict Rosenbrock, always recompute
837-
newJW = _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
838-
new_jac,
839-
new_W = calc_W!(
840-
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step, newJW
841-
)
842-
# Record pending dtgamma only when J was freshly computed; it will be
843-
# committed as last_dtgamma when the step is accepted (checked in
844-
# _rosenbrock_jac_reuse_decision). This tracks the dtgamma at the
845-
# last J computation for the gamma ratio heuristic.
846-
jac_reuse = get_jac_reuse(cache)
847-
if jac_reuse !== nothing
848-
jac_reuse.last_step_iter = integrator.iter
849-
if new_jac
850-
jac_reuse.pending_dtgamma = _jac_reuse_value(dtgamma)
851-
jac_reuse.last_u_length = length(integrator.u)
837+
if isWmethod(alg)
838+
# W-methods: use reuse logic to skip expensive J recomputations
839+
newJW = _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
840+
new_jac, new_W = calc_W!(
841+
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step, newJW
842+
)
843+
jac_reuse = get_jac_reuse(cache)
844+
if jac_reuse !== nothing
845+
jac_reuse.last_step_iter = integrator.iter
846+
if new_jac
847+
jac_reuse.pending_dtgamma = _jac_reuse_value(dtgamma)
848+
jac_reuse.last_u_length = length(integrator.u)
849+
end
852850
end
851+
else
852+
# Strict Rosenbrock: use original calc_W! path (no reuse, no bookkeeping)
853+
new_jac, new_W = calc_W!(
854+
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step
855+
)
853856
end
854857
end
855858
# If the Jacobian is not updated, we won't have to update ∂/∂t either.
@@ -867,8 +870,10 @@ system matrix. Supports Jacobian reuse for W-methods via `jac_reuse` in the cach
867870
function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step)
868871
jac_reuse = get_jac_reuse(cache)
869872

870-
# If no reuse support or repeat step, use standard path
871-
if repeat_step || jac_reuse === nothing
873+
alg = OrdinaryDiffEqCore.unwrap_alg(integrator, true)
874+
# Non-W-methods always recompute; skip caching to avoid extra calc_J calls
875+
# that can corrupt shared AD state (e.g. Enzyme + Krylov).
876+
if repeat_step || jac_reuse === nothing || !isWmethod(alg)
872877
dT = calc_tderivative(integrator, cache)
873878
W = calc_W(integrator, cache, dtgamma, repeat_step)
874879
return dT, W

0 commit comments

Comments
 (0)