Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ FastBroadcast = "0.3.5, 1"
FiniteDiff = "2"
ForwardDiff = "0.10, 1"
FunctionProperties = "0.1"
FunctionWrappersWrappers = "0.1, 1.0"
FunctionWrappersWrappers = "0.1, 1.4"
Functors = "0.4, 0.5"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10"
Expand All @@ -90,6 +90,7 @@ Mooncake = "0.5"
Reactant = "0.2.22"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1, 4"
SCCNonlinearSolve = "1"
Optimization = "4, 5"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.108"
Expand Down Expand Up @@ -136,6 +137,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Expand All @@ -148,4 +150,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "Reactant", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "Reactant", "SCCNonlinearSolve", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
9 changes: 7 additions & 2 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@ return (AdjointDiffCache, y)
function adjointdiffcache(
g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f, alg;
quad = false,
noiseterm = false, needs_jac = false
noiseterm = false, needs_jac = false, use_full_p = false
) where {G, DG1, DG2}
prob = sol.prob
u0 = state_values(prob)
p = parameter_values(prob)
if p === nothing || p isa SciMLBase.NullParameters
if use_full_p && p !== nothing && !(p isa SciMLBase.NullParameters)
# Use full parameter object (including caches) for VJP computation.
# Required for SCCNonlinearProblem where explicitfuns! write active
# data into non-tunable parameter components.
tunables, repack = p, identity
elseif p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
Expand Down
189 changes: 118 additions & 71 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
# Use a narrow union instead of AbstractNonlinearProblem so that composite problem
# types like SCCNonlinearProblem (which should differentiate through their individual
# NonlinearProblem sub-solves) don't accidentally match these dispatches.
#
# NOTE: SCCNonlinearProblem is intentionally excluded. Its ChainRules rrule
# (in SCCNonlinearSolveChainRulesCoreExt) calls _concrete_solve_adjoint, but no
# dispatch exists here for it. This means reverse-mode AD through SCCNonlinearProblem
# currently falls through to the SciMLBase stub error. See issue #1358.
# A dedicated _concrete_solve_adjoint for SCCNonlinearProblem would need to handle
# the SCC block structure (sub-problem solves + explicitfuns!) rather than treating
# it as a monolithic nonlinear system.
const ConcreteNonlinearProblem = Union{
NonlinearProblem, SciMLBase.ImmutableNonlinearProblem, SciMLBase.SteadyStateProblem,
}
Expand Down Expand Up @@ -153,7 +161,7 @@ function automatic_sensealg_choice(
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractSDEProblem,
},
u0, p, verbose, repack
u0, p, verbose, repack, original_p = p
)
# Get verbosity for sensitivity VJP choice warnings
_verbose = _get_sensitivity_vjp_verbose(verbose)
Expand All @@ -174,6 +182,12 @@ function automatic_sensealg_choice(
return GaussAdjoint(autojacvec = ZygoteVJP())
end

# Check if the original parameter has non-tunable active components
# (e.g. caches from SCCNonlinearProblem explicitfuns!).
_has_caches = isscimlstructure(original_p) && !(original_p isa AbstractArray) &&
hasfield(typeof(original_p), :caches) && !isempty(original_p.caches)
_diff_tunables = _has_caches ? Val(false) : Val(true)

default_sensealg = if p !== SciMLBase.NullParameters() &&
!(eltype(u0) <: ForwardDiff.Dual) &&
!(eltype(p) <: ForwardDiff.Dual) &&
Expand Down Expand Up @@ -271,42 +285,45 @@ function automatic_sensealg_choice(
if p === nothing || p === SciMLBase.NullParameters()
# QuadratureAdjoint skips all p calculations until the end
# So it's the fastest when there are no parameters
QuadratureAdjoint(autodiff = false, autojacvec = vjp)
QuadratureAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables)
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autodiff = false, autojacvec = vjp)
GaussAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables)
else
InterpolatingAdjoint(autodiff = false, autojacvec = vjp)
end
else
if p === nothing || p === SciMLBase.NullParameters()
# QuadratureAdjoint skips all p calculations until the end
# So it's the fastest when there are no parameters
QuadratureAdjoint(autojacvec = vjp)
QuadratureAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables)
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autojacvec = vjp)
GaussAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables)
else
InterpolatingAdjoint(autojacvec = vjp)
end
end
else
vjp = inplace_vjp(prob, u0, p, verbose, repack)
if _diff_tunables isa Val{false} && !supports_structured_vjp(vjp)
vjp = ZygoteVJP()
end
if vjp isa Bool
if _verbose
@warn "Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs"
end
# If reverse-mode isn't working, just fallback to numerical vjps
if p === nothing || p === SciMLBase.NullParameters()
QuadratureAdjoint(autodiff = false, autojacvec = vjp)
QuadratureAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables)
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autodiff = false, autojacvec = vjp)
GaussAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables)
else
InterpolatingAdjoint(autodiff = false, autojacvec = vjp)
end
else
if p === nothing || p === SciMLBase.NullParameters()
QuadratureAdjoint(autojacvec = vjp)
QuadratureAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables)
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autojacvec = vjp)
GaussAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables)
else
InterpolatingAdjoint(autojacvec = vjp)
end
Expand All @@ -316,17 +333,27 @@ function automatic_sensealg_choice(
end

function automatic_sensealg_choice(
prob::ConcreteNonlinearProblem, u0, p,
verbose, repack
prob::ConcreteNonlinearProblem, u0, tunables,
verbose, repack, original_p = tunables
)
# Check if the original parameter has non-tunable active components
# (e.g. caches from SCCNonlinearProblem explicitfuns!).
_has_caches = isscimlstructure(original_p) && !(original_p isa AbstractArray) &&
hasfield(typeof(original_p), :caches) && !isempty(original_p.caches)
_diff_tunables = _has_caches ? Val(false) : Val(true)

default_sensealg = if u0 isa GPUArraysCore.AbstractGPUArray ||
!DiffEqBase.isinplace(prob)
# autodiff = false because forwarddiff fails on many GPU kernels
# this only effects the Jacobian calculation and is same computation order
SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP())
SteadyStateAdjoint(
autodiff = false, autojacvec = ZygoteVJP(),
diff_tunables = _diff_tunables,
)
else
vjp = inplace_vjp(prob, u0, p, verbose, repack)
SteadyStateAdjoint(autojacvec = vjp)
vjp = inplace_vjp(prob, u0, tunables, verbose, repack)
if _diff_tunables isa Val{false} && !supports_structured_vjp(vjp)
vjp = ZygoteVJP()
end
SteadyStateAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables)
end
return default_sensealg
end
Expand Down Expand Up @@ -363,7 +390,7 @@ function SciMLBase._concrete_solve_adjoint(
throw(SciMLStructuresCompatibilityError())
end

default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack, p)
if has_cb && default_sensealg isa AbstractAdjointSensitivityAlgorithm &&
!(typeof(default_sensealg.autojacvec) <: Union{EnzymeVJP, ReverseDiffVJP, ReactantVJP})
default_sensealg = setvjp(default_sensealg, ReverseDiffVJP())
Expand Down Expand Up @@ -396,7 +423,7 @@ function SciMLBase._concrete_solve_adjoint(
end

u0 = state_values(prob) === nothing ? Float64[] : u0
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack, p)
return SciMLBase._concrete_solve_adjoint(
prob, alg, default_sensealg, u0, p,
originator::SciMLBase.ADOriginator, args...; verbose,
Expand Down Expand Up @@ -900,43 +927,60 @@ function SciMLBase._concrete_solve_adjoint(

du0 = reshape(du0, size(u0))

dp = if p === nothing || p === SciMLBase.NullParameters()
dp_full = if p === nothing || p === SciMLBase.NullParameters()
nothing
elseif dp isa AbstractArray
reshape(dp', size(tunables))
else
dp
end

dp = Zygote.accum(dp, igs)

_,
repack_adjoint = if p === nothing || p === SciMLBase.NullParameters()
nothing, x -> (x,)
elseif isscimlstructure(p)
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
# When diff_tunables=Val(false), dp_full is the full parameter
# gradient (SciMLStructure). For Enzyme, return it directly so
# the reverse rule can accumulate into all shadow components.
_use_full_p = hasproperty(sensealg, :diff_tunables) &&
sensealg.diff_tunables isa Val{false}
dp_tangent = if _use_full_p &&
originator isa SciMLBase.EnzymeOriginator &&
isscimlstructure(dp_full)
Zygote.accum(dp_full, igs)
else
dp = Zygote.accum(dp_full, igs)

_,
repack_adjoint = if p === nothing || p === SciMLBase.NullParameters()
nothing, x -> (x,)
elseif isscimlstructure(p)
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
elseif isfunctor(p) && supports_functor_params(sensealg)
Zygote.pullback(p) do p
t, _ = Functors.functor(p)
t
end
else
nothing, x -> (x,)
end
elseif isfunctor(p) && supports_functor_params(sensealg)
Zygote.pullback(p) do p
t, _ = Functors.functor(p)
t

if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p)
dp
else
repack_adjoint(dp)[1]
end
else
nothing, x -> (x,)
end

return if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(
NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(),
NoTangent(), NoTangent(), du0, dp_tangent, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...,
)
else
(
NoTangent(), NoTangent(), NoTangent(),
du0, repack_adjoint(dp)[1], NoTangent(),
du0, dp_tangent, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...,
)
end
Expand Down Expand Up @@ -2368,58 +2412,61 @@ function SciMLBase._concrete_solve_adjoint(
end
end

dp = adjoint_sensitivities(sol, alg; sensealg, dgdu = df)
dp_full = adjoint_sensitivities(sol, alg; sensealg, dgdu = df)

dp,
Δtunables = if Δ isa AbstractArray || Δ isa Number
# if Δ isa AbstractArray, the gradients correspond to `u`
# this is something that needs changing in the future, but
# this is the applicable till the movement to structuaral
# tangents is completed
dp, Δtunables = if isscimlstructure(dp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing
elseif isfunctor(dp)
dp, _ = Functors.functor(dp)
dp, nothing
# When diff_tunables=Val(false), dp_full is the full parameter
# gradient (SciMLStructure). For Enzyme, return it directly so
# the reverse rule can accumulate into all shadow components
# (including caches for SCCNonlinearProblem).
dp_tangent = if originator isa SciMLBase.EnzymeOriginator &&
sensealg.diff_tunables isa Val{false} &&
isscimlstructure(dp_full)
dp_full
else
dp = if isscimlstructure(dp_full)
canonicalize(Tunable(), dp_full)[1]
elseif isfunctor(dp_full)
Functors.functor(dp_full)[1]
else
dp, nothing
dp_full
end
else
dp, Δtunables = if isscimlstructure(p)
if (Δ.prob.p == ZeroTangent() || Δ.prob.p == NoTangent())
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing

Δtunables = if !(Δ isa AbstractArray || Δ isa Number)
if isscimlstructure(p) &&
!(Δ.prob.p == ZeroTangent() || Δ.prob.p == NoTangent())
Δp = setproperties(dp_full, to_nt(Δ.prob.p))
canonicalize(Tunable(), Δp)[1]
elseif isfunctor(p)
Functors.functor(Δ.prob.p)[1]
else
Δp = setproperties(dp, to_nt(Δ.prob.p))
Δtunables, _, _ = canonicalize(Tunable(), Δp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, Δtunables
nothing
end
elseif isfunctor(p)
dp, _ = Functors.functor(dp)
Δtunables, _ = Functors.functor(Δ.prob.p)
dp, Δtunables
else
dp, Δ.prob.p
nothing
end
end

dp = Zygote.accum(
dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing :
Δtunables
)
dp = Zygote.accum(
dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing :
Δtunables
)

if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p)
dp
else
repack_adjoint(dp)[1]
end
end

return if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(
NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
NoTangent(), NoTangent(), NoTangent(), dp_tangent, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...,
)
else
(
NoTangent(), NoTangent(), NoTangent(),
NoTangent(), repack_adjoint(dp)[1], NoTangent(),
NoTangent(), dp_tangent, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...,
)
end
Expand Down
Loading
Loading