Skip to content

Commit ef15360

Browse files
Reuse W (LU factorization) for Rosenbrock-W methods
Instead of always rebuilding W when J is reused, try the old W (including its LU factorization) and only recompute when the step is rejected (EEst > 1). The LU is the expensive part, so this aggressively avoids refactorization. Key changes: - _rosenbrock_jac_reuse_decision returns (false, false) by default instead of (false, true), reusing both J and W - EEst > 1 check forces full recompute after step rejection - IIP perform_step functions pass A = new_W ? W : nothing to dolinsolve, preventing refactorization of LU-corrupted W - Fix macro-generated gen_perform_step to capture and use new_W return value from calc_rosenbrock_differentiation! - Add cached_W field to JacReuseState for future use - Loosen borderline step-count bounds in convergence tests Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 073bce0 commit ef15360

7 files changed

Lines changed: 60 additions & 96 deletions

File tree

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ get_jac_reuse(cache) = hasproperty(cache, :jac_reuse) ? cache.jac_reuse : nothin
1414
Decide whether to recompute the Jacobian and/or W matrix for Rosenbrock methods.
1515
For W-methods (where `isWmethod(alg) == true`), implements CVODE-inspired reuse:
1616
- Always recompute on first iteration
17-
- Recompute J on error test failure (EEst > 1)
17+
- Recompute after step rejection (EEst > 1), since the old W wasn't good enough
1818
- Recompute when gamma ratio changes too much: |dtgamma/last_dtgamma - 1| > 0.3
1919
- Recompute every `max_jac_age` accepted steps (default 50)
2020
- Recompute when u_modified (callback modification)
21-
- W is always rebuilt (since W = J - M/(dt*gamma) depends on current dt)
21+
- Otherwise reuse both J and W (including LU factorization). The LU is the
22+
expensive part, so we try the old W and only recompute if the step fails.
2223
2324
For strict Rosenbrock methods, returns `nothing` to delegate to `do_newJW`.
2425
"""
@@ -67,6 +68,12 @@ function _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
6768
return (true, true)
6869
end
6970

71+
# Previous step was rejected (EEst > 1): the old W wasn't good enough.
72+
# Recompute everything since we're retrying with a different dt anyway.
73+
if integrator.EEst > 1
74+
return (true, true)
75+
end
76+
7077
# Gamma ratio check (uses only accepted-step dtgamma)
7178
last_dtg = jac_reuse.last_dtgamma
7279
if !iszero(last_dtg) && abs(dtgamma / last_dtg - 1) > 0.3
@@ -80,8 +87,9 @@ function _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
8087
return (true, true)
8188
end
8289

83-
# Reuse J, but always rebuild W (since dtgamma changes)
84-
return (false, true)
90+
# Reuse both J and W. The LU factorization is expensive, so try with
91+
# the old W and only recompute if the step is rejected (caught above).
92+
return (false, false)
8593
end
8694

8795
function calc_tderivative!(integrator, cache, dtd1, repeat_step)
@@ -819,7 +827,7 @@ function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step
819827
return dT, W
820828
end
821829

822-
new_jac, _ = newJW
830+
new_jac, new_W = newJW
823831

824832
# For complex W types (operators), delegate to standard calc_W
825833
if cache.W isa StaticWOperator || cache.W isa WOperator ||
@@ -858,7 +866,7 @@ function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step
858866
jac_reuse.pending_dtgamma = dtgamma
859867
end
860868

861-
# Build W from J
869+
# Build W from (possibly cached) J
862870
W = J - mass_matrix * inv(dtgamma)
863871
if !isa(W, Number)
864872
W = DiffEqBase.default_factorize(W)

lib/OrdinaryDiffEqRosenbrock/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1616
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
1717
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
1818
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
19+
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
1920
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2021
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2122
Preferences = "21216c6a-2e73-6563-6e65-726566657250"

lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ function gen_perform_step(tabmask::RosenbrockTableau{Bool,Bool},cachename::Symbo
407407

408408
if $(i==1)
409409
# Must be a part of the first linsolve for preconditioner step
410-
linres = dolinsolve(integrator, linsolve; A = !repeat_step ? W : nothing, b = _vec(linsolve_tmp))
410+
linres = dolinsolve(integrator, linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp))
411411
else
412412
linres = dolinsolve(integrator, linsolve; b = _vec(linsolve_tmp))
413413
end
@@ -480,7 +480,7 @@ function gen_perform_step(tabmask::RosenbrockTableau{Bool,Bool},cachename::Symbo
480480
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
481481
integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t)
482482

483-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
483+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
484484

485485
linsolve = cache.linsolve
486486

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Fields:
1515
- `max_jac_age`: Maximum number of accepted steps between Jacobian updates (default 50)
1616
- `cached_J`: Cached Jacobian for OOP reuse (type-erased for flexibility)
1717
- `cached_dT`: Cached time derivative for OOP reuse
18+
- `cached_W`: Cached factorized W for OOP reuse (LU factorization is expensive)
1819
"""
1920
mutable struct JacReuseState{T}
2021
last_dtgamma::T
@@ -23,10 +24,11 @@ mutable struct JacReuseState{T}
2324
max_jac_age::Int
2425
cached_J::Any
2526
cached_dT::Any
27+
cached_W::Any
2628
end
2729

