11using SciMLOperators: StaticWOperator, WOperator
22
3+ """
4+ get_jac_reuse(cache)
5+
6+ Duck-typed accessor for the `jac_reuse` field. Returns `nothing` if the cache
7+ does not have a `jac_reuse` field.
8+ """
9+ get_jac_reuse (cache) = hasproperty (cache, :jac_reuse ) ? cache. jac_reuse : nothing
10+
11+ """
12+ _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma) -> Union{Nothing, NTuple{2,Bool}}
13+
14+ Decide whether to recompute the Jacobian and/or W matrix for Rosenbrock methods.
15+ For W-methods (where `isWmethod(alg) == true`), implements CVODE-inspired reuse:
16+ - Always recompute on first iteration
17+ - Recompute J on error test failure (EEst > 1)
18+ - Recompute when gamma ratio changes too much: |dtgamma/last_dtgamma - 1| > 0.3
19+ - Recompute every `max_jac_age` accepted steps (default 50)
20+ - Recompute when u_modified (callback modification)
21+ - W is always rebuilt (since W = J - M/(dt*gamma) depends on current dt)
22+
23+ For strict Rosenbrock methods, returns `nothing` to delegate to `do_newJW`.
24+ """
25+ function _rosenbrock_jac_reuse_decision (integrator, cache, dtgamma)
26+ alg = OrdinaryDiffEqCore. unwrap_alg (integrator, true )
27+
28+ # Non-W-methods: delegate to do_newJW (preserves linear problem optimization etc.)
29+ if ! isWmethod (alg)
30+ return nothing
31+ end
32+
33+ jac_reuse = get_jac_reuse (cache)
34+ # If no reuse state (e.g. OOP cache without jac_reuse), delegate to do_newJW
35+ if jac_reuse === nothing
36+ return nothing
37+ end
38+
39+ # Linear problems: delegate to do_newJW (which returns (false, false) for islin)
40+ islin, _ = islinearfunction (integrator)
41+ if islin
42+ return nothing
43+ end
44+
45+ # First iteration: always compute J and W.
46+ if integrator. iter <= 1
47+ return (true , true )
48+ end
49+
50+ # Commit pending_dtgamma from previous step if it was accepted.
51+ # This ensures rejected steps don't pollute last_dtgamma, keeping
52+ # IIP-adaptive and OOP-non-adaptive reuse decisions synchronized.
53+ naccept = integrator. stats. naccept
54+ if naccept > jac_reuse. last_naccept
55+ jac_reuse. last_dtgamma = jac_reuse. pending_dtgamma
56+ jac_reuse. last_naccept = naccept
57+ end
58+
59+ # Fresh cache (e.g., algorithm switch where iter > 1 but the Rosenbrock
60+ # cache is freshly created with cached_J = nothing).
61+ if iszero (jac_reuse. last_dtgamma)
62+ return (true , true )
63+ end
64+
65+ # Callback modification: recompute
66+ if integrator. u_modified
67+ return (true , true )
68+ end
69+
70+ # Gamma ratio check (uses only accepted-step dtgamma)
71+ last_dtg = jac_reuse. last_dtgamma
72+ if ! iszero (last_dtg) && abs (dtgamma / last_dtg - 1 ) > 0.3
73+ return (true , true )
74+ end
75+
76+ # Age check: recompute J after max_jac_age accepted steps.
77+ # Uses naccept (not a local counter) so rejected steps don't desynchronize
78+ # IIP-adaptive and OOP-non-adaptive solves.
79+ if (naccept - jac_reuse. last_naccept) >= jac_reuse. max_jac_age
80+ return (true , true )
81+ end
82+
83+ # Reuse J, but always rebuild W (since dtgamma changes)
84+ return (false , true )
85+ end
86+
387function calc_tderivative! (integrator, cache, dtd1, repeat_step)
488 return @inbounds begin
589 (; t, dt, uprev, u, f, p) = integrator
@@ -689,15 +773,101 @@ function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repe
689773 # we need to skip calculating `J` and `W` when a step is repeated
690774 new_jac = new_W = false
691775 if ! repeat_step
692- new_jac, new_W = calc_W! (
693- cache. W, integrator, nlsolver, cache, dtgamma, repeat_step
776+ # For W-methods, use reuse logic; for strict Rosenbrock, always recompute
777+ newJW = _rosenbrock_jac_reuse_decision (integrator, cache, dtgamma)
778+ new_jac,
779+ new_W = calc_W! (
780+ cache. W, integrator, nlsolver, cache, dtgamma, repeat_step, newJW
694781 )
782+ # Record pending dtgamma only when J was freshly computed; it will be
783+ # committed as last_dtgamma when the step is accepted (checked in
784+ # _rosenbrock_jac_reuse_decision). This tracks the dtgamma at the
785+ # last J computation for the gamma ratio heuristic.
786+ jac_reuse = get_jac_reuse (cache)
787+ if jac_reuse != = nothing && new_jac
788+ jac_reuse. pending_dtgamma = dtgamma
789+ end
695790 end
696791 # If the Jacobian is not updated, we won't have to update ∂/∂t either.
697792 calc_tderivative! (integrator, cache, dtd1, repeat_step || ! new_jac)
698793 return new_W
699794end
700795
796+ """
797+ calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step)
798+
799+ Non-mutating (OOP) version of `calc_rosenbrock_differentiation!`.
800+ Returns `(dT, W)` where `dT` is the time derivative and `W` is the factorized
801+ system matrix. Supports Jacobian reuse for W-methods via `jac_reuse` in the cache.
802+ """
803+ function calc_rosenbrock_differentiation (integrator, cache, dtgamma, repeat_step)
804+ jac_reuse = get_jac_reuse (cache)
805+
806+ # If no reuse support or repeat step, use standard path
807+ if repeat_step || jac_reuse === nothing
808+ dT = calc_tderivative (integrator, cache)
809+ W = calc_W (integrator, cache, dtgamma, repeat_step)
810+ return dT, W
811+ end
812+
813+ newJW = _rosenbrock_jac_reuse_decision (integrator, cache, dtgamma)
814+
815+ if newJW === nothing
816+ # Delegate to standard path (linear problems, non-W-methods, etc.)
817+ dT = calc_tderivative (integrator, cache)
818+ W = calc_W (integrator, cache, dtgamma, repeat_step)
819+ return dT, W
820+ end
821+
822+ new_jac, _ = newJW
823+
824+ # For complex W types (operators), delegate to standard calc_W
825+ if cache. W isa StaticWOperator || cache. W isa WOperator ||
826+ cache. W isa AbstractSciMLOperator
827+ dT = calc_tderivative (integrator, cache)
828+ W = calc_W (integrator, cache, dtgamma, repeat_step)
829+ jac_reuse. pending_dtgamma = dtgamma
830+ return dT, W
831+ end
832+
833+ mass_matrix = integrator. f. mass_matrix
834+ update_coefficients! (mass_matrix, integrator. uprev, integrator. p, integrator. t)
835+
836+ # Safety: if cached_J is nothing (e.g. first use after algorithm switch),
837+ # force a fresh computation regardless of the decision.
838+ if ! new_jac && jac_reuse. cached_J === nothing
839+ new_jac = true
840+ end
841+
842+ if new_jac
843+ J = calc_J (integrator, cache)
844+ dT = calc_tderivative (integrator, cache)
845+
846+ # Cache for future reuse
847+ jac_reuse. cached_J = J
848+ jac_reuse. cached_dT = dT
849+ else
850+ # Reuse cached J and dT
851+ J = jac_reuse. cached_J
852+ dT = jac_reuse. cached_dT
853+ end
854+
855+ # Record pending dtgamma only when J was freshly computed;
856+ # committed as last_dtgamma on the next accepted step.
857+ if new_jac
858+ jac_reuse. pending_dtgamma = dtgamma
859+ end
860+
861+ # Build W from J
862+ W = J - mass_matrix * inv (dtgamma)
863+ if ! isa (W, Number)
864+ W = DiffEqBase. default_factorize (W)
865+ end
866+ integrator. stats. nw += 1
867+
868+ return dT, W
869+ end
870+
701871# update W matrix (only used in Newton method)
702872function update_W! (integrator, cache, dtgamma, repeat_step, newJW = nothing )
703873 return update_W! (cache. nlsolver, integrator, cache, dtgamma, repeat_step, newJW)
@@ -709,7 +879,8 @@ function update_W!(
709879 repeat_step:: Bool , newJW = nothing
710880 )
711881 if isnewton (nlsolver)
712- new_jac, new_W = calc_W! (
882+ new_jac,
883+ new_W = calc_W! (
713884 get_W (nlsolver), integrator, nlsolver, cache, dtgamma, repeat_step,
714885 newJW
715886 )
@@ -894,7 +1065,8 @@ function build_J_W(
8941065 elseif f. jac_prototype === nothing
8951066 if alg_autodiff (alg) isa AutoSparse
8961067 if isnothing (f. sparsity)
897- ! isnothing (jac_config) ? convert .(eltype (u), sparsity_pattern (jac_config[1 ])) :
1068+ ! isnothing (jac_config) ?
1069+ convert .(eltype (u), sparsity_pattern (jac_config[1 ])) :
8981070 spzeros (eltype (u), length (u), length (u))
8991071 elseif eltype (f. sparsity) == Bool
9001072 convert .(eltype (u), f. sparsity)
0 commit comments