Skip to content

Commit 186d502

Browse files
Add diff_tunables flag to SteadyStateAdjoint for full-parameter VJP
Add `diff_tunables::Val` field to SteadyStateAdjoint (default Val(true)). When Val(false), the parameter VJP via vecjacobian! is computed w.r.t. the full parameter object (including caches) instead of just tunables. This is needed for SCCNonlinearProblem where explicitfuns! write active data into non-tunable parameter components (caches). The automatic sensealg choice detects non-empty caches and sets diff_tunables=Val(false) with a structured-VJP-compatible backend. Changes: - sensitivity_algorithms.jl: Add diff_tunables field to SteadyStateAdjoint - adjoint_common.jl: Add use_full_p kwarg to adjointdiffcache - steadystate_adjoint.jl: Gate use_full_p on diff_tunables flag - concrete_solve.jl: Pass original_p to automatic_sensealg_choice for cache detection; return full gradient for EnzymeOriginator when diff_tunables=Val(false) - test/scc_enzyme.jl: Direct SCC differentiation test with Enzyme Companion PR: NonlinearSolve.jl#884 Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7c6770e commit 186d502

5 files changed

Lines changed: 169 additions & 61 deletions

File tree

src/adjoint_common.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,17 @@ return (AdjointDiffCache, y)
3333
function adjointdiffcache(
3434
g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f, alg;
3535
quad = false,
36-
noiseterm = false, needs_jac = false
36+
noiseterm = false, needs_jac = false, use_full_p = false
3737
) where {G, DG1, DG2}
3838
prob = sol.prob
3939
u0 = state_values(prob)
4040
p = parameter_values(prob)
41-
if p === nothing || p isa SciMLBase.NullParameters
41+
if use_full_p && p !== nothing && !(p isa SciMLBase.NullParameters)
42+
# Use full parameter object (including caches) for VJP computation.
43+
# Required for SCCNonlinearProblem where explicitfuns! write active
44+
# data into non-tunable parameter components.
45+
tunables, repack = p, identity
46+
elseif p === nothing || p isa SciMLBase.NullParameters
4247
tunables, repack = p, identity
4348
elseif isscimlstructure(p)
4449
tunables, repack, _ = canonicalize(Tunable(), p)