2830
function JacReuseState(dtgamma::T) where {T}
29-
return JacReuseState{T}(dtgamma, dtgamma, 0, 50, nothing, nothing)
31+
return JacReuseState{T}(dtgamma, dtgamma, 0, 50, nothing, nothing, nothing)
3032
end
3133

3234
# Fake values since non-FSAL

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 36 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,19 @@ end
5454
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
5555
end
5656

57-
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)
57+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)
5858

5959
calculate_residuals!(
6060
weight, fill!(weight, one(eltype(u))), uprev, uprev,
6161
integrator.opts.abstol, integrator.opts.reltol,
6262
integrator.opts.internalnorm, t
6363
)
6464

65-
if repeat_step
66-
linres = dolinsolve(
67-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
68-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
69-
solverdata = (; gamma = dtγ)
70-
)
71-
else
72-
linres = dolinsolve(
73-
integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
74-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
75-
solverdata = (; gamma = dtγ)
76-
)
77-
end
65+
linres = dolinsolve(
66+
integrator, cache.linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp),
67+
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
68+
solverdata = (; gamma = dtγ)
69+
)
7870

7971
vecu = _vec(linres.u)
8072
veck₁ = _vec(k₁)
@@ -178,27 +170,19 @@ end
178170
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
179171
end
180172

181-
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)
173+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)
182174

183175
calculate_residuals!(
184176
weight, fill!(weight, one(eltype(u))), uprev, uprev,
185177
integrator.opts.abstol, integrator.opts.reltol,
186178
integrator.opts.internalnorm, t
187179
)
188180

189-
if repeat_step
190-
linres = dolinsolve(
191-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
192-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
193-
solverdata = (; gamma = dtγ)
194-
)
195-
else
196-
linres = dolinsolve(
197-
integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
198-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
199-
solverdata = (; gamma = dtγ)
200-
)
201-
end
181+
linres = dolinsolve(
182+
integrator, cache.linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp),
183+
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
184+
solverdata = (; gamma = dtγ)
185+
)
202186

203187
vecu = _vec(linres.u)
204188
veck₁ = _vec(k₁)
@@ -578,27 +562,19 @@ end
578562
dtd3 = dt * d3
579563
dtgamma = dt * gamma
580564

581-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
565+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
582566

583567
calculate_residuals!(
584568
weight, fill!(weight, one(eltype(u))), uprev, uprev,
585569
integrator.opts.abstol, integrator.opts.reltol,
586570
integrator.opts.internalnorm, t
587571
)
588572

589-
if repeat_step
590-
linres = dolinsolve(
591-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
592-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
593-
solverdata = (; gamma = dtgamma)
594-
)
595-
else
596-
linres = dolinsolve(
597-
integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
598-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
599-
solverdata = (; gamma = dtgamma)
600-
)
601-
end
573+
linres = dolinsolve(
574+
integrator, cache.linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp),
575+
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
576+
solverdata = (; gamma = dtgamma)
577+
)
602578

603579
vecu = _vec(linres.u)
604580
veck1 = _vec(k1)
@@ -793,27 +769,19 @@ end
793769
dtd4 = dt * d4
794770
dtgamma = dt * gamma
795771

796-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
772+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
797773

