Skip to content

Commit 1aafe65

Browse files
Add Jacobian reuse for Rosenbrock-W methods
Implement CVODE-inspired Jacobian reuse heuristics for Rosenbrock-W methods (Rosenbrock23, Rosenbrock32, Rodas23W, ROS2S, ROS34PW1a, ROS34PW2, ROS34PW3, ROK4a). W-methods guarantee convergence order even with inexact Jacobians, allowing safe reuse across multiple steps. Key changes: - Add `isWmethod` trait to distinguish W-methods from strict Rosenbrock - Add `JacReuseState` struct tracking Jacobian age and step size ratios - Add `_rosenbrock_jac_reuse_decision` implementing reuse heuristics: gamma ratio threshold (30%), max age (50 steps), Newton convergence - Integrate reuse into both IIP and OOP `calc_rosenbrock_differentiation` - Deferred dtgamma commit: pending value only committed on step accept - Add comprehensive test suite (convergence, accuracy, J count reduction) Closes #1043 Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 30f97be commit 1aafe65

15 files changed

Lines changed: 955 additions & 201 deletions

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
2525
using DiffEqBase: TimeGradientWrapper,
2626
UJacobianWrapper, TimeDerivativeWrapper,
2727
UDerivativeWrapper
28-
import SciMLBase: SciMLBase, constructorof, @set, isinplace, has_jvp, unwrapped_f, DEIntegrator, ODEFunction, SplitFunction, DynamicalODEFunction, DAEFunction, islinear, remake, solve!, isconstant
28+
import SciMLBase: SciMLBase, constructorof, @set, isinplace, has_jvp, unwrapped_f,
29+
DEIntegrator, ODEFunction, SplitFunction, DynamicalODEFunction,
30+
DAEFunction, islinear, remake, solve!, isconstant
2931
using SciMLBase: @set, @reset
30-
import SciMLOperators: SciMLOperators, IdentityOperator, update_coefficients, update_coefficients!, MatrixOperator, AbstractSciMLOperator, ScalarOperator
31-
import SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm, ColoringProblem,
32+
import SciMLOperators: SciMLOperators, IdentityOperator, update_coefficients,
33+
update_coefficients!, MatrixOperator, AbstractSciMLOperator,
34+
ScalarOperator
35+
import SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm,
36+
ColoringProblem,
3237
ncolors, column_colors, coloring, sparsity_pattern
3338
import OrdinaryDiffEqCore
3439
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
@@ -41,10 +46,11 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
4146
isnewton, _unwrap_val,
4247
set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir,
4348
get_W, isfirstcall, isfirststage, isJcurrent,
44-
get_new_W_γdt_cutoff,
49+
get_new_W_γdt_cutoff, isWmethod,
4550
TryAgain, DIRK, COEFFICIENT_MULTISTEP, NORDSIECK_MULTISTEP, GLM,
4651
FastConvergence, Convergence, SlowConvergence,
47-
VerySlowConvergence, Divergence, NLStatus, MethodType, constvalue, @SciMLMessage
52+
VerySlowConvergence, Divergence, NLStatus, MethodType, constvalue,
53+
@SciMLMessage
4854

4955
import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_autodiff,
5056
_get_fwd_tag
@@ -79,10 +85,18 @@ function get_nzval end
7985
function set_all_nzval! end
8086

8187
# Provide error messages if these are called without extension
82-
nonzeros(A) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
83-
spzeros(args...) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
84-
get_nzval(A) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
85-
set_all_nzval!(A, val) = error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
88+
function nonzeros(A)
89+
error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
90+
end
91+
function spzeros(args...)
92+
error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
93+
end
94+
function get_nzval(A)
95+
error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
96+
end
97+
function set_all_nzval!(A, val)
98+
error("SparseArrays extension not loaded. Please load SparseArrays to use sparse matrix functionality.")
99+
end
86100

87101
include("alg_utils.jl")
88102
include("linsolve_utils.jl")

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 176 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,89 @@
11
using SciMLOperators: StaticWOperator, WOperator
22

