Skip to content

Commit b85c233

Browse files
Add Jacobian reuse for Rosenbrock-W methods
W-methods guarantee correctness with a stale Jacobian, but the existing code recomputes J every accepted step. This adds CVODE-inspired reuse logic that skips Jacobian recomputation when conditions allow it, reducing the dominant cost for large stiff systems. Reuse strategy (W-methods only): - Recompute J on first iteration, error test failure, callback modification, gamma ratio change > 30%, or every 50 accepted steps - W is always rebuilt since it depends on current dt - Strict Rosenbrock methods are unchanged (always recompute J) Changes: - Add JacReuseState struct and reuse decision logic to derivative_utils.jl - Add jac_reuse field to all Rosenbrock mutable caches (hand-written and macro-generated) - Wire reuse decision into calc_rosenbrock_differentiation! - Add comprehensive test suite for convergence, Jacobian savings, and benchmark accuracy Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e52f9ad commit b85c233

7 files changed

Lines changed: 346 additions & 20 deletions

File tree

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
4141
isnewton, _unwrap_val,
4242
set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir,
4343
get_W, isfirstcall, isfirststage, isJcurrent,
44-
get_new_W_γdt_cutoff,
44+
get_new_W_γdt_cutoff, isWmethod,
4545
TryAgain, DIRK, COEFFICIENT_MULTISTEP, NORDSIECK_MULTISTEP, GLM,
4646
FastConvergence, Convergence, SlowConvergence,
4747
VerySlowConvergence, Divergence, NLStatus, MethodType, constvalue, @SciMLMessage

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,91 @@
11
using SciMLOperators: StaticWOperator, WOperator
22

3+
"""
4+
JacReuseState{T}
5+
6+
Lightweight mutable state for tracking Jacobian reuse in Rosenbrock-W methods.
7+
W-methods guarantee correctness with a stale Jacobian, so we can skip expensive
8+
Jacobian recomputations when conditions allow it.
9+
10+
Fields:
11+
- `last_dtgamma`: The dtgamma value from the last Jacobian computation
12+
- `steps_since_jac`: Number of accepted steps since last Jacobian update
13+
- `max_jac_age`: Maximum number of accepted steps between Jacobian updates (default 50)
14+
"""
15+
mutable struct JacReuseState{T}
16+
last_dtgamma::T
17+
steps_since_jac::Int
18+
max_jac_age::Int
19+
end
20+
21+
JacReuseState(dtgamma::T) where {T} = JacReuseState{T}(dtgamma, 0, 50)
22+
23+
"""
24+
get_jac_reuse(cache)
25+
26+
Duck-typed accessor for the `jac_reuse` field. Returns `nothing` if the cache
27+
does not have a `jac_reuse` field.
28+
"""
29+
get_jac_reuse(cache) = hasproperty(cache, :jac_reuse) ? cache.jac_reuse : nothing
30+
31+
"""
32+
_rosenbrock_jac_reuse_decision(integrator, cache, dtgamma) -> (new_jac, new_W)
33+
34+
Decide whether to recompute the Jacobian and/or W matrix for Rosenbrock methods.
35+
For W-methods (where `isWmethod(alg) == true`), implements CVODE-inspired reuse:
36+
- Always recompute on first iteration
37+
- Recompute J on error test failure (EEst > 1)
38+
- Recompute when gamma ratio changes too much: |dtgamma/last_dtgamma - 1| > 0.3
39+
- Recompute every `max_jac_age` accepted steps (default 50)
40+
- Recompute when u_modified (callback modification)
41+
- W is always rebuilt (since W = J - M/(dt*gamma) depends on current dt)
42+
43+
For strict Rosenbrock methods, always returns (true, true).
44+
"""
45+
function _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
46+
alg = OrdinaryDiffEqCore.unwrap_alg(integrator, true)
47+
48+
# Non-W-methods always recompute
49+
if !isWmethod(alg)
50+
return (true, true)
51+
end
52+
53+
jac_reuse = get_jac_reuse(cache)
54+
# If no reuse state (e.g. OOP cache), always recompute
55+
if jac_reuse === nothing
56+
return (true, true)
57+
end
58+
59+
# First iteration: always compute
60+
if integrator.iter <= 1
61+
return (true, true)
62+
end
63+
64+
# Callback modification: recompute
65+
if integrator.u_modified
66+
return (true, true)
67+
end
68+
69+
# Error test failure: recompute
70+
if integrator.EEst > one(integrator.EEst)
71+
return (true, true)
72+
end
73+
74+
# Gamma ratio check
75+
last_dtg = jac_reuse.last_dtgamma
76+
if !iszero(last_dtg) && abs(dtgamma / last_dtg - 1) > 0.3
77+
return (true, true)
78+
end
79+
80+
# Step counter check
81+
if jac_reuse.steps_since_jac >= jac_reuse.max_jac_age
82+
return (true, true)
83+
end
84+
85+
# Reuse J, but always rebuild W (since dtgamma changes)
86+
return (false, true)
87+
end
88+
389
function calc_tderivative!(integrator, cache, dtd1, repeat_step)
490
return @inbounds begin
591
(; t, dt, uprev, u, f, p) = integrator
@@ -689,9 +775,21 @@ function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repe
689775
# we need to skip calculating `J` and `W` when a step is repeated
690776
new_jac = new_W = false
691777
if !repeat_step
778+
# For W-methods, use reuse logic; for strict Rosenbrock, always recompute
779+
newJW = _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
692780
new_jac, new_W = calc_W!(
693-
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step
781+
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step, newJW
694782
)
783+
# Update reuse state after W computation
784+
jac_reuse = get_jac_reuse(cache)
785+
if jac_reuse !== nothing
786+
if new_jac
787+
jac_reuse.last_dtgamma = dtgamma
788+
jac_reuse.steps_since_jac = 0
789+
else
790+
jac_reuse.steps_since_jac += 1
791+
end
792+
end
695793
end
696794
# If the Jacobian is not updated, we won't have to update ∂/∂t either.
697795
calc_tderivative!(integrator, cache, dtd1, repeat_step || !new_jac)

lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ using OrdinaryDiffEqDifferentiation: TimeDerivativeWrapper, TimeGradientWrapper,
3333
build_jac_config, issuccess_W, jacobian2W!,
3434
resize_jac_config!, resize_grad_config!,
3535
calc_W, calc_rosenbrock_differentiation!, build_J_W,
36-
UJacobianWrapper, dolinsolve, WOperator, resize_J_W!
36+
UJacobianWrapper, dolinsolve, WOperator, resize_J_W!,
37+
JacReuseState
3738

3839
using Reexport
3940
@reexport using SciMLBase

lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ function gen_cache_struct(tab::RosenbrockTableau,cachename::Symbol,constcachenam
176176
end
177177
end
178178
cacheexpr=quote
179-
@cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: GenericRosenbrockMutableCache
179+
@cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType,JRType} <: GenericRosenbrockMutableCache
180180
u::uType
181181
uprev::uType
182182
du::rateType
@@ -198,6 +198,7 @@ function gen_cache_struct(tab::RosenbrockTableau,cachename::Symbol,constcachenam
198198
linsolve::F
199199
jac_config::JCType
200200
grad_config::GCType
201+
jac_reuse::JRType
201202
end
202203
end
203204
constcacheexpr,cacheexpr
@@ -246,7 +247,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab
246247
tf = TimeGradientWrapper(f,uprev,p)
247248
uf = UJacobianWrapper(f,t,p)
248249
linsolve_tmp = zero(rate_prototype)
249-
250+
250251
grad_config = build_grad_config(alg,f,tf,du1,t)
251252
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2)
252253
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
@@ -255,7 +256,8 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab
255256
linsolve = init(linprob,alg.linsolve,alias = LinearAliasSpecifier(alias_A=true,alias_b=true),
256257
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
257258
Pr = Diagonal(_vec(weight)),
258-
verbose = verbose.linear_verbosity)
259+
verbose = verbose.linear_verbosity)
260+
jac_reuse = JacReuseState(zero(dt))
259261
$cachename($(valsyms...))
260262
end
261263
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414