798774
calculate_residuals!(
799775
weight, fill!(weight, one(eltype(u))), uprev, uprev,
800776
integrator.opts.abstol, integrator.opts.reltol,
801777
integrator.opts.internalnorm, t
802778
)
803779

804-
if repeat_step
805-
linres = dolinsolve(
806-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
807-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
808-
solverdata = (; gamma = dtgamma)
809-
)
810-
else
811-
linres = dolinsolve(
812-
integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
813-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
814-
solverdata = (; gamma = dtgamma)
815-
)
816-
end
780+
linres = dolinsolve(
781+
integrator, cache.linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp),
782+
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
783+
solverdata = (; gamma = dtgamma)
784+
)
817785

818786
vecu = _vec(linres.u)
819787
veck1 = _vec(k1)
@@ -1127,27 +1095,19 @@ end
11271095
f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation!
11281096
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
11291097

1130-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
1098+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
11311099

11321100
calculate_residuals!(
11331101
weight, fill!(weight, one(eltype(u))), uprev, uprev,
11341102
integrator.opts.abstol, integrator.opts.reltol,
11351103
integrator.opts.internalnorm, t
11361104
)
11371105

1138-
if repeat_step
1139-
linres = dolinsolve(
1140-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1141-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1142-
solverdata = (; gamma = dtgamma)
1143-
)
1144-
else
1145-
linres = dolinsolve(
1146-
integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1147-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1148-
solverdata = (; gamma = dtgamma)
1149-
)
1150-
end
1106+
linres = dolinsolve(
1107+
integrator, cache.linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp),
1108+
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1109+
solverdata = (; gamma = dtgamma)
1110+
)
11511111

11521112
@.. broadcast = false $(_vec(k1)) = -linres.u
11531113

@@ -1479,27 +1439,19 @@ end
14791439
f(cache.fsalfirst, uprev, p, t)
14801440
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
14811441

1482-
calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step)
1442+
new_W = calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step)
14831443

14841444
calculate_residuals!(
14851445
weight, fill!(weight, one(eltype(u))), uprev, uprev,
14861446
integrator.opts.abstol, integrator.opts.reltol,
14871447
integrator.opts.internalnorm, t
14881448
)
14891449

1490-
if repeat_step
1491-
linres = dolinsolve(
1492-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1493-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1494-
solverdata = (; gamma = dtgamma)
1495-
)
1496-
else
1497-
linres = dolinsolve(
1498-
integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1499-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1500-
solverdata = (; gamma = dtgamma)
1501-
)
1502-
end
1450+
linres = dolinsolve(
1451+
integrator, cache.linsolve; A = new_W ? W : nothing, b = _vec(linsolve_tmp),
1452+
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1453+
solverdata = (; gamma = dtgamma)
1454+
)
15031455

15041456
@.. $(_vec(ks[1])) = -linres.u
15051457
integrator.stats.nsolve += 1

lib/OrdinaryDiffEqRosenbrock/test/jacobian_reuse_test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ strict_rosenbrock = [
123123
@test jr.max_jac_age == 50
124124
@test jr.cached_J === nothing
125125
@test jr.cached_dT === nothing
126+
@test jr.cached_W === nothing
126127
end
127128

128129
# ========================================================================

lib/OrdinaryDiffEqRosenbrock/test/ode_rosenbrock_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ end
275275
@test sim.𝒪est[:final] 2 atol = testTol
276276

277277
sol = solve(prob, ROS2S())
278-
@test length(sol) < 20
278+
@test length(sol) < 25
279279
@test SciMLBase.successful_retcode(sol)
280280

281281
### ROS3
@@ -526,7 +526,7 @@ end
526526
@test sim.𝒪est[:final] 4 atol = testTol
527527

528528
sol = solve(prob, ROS34PW3())
529-
@test length(sol) < 20
529+
@test length(sol) < 25
530530
@test SciMLBase.successful_retcode(sol)
531531

532532
prob = prob_ode_2Dlinear
@@ -535,7 +535,7 @@ end
535535
@test sim.𝒪est[:final] 4 atol = testTol
536536

537537
sol = solve(prob, ROS34PW3())
538-
@test length(sol) < 20
538+
@test length(sol) < 25
539539
@test SciMLBase.successful_retcode(sol)
540540

541541
### ROS34PRw

0 commit comments

Comments
 (0)