3+
"""
4+
get_jac_reuse(cache)
5+
6+
Duck-typed accessor for the `jac_reuse` field. Returns `nothing` if the cache
7+
does not have a `jac_reuse` field.
8+
"""
9+
get_jac_reuse(cache) = hasproperty(cache, :jac_reuse) ? cache.jac_reuse : nothing
10+
11+
"""
12+
_rosenbrock_jac_reuse_decision(integrator, cache, dtgamma) -> Union{Nothing, NTuple{2,Bool}}
13+
14+
Decide whether to recompute the Jacobian and/or W matrix for Rosenbrock methods.
15+
For W-methods (where `isWmethod(alg) == true`), implements CVODE-inspired reuse:
16+
- Always recompute on first iteration
17+
- Recompute J on error test failure (EEst > 1)
18+
- Recompute when gamma ratio changes too much: |dtgamma/last_dtgamma - 1| > 0.3
19+
- Recompute every `max_jac_age` accepted steps (default 50)
20+
- Recompute when u_modified (callback modification)
21+
- W is always rebuilt (since W = J - M/(dt*gamma) depends on current dt)
22+
23+
For strict Rosenbrock methods, returns `nothing` to delegate to `do_newJW`.
24+
"""
25+
function _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
26+
alg = OrdinaryDiffEqCore.unwrap_alg(integrator, true)
27+
28+
# Non-W-methods: delegate to do_newJW (preserves linear problem optimization etc.)
29+
if !isWmethod(alg)
30+
return nothing
31+
end
32+
33+
jac_reuse = get_jac_reuse(cache)
34+
# If no reuse state (e.g. OOP cache without jac_reuse), delegate to do_newJW
35+
if jac_reuse === nothing
36+
return nothing
37+
end
38+
39+
# Linear problems: delegate to do_newJW (which returns (false, false) for islin)
40+
islin, _ = islinearfunction(integrator)
41+
if islin
42+
return nothing
43+
end
44+
45+
# First iteration: always compute J and W.
46+
if integrator.iter <= 1
47+
return (true, true)
48+
end
49+
50+
# Commit pending_dtgamma from previous step if it was accepted.
51+
# This ensures rejected steps don't pollute last_dtgamma, keeping
52+
# IIP-adaptive and OOP-non-adaptive reuse decisions synchronized.
53+
naccept = integrator.stats.naccept
54+
if naccept > jac_reuse.last_naccept
55+
jac_reuse.last_dtgamma = jac_reuse.pending_dtgamma
56+
jac_reuse.last_naccept = naccept
57+
end
58+
59+
# Fresh cache (e.g., algorithm switch where iter > 1 but the Rosenbrock
60+
# cache is freshly created with cached_J = nothing).
61+
if iszero(jac_reuse.last_dtgamma)
62+
return (true, true)
63+
end
64+
65+
# Callback modification: recompute
66+
if integrator.u_modified
67+
return (true, true)
68+
end
69+
70+
# Gamma ratio check (uses only accepted-step dtgamma)
71+
last_dtg = jac_reuse.last_dtgamma
72+
if !iszero(last_dtg) && abs(dtgamma / last_dtg - 1) > 0.3
73+
return (true, true)
74+
end
75+
76+
# Age check: recompute J after max_jac_age accepted steps.
77+
# Uses naccept (not a local counter) so rejected steps don't desynchronize
78+
# IIP-adaptive and OOP-non-adaptive solves.
79+
if (naccept - jac_reuse.last_naccept) >= jac_reuse.max_jac_age
80+
return (true, true)
81+
end
82+
83+
# Reuse J, but always rebuild W (since dtgamma changes)
84+
return (false, true)
85+
end
86+
387
function calc_tderivative!(integrator, cache, dtd1, repeat_step)
488
return @inbounds begin
589
(; t, dt, uprev, u, f, p) = integrator
@@ -689,15 +773,101 @@ function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repe
689773
# we need to skip calculating `J` and `W` when a step is repeated
690774
new_jac = new_W = false
691775
if !repeat_step
692-
new_jac, new_W = calc_W!(
693-
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step
776+
# For W-methods, use reuse logic; for strict Rosenbrock, always recompute
777+
newJW = _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
778+
new_jac,
779+
new_W = calc_W!(
780+
cache.W, integrator, nlsolver, cache, dtgamma, repeat_step, newJW
694781
)
782+
# Record pending dtgamma only when J was freshly computed; it will be
783+
# committed as last_dtgamma when the step is accepted (checked in
784+
# _rosenbrock_jac_reuse_decision). This tracks the dtgamma at the
785+
# last J computation for the gamma ratio heuristic.
786+
jac_reuse = get_jac_reuse(cache)
787+
if jac_reuse !== nothing && new_jac
788+
jac_reuse.pending_dtgamma = dtgamma
789+
end
695790
end
696791
# If the Jacobian is not updated, we won't have to update ∂/∂t either.
697792
calc_tderivative!(integrator, cache, dtd1, repeat_step || !new_jac)
698793
return new_W
699794
end
700795

796+
"""
797+
calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step)
798+
799+
Non-mutating (OOP) version of `calc_rosenbrock_differentiation!`.
800+
Returns `(dT, W)` where `dT` is the time derivative and `W` is the factorized
801+
system matrix. Supports Jacobian reuse for W-methods via `jac_reuse` in the cache.
802+
"""
803+
function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step)
804+
jac_reuse = get_jac_reuse(cache)
805+
806+
# If no reuse support or repeat step, use standard path
807+
if repeat_step || jac_reuse === nothing
808+
dT = calc_tderivative(integrator, cache)
809+
W = calc_W(integrator, cache, dtgamma, repeat_step)
810+
return dT, W
811+
end
812+
813+
newJW = _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
814+
815+
if newJW === nothing
816+
# Delegate to standard path (linear problems, non-W-methods, etc.)
817+
dT = calc_tderivative(integrator, cache)
818+
W = calc_W(integrator, cache, dtgamma, repeat_step)
819+
return dT, W
820+
end
821+
822+
new_jac, _ = newJW
823+
824+
# For complex W types (operators), delegate to standard calc_W
825+
if cache.W isa StaticWOperator || cache.W isa WOperator ||
826+
cache.W isa AbstractSciMLOperator
827+
dT = calc_tderivative(integrator, cache)
828+
W = calc_W(integrator, cache, dtgamma, repeat_step)
829+
jac_reuse.pending_dtgamma = dtgamma
830+
return dT, W
831+
end
832+
833+
mass_matrix = integrator.f.mass_matrix
834+
update_coefficients!(mass_matrix, integrator.uprev, integrator.p, integrator.t)
835+
836+
# Safety: if cached_J is nothing (e.g. first use after algorithm switch),
837+
# force a fresh computation regardless of the decision.
838+
if !new_jac && jac_reuse.cached_J === nothing
839+
new_jac = true
840+
end
841+
842+
if new_jac
843+
J = calc_J(integrator, cache)
844+
dT = calc_tderivative(integrator, cache)
845+
846+
# Cache for future reuse
847+
jac_reuse.cached_J = J
848+
jac_reuse.cached_dT = dT
849+
else
850+
# Reuse cached J and dT
851+
J = jac_reuse.cached_J
852+
dT = jac_reuse.cached_dT
853+
end
854+
855+
# Record pending dtgamma only when J was freshly computed;
856+
# committed as last_dtgamma on the next accepted step.
857+
if new_jac
858+
jac_reuse.pending_dtgamma = dtgamma
859+
end
860+
861+
# Build W from J
862+
W = J - mass_matrix * inv(dtgamma)
863+
if !isa(W, Number)
864+
W = DiffEqBase.default_factorize(W)
865+
end
866+
integrator.stats.nw += 1
867+
868+
return dT, W
869+
end
870+
701871
# update W matrix (only used in Newton method)
702872
function update_W!(integrator, cache, dtgamma, repeat_step, newJW = nothing)
703873
return update_W!(cache.nlsolver, integrator, cache, dtgamma, repeat_step, newJW)
@@ -709,7 +879,8 @@ function update_W!(
709879
repeat_step::Bool, newJW = nothing
710880
)
711881
if isnewton(nlsolver)
712-
new_jac, new_W = calc_W!(
882+
new_jac,
883+
new_W = calc_W!(
713884
get_W(nlsolver), integrator, nlsolver, cache, dtgamma, repeat_step,
714885
newJW
715886
)
@@ -894,7 +1065,8 @@ function build_J_W(
8941065
elseif f.jac_prototype === nothing
8951066
if alg_autodiff(alg) isa AutoSparse
8961067
if isnothing(f.sparsity)
897-
!isnothing(jac_config) ? convert.(eltype(u), sparsity_pattern(jac_config[1])) :
1068+
!isnothing(jac_config) ?
1069+
convert.(eltype(u), sparsity_pattern(jac_config[1])) :
8981070
spzeros(eltype(u), length(u), length(u))
8991071
elseif eltype(f.sparsity) == Bool
9001072
convert.(eltype(u), f.sparsity)

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ function sparsity_colorvec(f::F, x) where {F}
451451
col_alg = GreedyColoringAlgorithm()
452452
col_prob = ColoringProblem()
453453
colorvec = SciMLBase.has_colorvec(f) ? f.colorvec :
454-
(isnothing(sparsity) ? (1:length(x)) : column_colors(coloring(sparsity, col_prob, col_alg)))
454+
(
455+
isnothing(sparsity) ? (1:length(x)) :
456+
column_colors(coloring(sparsity, col_prob, col_alg))
457+
)
455458
return sparsity, colorvec
456459
end

0 commit comments

Comments
 (0)