1515
mutable struct RosenbrockCache{
1616
uType, rateType, tabType, uNoUnitsType, JType, WType, TabType,
17-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter,
17+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter, JRType,
1818
} <:
1919
RosenbrockMutableCache
2020
u::uType
@@ -46,6 +46,7 @@ mutable struct RosenbrockCache{
4646
step_limiter!::StepLimiter
4747
stage_limiter!::StageLimiter
4848
interp_order::Int
49+
jac_reuse::JRType
4950
end
5051
function full_cache(c::RosenbrockCache)
5152
return [
@@ -69,7 +70,7 @@ end
6970
@cache mutable struct Rosenbrock23Cache{
7071
uType, rateType, uNoUnitsType, JType, WType,
7172
TabType, TFType, UFType, F, JCType, GCType,
72-
RTolType, A, AV, StepLimiter, StageLimiter,
73+
RTolType, A, AV, StepLimiter, StageLimiter, JRType,
7374
} <: RosenbrockMutableCache
7475
u::uType
7576
uprev::uType
@@ -99,12 +100,13 @@ end
99100
algebraic_vars::AV
100101
step_limiter!::StepLimiter
101102
stage_limiter!::StageLimiter
103+
jac_reuse::JRType
102104
end
103105

104106
@cache mutable struct Rosenbrock32Cache{
105107
uType, rateType, uNoUnitsType, JType, WType,
106108
TabType, TFType, UFType, F, JCType, GCType,
107-
RTolType, A, AV, StepLimiter, StageLimiter,
109+
RTolType, A, AV, StepLimiter, StageLimiter, JRType,
108110
} <: RosenbrockMutableCache
109111
u::uType
110112
uprev::uType
@@ -134,6 +136,7 @@ end
134136
algebraic_vars::AV
135137
step_limiter!::StepLimiter
136138
stage_limiter!::StageLimiter
139+
jac_reuse::JRType
137140
end
138141

139142
function alg_cache(
@@ -185,12 +188,14 @@ function alg_cache(
185188
algebraic_vars = f.mass_matrix === I ? nothing :
186189
[all(iszero, x) for x in eachcol(f.mass_matrix)]
187190

191+
jac_reuse = JacReuseState(zero(dt))
192+
188193
return Rosenbrock23Cache(
189194
u, uprev, k₁, k₂, k₃, du1, du2, f₁,
190195
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
191196
linsolve_tmp,
192197
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
193-
alg.stage_limiter!
198+
alg.stage_limiter!, jac_reuse
194199
)
195200
end
196201

@@ -244,10 +249,13 @@ function alg_cache(
244249
algebraic_vars = f.mass_matrix === I ? nothing :
245250
[all(iszero, x) for x in eachcol(f.mass_matrix)]
246251

252+
jac_reuse = JacReuseState(zero(dt))
253+
247254
return Rosenbrock32Cache(
248255
u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
249256
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
250-
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!
257+
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!,
258+
jac_reuse
251259
)
252260
end
253261

@@ -336,7 +344,7 @@ end
336344
@cache mutable struct Rosenbrock33Cache{
337345
uType, rateType, uNoUnitsType, JType, WType,
338346
TabType, TFType, UFType, F, JCType, GCType,
339-
RTolType, A, StepLimiter, StageLimiter,
347+
RTolType, A, StepLimiter, StageLimiter, JRType,
340348
} <: RosenbrockMutableCache
341349
u::uType
342350
uprev::uType
@@ -366,6 +374,7 @@ end
366374
alg::A
367375
step_limiter!::StepLimiter
368376
stage_limiter!::StageLimiter
377+
jac_reuse::JRType
369378
end
370379

371380
function alg_cache(
@@ -412,12 +421,14 @@ function alg_cache(
412421
verbose = verbose.linear_verbosity
413422
)
414423

424+
jac_reuse = JacReuseState(zero(dt))
425+
415426
return Rosenbrock33Cache(
416427
u, uprev, du, du1, du2, k1, k2, k3, k4,
417428
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
418429
linsolve_tmp,
419430
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
420-
alg.stage_limiter!
431+
alg.stage_limiter!, jac_reuse
421432
)
422433
end
423434

@@ -443,7 +454,7 @@ end
443454

444455
@cache mutable struct Rosenbrock34Cache{
445456
uType, rateType, uNoUnitsType, JType, WType,
446-
TabType, TFType, UFType, F, JCType, GCType, StepLimiter, StageLimiter,
457+
TabType, TFType, UFType, F, JCType, GCType, StepLimiter, StageLimiter, JRType,
447458
} <:
448459
RosenbrockMutableCache
449460
u::uType
@@ -472,6 +483,7 @@ end
472483
grad_config::GCType
473484
step_limiter!::StepLimiter
474485
stage_limiter!::StageLimiter
486+
jac_reuse::JRType
475487
end
476488

477489
function alg_cache(
@@ -520,12 +532,14 @@ function alg_cache(
520532
verbose = verbose.linear_verbosity
521533
)
522534

535+
jac_reuse = JacReuseState(zero(dt))
536+
523537
return Rosenbrock34Cache(
524538
u, uprev, du, du1, du2, k1, k2, k3, k4,
525539
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
526540
linsolve_tmp,
527541
linsolve, jac_config, grad_config, alg.step_limiter!,
528-
alg.stage_limiter!
542+
alg.stage_limiter!, jac_reuse
529543
)
530544
end
531545

@@ -611,7 +625,7 @@ end
611625

612626
@cache mutable struct Rodas23WCache{
613627
uType, rateType, uNoUnitsType, JType, WType, TabType,
614-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter,
628+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter, JRType,
615629
} <:
616630
RosenbrockMutableCache
617631
u::uType
@@ -646,11 +660,12 @@ end
646660
alg::A
647661
step_limiter!::StepLimiter
648662
stage_limiter!::StageLimiter
663+
jac_reuse::JRType
649664
end
650665

651666
@cache mutable struct Rodas3PCache{
652667
uType, rateType, uNoUnitsType, JType, WType, TabType,
653-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter,
668+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter, JRType,
654669
} <:
655670
RosenbrockMutableCache
656671
u::uType
@@ -685,6 +700,7 @@ end
685700
alg::A
686701
step_limiter!::StepLimiter
687702
stage_limiter!::StageLimiter
703+
jac_reuse::JRType
688704
end
689705

690706
function alg_cache(
@@ -737,11 +753,13 @@ function alg_cache(
737753
verbose = verbose.linear_verbosity
738754
)
739755

756+
jac_reuse = JacReuseState(zero(dt))
757+
740758
return Rodas23WCache(
741759
u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
742760
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
743761
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
744-
alg.stage_limiter!
762+
alg.stage_limiter!, jac_reuse
745763
)
746764
end
747765

@@ -795,11 +813,13 @@ function alg_cache(
795813
verbose = verbose.linear_verbosity
796814
)
797815

816+
jac_reuse = JacReuseState(zero(dt))
817+
798818
return Rodas3PCache(
799819
u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
800820
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
801821
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
802-
alg.stage_limiter!
822+
alg.stage_limiter!, jac_reuse
803823
)
804824
end
805825

@@ -934,12 +954,14 @@ function alg_cache(
934954
)
935955

936956

957+
jac_reuse = JacReuseState(zero(dt))
958+
937959
# Return the cache struct with vectors
938960
return RosenbrockCache(
939961
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,
940962
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
941963
linsolve, jac_config, grad_config, reltol, alg,
942-
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1)
964+
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1), jac_reuse
943965
)
944966
end
945967

0 commit comments

Comments
 (0)