src/concrete_solve.jl

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,27 @@ function automatic_sensealg_choice(
324324
end
325325

326326
function automatic_sensealg_choice(
327-
prob::ConcreteNonlinearProblem, u0, p,
328-
verbose, repack
327+
prob::ConcreteNonlinearProblem, u0, tunables,
328+
verbose, repack, original_p = tunables
329329
)
330+
# Check if the original parameter has non-tunable active components
331+
# (e.g. caches from SCCNonlinearProblem explicitfuns!).
332+
_has_caches = isscimlstructure(original_p) && !(original_p isa AbstractArray) &&
333+
hasfield(typeof(original_p), :caches) && !isempty(original_p.caches)
334+
_diff_tunables = _has_caches ? Val(false) : Val(true)
335+
330336
default_sensealg = if u0 isa GPUArraysCore.AbstractGPUArray ||
331337
!DiffEqBase.isinplace(prob)
332-
# autodiff = false because forwarddiff fails on many GPU kernels
333-
# this only effects the Jacobian calculation and is same computation order
334-
SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP())
338+
SteadyStateAdjoint(
339+
autodiff = false, autojacvec = ZygoteVJP(),
340+
diff_tunables = _diff_tunables,
341+
)
335342
else
336-
vjp = inplace_vjp(prob, u0, p, verbose, repack)
337-
SteadyStateAdjoint(autojacvec = vjp)
343+
vjp = inplace_vjp(prob, u0, tunables, verbose, repack)
344+
if _diff_tunables isa Val{false} && !supports_structured_vjp(vjp)
345+
vjp = ZygoteVJP()
346+
end
347+
SteadyStateAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables)
338348
end
339349
return default_sensealg
340350
end
@@ -371,7 +381,7 @@ function SciMLBase._concrete_solve_adjoint(
371381
throw(SciMLStructuresCompatibilityError())
372382
end
373383

374-
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
384+
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack, p)
375385
if has_cb && default_sensealg isa AbstractAdjointSensitivityAlgorithm &&
376386
!(typeof(default_sensealg.autojacvec) <: Union{EnzymeVJP, ReverseDiffVJP, ReactantVJP})
377387
default_sensealg = setvjp(default_sensealg, ReverseDiffVJP())
@@ -404,7 +414,7 @@ function SciMLBase._concrete_solve_adjoint(
404414
end
405415

406416
u0 = state_values(prob) === nothing ? Float64[] : u0
407-
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
417+
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack, p)
408418
return SciMLBase._concrete_solve_adjoint(
409419
prob, alg, default_sensealg, u0, p,
410420
originator::SciMLBase.ADOriginator, args...; verbose,
@@ -2376,56 +2386,49 @@ function SciMLBase._concrete_solve_adjoint(
23762386
end
23772387
end
23782388

2379-
dp = adjoint_sensitivities(sol, alg; sensealg, dgdu = df)
2389+
dp_full = adjoint_sensitivities(sol, alg; sensealg, dgdu = df)
23802390

2381-
dp,
2382-
Δtunables = if Δ isa AbstractArray || Δ isa Number
2383-
# if Δ isa AbstractArray, the gradients correspond to `u`
2384-
# this is something that needs changing in the future, but
2385-
# this is the applicable till the movement to structuaral
2386-
# tangents is completed
2387-
dp, Δtunables = if isscimlstructure(dp)
2388-
dp, _, _ = canonicalize(Tunable(), dp)
2389-
dp, nothing
2390-
elseif isfunctor(dp)
2391-
dp, _ = Functors.functor(dp)
2392-
dp, nothing
2391+
# When diff_tunables=Val(false), dp_full is the full parameter
2392+
# gradient (SciMLStructure). For Enzyme, return it directly so
2393+
# the reverse rule can accumulate into all shadow components
2394+
# (including caches for SCCNonlinearProblem).
2395+
dp_tangent = if originator isa SciMLBase.EnzymeOriginator &&
2396+
sensealg.diff_tunables isa Val{false} &&
2397+
isscimlstructure(dp_full)
2398+
dp_full
2399+
else
2400+
dp = if isscimlstructure(dp_full)
2401+
canonicalize(Tunable(), dp_full)[1]
2402+
elseif isfunctor(dp_full)
2403+
Functors.functor(dp_full)[1]
23932404
else
2394-
dp, nothing
2405+
dp_full
23952406
end
2396-
else
2397-
dp, Δtunables = if isscimlstructure(p)
2398-
if.prob.p == ZeroTangent() || Δ.prob.p == NoTangent())
2399-
dp, _, _ = canonicalize(Tunable(), dp)
2400-
dp, nothing
2407+
2408+
Δtunables = if !isa AbstractArray || Δ isa Number)
2409+
if isscimlstructure(p) &&
2410+
!.prob.p == ZeroTangent() || Δ.prob.p == NoTangent())
2411+
Δp = setproperties(dp_full, to_nt.prob.p))
2412+
canonicalize(Tunable(), Δp)[1]
2413+
elseif isfunctor(p)
2414+
Functors.functor.prob.p)[1]
24012415
else
2402-
Δp = setproperties(dp, to_nt.prob.p))
2403-
Δtunables, _, _ = canonicalize(Tunable(), Δp)
2404-
dp, _, _ = canonicalize(Tunable(), dp)
2405-
dp, Δtunables
2416+
nothing
24062417
end
2407-
elseif isfunctor(p)
2408-
dp, _ = Functors.functor(dp)
2409-
Δtunables, _ = Functors.functor.prob.p)
2410-
dp, Δtunables
24112418
else
2412-
dp, Δ.prob.p
2419+
nothing
24132420
end
2414-
end
24152421

2416-
dp = Zygote.accum(
2417-
dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing :
2418-
Δtunables
2419-
)
2422+
dp = Zygote.accum(
2423+
dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing :
2424+
Δtunables
2425+
)
24202426

2421-
# For Enzyme with SciMLStructure parameters, return the tunable gradient
2422-
# vector directly instead of the Zygote-repacked NamedTuple. The Enzyme
2423-
# reverse rule in NonlinearSolveBaseEnzymeExt uses
2424-
# SciMLStructures.replace! to accumulate it into the parameter shadow.
2425-
dp_tangent = if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p)
2426-
dp
2427-
else
2428-
repack_adjoint(dp)[1]
2427+
if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p)
2428+
dp
2429+
else
2430+
repack_adjoint(dp)[1]
2431+
end
24292432
end
24302433

24312434
return if originator isa SciMLBase.TrackerOriginator ||

src/sensitivity_algorithms.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,30 +1293,42 @@ documentation page or the docstrings of the vjp types.
12931293
Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at
12941294
http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007)
12951295
"""
1296-
struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK} <:
1296+
struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK, DT <: Val} <:
12971297
AbstractAdjointSensitivityAlgorithm{CS, AD, FDT}
12981298
autojacvec::VJP
12991299
linsolve::LS
13001300
linsolve_kwargs::LK
1301+
diff_tunables::DT
13011302
end
13021303

1304+
"""
1305+
SteadyStateAdjoint(; autojacvec=nothing, linsolve=nothing, diff_tunables=Val(true), ...)
1306+
1307+
When `diff_tunables = Val(true)` (default), the parameter VJP is computed
1308+
w.r.t. the tunable portion of `p` only. When `diff_tunables = Val(false)`,
1309+
the VJP is computed w.r.t. the full parameter object (including caches,
1310+
initials, etc.). This is needed for SCCNonlinearProblem where `explicitfuns!`
1311+
write active data into non-tunable components. Requires an `autojacvec`
1312+
backend that supports structured parameters (ZygoteVJP, EnzymeVJP,
1313+
MooncakeVJP, ReactantVJP).
1314+
"""
13031315
Base.@pure function SteadyStateAdjoint(;
13041316
chunk_size = 0, autodiff = true,
13051317
diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing,
1306-
linsolve_kwargs = (;)
1318+
linsolve_kwargs = (;), diff_tunables = Val(true)
13071319
)
13081320
return SteadyStateAdjoint{
13091321
chunk_size, autodiff, diff_type, typeof(autojacvec),
1310-
typeof(linsolve), typeof(linsolve_kwargs),
1311-
}(autojacvec, linsolve, linsolve_kwargs)
1322+
typeof(linsolve), typeof(linsolve_kwargs), typeof(diff_tunables),
1323+
}(autojacvec, linsolve, linsolve_kwargs, diff_tunables)
13121324
end
13131325
function setvjp(
1314-
sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK},
1326+
sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK, DT},
13151327
vjp
1316-
) where {CS, AD, FDT, VJP, LS, LK}
1317-
return SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS, LK}(
1328+
) where {CS, AD, FDT, VJP, LS, LK, DT}
1329+
return SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS, LK, DT}(
13181330
vjp, sensealg.linsolve,
1319-
sensealg.linsolve_kwargs
1331+
sensealg.linsolve_kwargs, sensealg.diff_tunables
13201332
)
13211333
end
13221334

src/steadystate_adjoint.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ function SteadyStateAdjointSensitivityFunction(
2020
)
2121
(; p, u0) = sol.prob
2222

23+
# When diff_tunables = Val(false), use the full parameter object so
24+
# the VJP includes gradients w.r.t. all parameter components.
25+
_use_full_p = sensealg.diff_tunables isa Val{false} &&
26+
isscimlstructure(p) && !(p isa AbstractArray)
2327
diffcache,
2428
y = adjointdiffcache(
2529
g, sensealg, false, sol, dgdu, dgdp, f, alg;
26-
quad = false, needs_jac
30+
quad = false, needs_jac, use_full_p = _use_full_p
2731
)
2832

2933
λ = zero(y)
@@ -161,6 +165,7 @@ end
161165
end
162166
end
163167

168+
164169
if g !== nothing || dgdp !== nothing
165170
# compute del g/del p
166171
if dgdp !== nothing

test/scc_enzyme.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
using Test
2+
using NonlinearSolve, SCCNonlinearSolve
3+
using SciMLSensitivity
4+
using Enzyme
5+
using FiniteDiff
6+
import SciMLStructures as SS
7+
8+
# Two-component SCC problem with parameter coupling through caches.
9+
# Component 1: u1^2 - p[1] = 0 (root: u1 = sqrt(p[1]))
10+
# Component 2: u2 - p[2]*u1 = 0 (root: u2 = p[2]*u1)
11+
# where u1 from component 1 is passed via explicitfun into component 2's
12+
# parameter cache.
13+
14+
@testset "SCCNonlinearProblem Enzyme differentiation" begin
15+
# Sub-problem 1: u^2 - p = 0
16+
function f1(du, u, p)
17+
du[1] = u[1]^2 - p[1]
18+
end
19+
explicitfun1!(p, sols) = nothing
20+
21+
# Sub-problem 2: u - cache[1] * p[2] = 0
22+
# cache[1] will be set to sol1[1] (= sqrt(p[1])) by explicitfun2
23+
function f2(du, u, p)
24+
du[1] = u[1] - p[1] * p[2] # p[1] is cache, p[2] is tunable
25+
end
26+
function explicitfun2!(p, sols)
27+
p[1] = sols[1].u[1] # transfer u1 from component 1 into cache
28+
return nothing
29+
end
30+
31+
p_shared = [0.0, 2.0] # p[1] = cache (written by explicitfun2), p[2] = tunable
32+
prob1 = NonlinearProblem(
33+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f1), [1.0], p_shared,
34+
)
35+
prob2 = NonlinearProblem(
36+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2), [1.0], p_shared,
37+
)
38+
39+
sccprob = SciMLBase.SCCNonlinearProblem(
40+
[prob1, prob2],
41+
SciMLBase.Void{Any}.([explicitfun1!, explicitfun2!]),
42+
)
43+
44+
alg = SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson())
45+
46+
# Forward solve works
47+
p_test = [4.0, 3.0]
48+
p_shared .= p_test
49+
sol = solve(sccprob, alg)
50+
@test SciMLBase.successful_retcode(sol)
51+
@test sol.u[1] 2.0 atol = 1.0e-10 # sqrt(4)
52+
@test sol.u[2] 6.0 atol = 1.0e-10 # 3 * 2
53+
54+
# FiniteDiff ground truth
55+
function loss(p_val)
56+
p_shared .= p_val
57+
sol = solve(sccprob, alg)
58+
sum(sol.u)
59+
end
60+
fd = FiniteDiff.finite_difference_gradient(loss, p_test)
61+
@test any(!iszero, fd)
62+
63+
# Enzyme gradient
64+
@testset "Enzyme through SCC" begin
65+
loss_enzyme = let sccprob = sccprob, alg = alg, p_shared = p_shared
66+
p_val -> begin
67+
p_shared .= p_val
68+
sol = solve(sccprob, alg)
69+
sum(sol.u)
70+
end
71+
end
72+
73+
dloss = Enzyme.make_zero(loss_enzyme)
74+
dp = zeros(length(p_test))
75+
Enzyme.autodiff(
76+
Enzyme.set_runtime_activity(Enzyme.Reverse),
77+
Enzyme.Duplicated(loss_enzyme, dloss),
78+
Enzyme.Active,
79+
Enzyme.Duplicated(copy(p_test), dp),
80+
)
81+
@test isapprox(dp, fd, rtol = 0.05)
82+
end
83+
end

0 commit comments

Comments
 (0)