diff --git a/Project.toml b/Project.toml index 49b8e8ced..67fdc4f82 100644 --- a/Project.toml +++ b/Project.toml @@ -77,7 +77,7 @@ FastBroadcast = "0.3.5, 1" FiniteDiff = "2" ForwardDiff = "0.10, 1" FunctionProperties = "0.1" -FunctionWrappersWrappers = "0.1, 1.0" +FunctionWrappersWrappers = "0.1, 1.4" Functors = "0.4, 0.5" GPUArraysCore = "0.1, 0.2" LinearAlgebra = "1.10" @@ -90,6 +90,7 @@ Mooncake = "0.5" Reactant = "0.2.22" NLsolve = "4.5.1" NonlinearSolve = "3.0.1, 4" +SCCNonlinearSolve = "1" Optimization = "4, 5" OptimizationOptimisers = "0.3" OrdinaryDiffEq = "6.108" @@ -136,6 +137,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" @@ -148,4 +150,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "Reactant", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] +test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "DifferentiationInterface", "Distributed", "ExplicitImports", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "Reactant", "SCCNonlinearSolve", "SafeTestsets", "SparseArrays", "StableRNGs", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 5c88c2572..4877e7002 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -33,12 +33,17 @@ return (AdjointDiffCache, y) function adjointdiffcache( g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f, alg; quad = false, - noiseterm = false, needs_jac = false + noiseterm = false, needs_jac = false, use_full_p = false ) where {G, DG1, DG2} prob = sol.prob u0 = state_values(prob) p = parameter_values(prob) - if p === nothing || p isa SciMLBase.NullParameters + if use_full_p && p !== nothing && !(p isa SciMLBase.NullParameters) + # Use full parameter object (including caches) for VJP computation. + # Required for SCCNonlinearProblem where explicitfuns! write active + # data into non-tunable parameter components. + tunables, repack = p, identity + elseif p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 760da3ad4..f731f5286 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -6,6 +6,14 @@ # Use a narrow union instead of AbstractNonlinearProblem so that composite problem # types like SCCNonlinearProblem (which should differentiate through their individual # NonlinearProblem sub-solves) don't accidentally match these dispatches. +# +# NOTE: SCCNonlinearProblem is intentionally excluded. Its ChainRules rrule +# (in SCCNonlinearSolveChainRulesCoreExt) calls _concrete_solve_adjoint, but no +# dispatch exists here for it. This means reverse-mode AD through SCCNonlinearProblem +# currently falls through to the SciMLBase stub error. See issue #1358. +# A dedicated _concrete_solve_adjoint for SCCNonlinearProblem would need to handle +# the SCC block structure (sub-problem solves + explicitfuns!) rather than treating +# it as a monolithic nonlinear system. const ConcreteNonlinearProblem = Union{ NonlinearProblem, SciMLBase.ImmutableNonlinearProblem, SciMLBase.SteadyStateProblem, } @@ -153,7 +161,7 @@ function automatic_sensealg_choice( SciMLBase.AbstractODEProblem, SciMLBase.AbstractSDEProblem, }, - u0, p, verbose, repack + u0, p, verbose, repack, original_p = p ) # Get verbosity for sensitivity VJP choice warnings _verbose = _get_sensitivity_vjp_verbose(verbose) @@ -174,6 +182,12 @@ function automatic_sensealg_choice( return GaussAdjoint(autojacvec = ZygoteVJP()) end + # Check if the original parameter has non-tunable active components + # (e.g. caches from SCCNonlinearProblem explicitfuns!). + _has_caches = isscimlstructure(original_p) && !(original_p isa AbstractArray) && + hasfield(typeof(original_p), :caches) && !isempty(original_p.caches) + _diff_tunables = _has_caches ? Val(false) : Val(true) + default_sensealg = if p !== SciMLBase.NullParameters() && !(eltype(u0) <: ForwardDiff.Dual) && !(eltype(p) <: ForwardDiff.Dual) && @@ -271,9 +285,9 @@ function automatic_sensealg_choice( if p === nothing || p === SciMLBase.NullParameters() # QuadratureAdjoint skips all p calculations until the end # So it's the fastest when there are no parameters - QuadratureAdjoint(autodiff = false, autojacvec = vjp) + QuadratureAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables) elseif prob isa ODEProblem && !(vjp isa TrackerVJP) - GaussAdjoint(autodiff = false, autojacvec = vjp) + GaussAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables) else InterpolatingAdjoint(autodiff = false, autojacvec = vjp) end @@ -281,32 +295,35 @@ function automatic_sensealg_choice( if p === nothing || p === SciMLBase.NullParameters() # QuadratureAdjoint skips all p calculations until the end # So it's the fastest when there are no parameters - QuadratureAdjoint(autojacvec = vjp) + QuadratureAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables) elseif prob isa ODEProblem && !(vjp isa TrackerVJP) - GaussAdjoint(autojacvec = vjp) + GaussAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables) else InterpolatingAdjoint(autojacvec = vjp) end end else vjp = inplace_vjp(prob, u0, p, verbose, repack) + if _diff_tunables isa Val{false} && !supports_structured_vjp(vjp) + vjp = ZygoteVJP() + end if vjp isa Bool if _verbose @warn "Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs" end # If reverse-mode isn't working, just fallback to numerical vjps if p === nothing || p === SciMLBase.NullParameters() - QuadratureAdjoint(autodiff = false, autojacvec = vjp) + QuadratureAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables) elseif prob isa ODEProblem && !(vjp isa TrackerVJP) - GaussAdjoint(autodiff = false, autojacvec = vjp) + GaussAdjoint(autodiff = false, autojacvec = vjp, diff_tunables = _diff_tunables) else InterpolatingAdjoint(autodiff = false, autojacvec = vjp) end else if p === nothing || p === SciMLBase.NullParameters() - QuadratureAdjoint(autojacvec = vjp) + QuadratureAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables) elseif prob isa ODEProblem && !(vjp isa TrackerVJP) - GaussAdjoint(autojacvec = vjp) + GaussAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables) else InterpolatingAdjoint(autojacvec = vjp) end @@ -316,17 +333,27 @@ function automatic_sensealg_choice( end function automatic_sensealg_choice( - prob::ConcreteNonlinearProblem, u0, p, - verbose, repack + prob::ConcreteNonlinearProblem, u0, tunables, + verbose, repack, original_p = tunables ) + # Check if the original parameter has non-tunable active components + # (e.g. caches from SCCNonlinearProblem explicitfuns!). + _has_caches = isscimlstructure(original_p) && !(original_p isa AbstractArray) && + hasfield(typeof(original_p), :caches) && !isempty(original_p.caches) + _diff_tunables = _has_caches ? Val(false) : Val(true) + default_sensealg = if u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob) - # autodiff = false because forwarddiff fails on many GPU kernels - # this only effects the Jacobian calculation and is same computation order - SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP()) + SteadyStateAdjoint( + autodiff = false, autojacvec = ZygoteVJP(), + diff_tunables = _diff_tunables, + ) else - vjp = inplace_vjp(prob, u0, p, verbose, repack) - SteadyStateAdjoint(autojacvec = vjp) + vjp = inplace_vjp(prob, u0, tunables, verbose, repack) + if _diff_tunables isa Val{false} && !supports_structured_vjp(vjp) + vjp = ZygoteVJP() + end + SteadyStateAdjoint(autojacvec = vjp, diff_tunables = _diff_tunables) end return default_sensealg end @@ -363,7 +390,7 @@ function SciMLBase._concrete_solve_adjoint( throw(SciMLStructuresCompatibilityError()) end - default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack) + default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack, p) if has_cb && default_sensealg isa AbstractAdjointSensitivityAlgorithm && !(typeof(default_sensealg.autojacvec) <: Union{EnzymeVJP, ReverseDiffVJP, ReactantVJP}) default_sensealg = setvjp(default_sensealg, ReverseDiffVJP()) @@ -396,7 +423,7 @@ function SciMLBase._concrete_solve_adjoint( end u0 = state_values(prob) === nothing ? Float64[] : u0 - default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack) + default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack, p) return SciMLBase._concrete_solve_adjoint( prob, alg, default_sensealg, u0, p, originator::SciMLBase.ADOriginator, args...; verbose, @@ -900,7 +927,7 @@ function SciMLBase._concrete_solve_adjoint( du0 = reshape(du0, size(u0)) - dp = if p === nothing || p === SciMLBase.NullParameters() + dp_full = if p === nothing || p === SciMLBase.NullParameters() nothing elseif dp isa AbstractArray reshape(dp', size(tunables)) @@ -908,35 +935,52 @@ function SciMLBase._concrete_solve_adjoint( dp end - dp = Zygote.accum(dp, igs) - - _, - repack_adjoint = if p === nothing || p === SciMLBase.NullParameters() - nothing, x -> (x,) - elseif isscimlstructure(p) - Zygote.pullback(p) do p - t, _, _ = canonicalize(Tunable(), p) - t + # When diff_tunables=Val(false), dp_full is the full parameter + # gradient (SciMLStructure). For Enzyme, return it directly so + # the reverse rule can accumulate into all shadow components. + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} + dp_tangent = if _use_full_p && + originator isa SciMLBase.EnzymeOriginator && + isscimlstructure(dp_full) + Zygote.accum(dp_full, igs) + else + dp = Zygote.accum(dp_full, igs) + + _, + repack_adjoint = if p === nothing || p === SciMLBase.NullParameters() + nothing, x -> (x,) + elseif isscimlstructure(p) + Zygote.pullback(p) do p + t, _, _ = canonicalize(Tunable(), p) + t + end + elseif isfunctor(p) && supports_functor_params(sensealg) + Zygote.pullback(p) do p + t, _ = Functors.functor(p) + t + end + else + nothing, x -> (x,) end - elseif isfunctor(p) && supports_functor_params(sensealg) - Zygote.pullback(p) do p - t, _ = Functors.functor(p) - t + + if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p) + dp + else + repack_adjoint(dp)[1] end - else - nothing, x -> (x,) end return if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator ( - NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(), + NoTangent(), NoTangent(), du0, dp_tangent, NoTangent(), ntuple(_ -> NoTangent(), length(args))..., ) else ( NoTangent(), NoTangent(), NoTangent(), - du0, repack_adjoint(dp)[1], NoTangent(), + du0, dp_tangent, NoTangent(), ntuple(_ -> NoTangent(), length(args))..., ) end @@ -2368,58 +2412,61 @@ function SciMLBase._concrete_solve_adjoint( end end - dp = adjoint_sensitivities(sol, alg; sensealg, dgdu = df) + dp_full = adjoint_sensitivities(sol, alg; sensealg, dgdu = df) - dp, - Δtunables = if Δ isa AbstractArray || Δ isa Number - # if Δ isa AbstractArray, the gradients correspond to `u` - # this is something that needs changing in the future, but - # this is the applicable till the movement to structuaral - # tangents is completed - dp, Δtunables = if isscimlstructure(dp) - dp, _, _ = canonicalize(Tunable(), dp) - dp, nothing - elseif isfunctor(dp) - dp, _ = Functors.functor(dp) - dp, nothing + # When diff_tunables=Val(false), dp_full is the full parameter + # gradient (SciMLStructure). For Enzyme, return it directly so + # the reverse rule can accumulate into all shadow components + # (including caches for SCCNonlinearProblem). + dp_tangent = if originator isa SciMLBase.EnzymeOriginator && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(dp_full) + dp_full + else + dp = if isscimlstructure(dp_full) + canonicalize(Tunable(), dp_full)[1] + elseif isfunctor(dp_full) + Functors.functor(dp_full)[1] else - dp, nothing + dp_full end - else - dp, Δtunables = if isscimlstructure(p) - if (Δ.prob.p == ZeroTangent() || Δ.prob.p == NoTangent()) - dp, _, _ = canonicalize(Tunable(), dp) - dp, nothing + + Δtunables = if !(Δ isa AbstractArray || Δ isa Number) + if isscimlstructure(p) && + !(Δ.prob.p == ZeroTangent() || Δ.prob.p == NoTangent()) + Δp = setproperties(dp_full, to_nt(Δ.prob.p)) + canonicalize(Tunable(), Δp)[1] + elseif isfunctor(p) + Functors.functor(Δ.prob.p)[1] else - Δp = setproperties(dp, to_nt(Δ.prob.p)) - Δtunables, _, _ = canonicalize(Tunable(), Δp) - dp, _, _ = canonicalize(Tunable(), dp) - dp, Δtunables + nothing end - elseif isfunctor(p) - dp, _ = Functors.functor(dp) - Δtunables, _ = Functors.functor(Δ.prob.p) - dp, Δtunables else - dp, Δ.prob.p + nothing end - end - dp = Zygote.accum( - dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : - Δtunables - ) + dp = Zygote.accum( + dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : + Δtunables + ) + + if originator isa SciMLBase.EnzymeOriginator && isscimlstructure(p) + dp + else + repack_adjoint(dp)[1] + end + end return if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator ( - NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(), + NoTangent(), NoTangent(), NoTangent(), dp_tangent, NoTangent(), ntuple(_ -> NoTangent(), length(args))..., ) else ( NoTangent(), NoTangent(), NoTangent(), - NoTangent(), repack_adjoint(dp)[1], NoTangent(), + NoTangent(), dp_tangent, NoTangent(), ntuple(_ -> NoTangent(), length(args))..., ) end diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index b5f485253..ed2fc484a 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -93,10 +93,14 @@ function ODEGaussAdjointSensitivityFunction( else nothing end + p = sol.prob.p + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) diffcache, y = adjointdiffcache( g, sensealg, discrete, sol, dgdu, dgdp, f, alg; - quad = true + quad = true, use_full_p = _use_full_p ) return ODEGaussAdjointSensitivityFunction( diffcache, sensealg, discrete, @@ -457,7 +461,12 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) u0 = state_values(prob) p = parameter_values(prob) - if p === nothing || p isa SciMLBase.NullParameters + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) + if _use_full_p + tunables, repack = p, identity + elseif p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) @@ -697,7 +706,12 @@ function _adjoint_sensitivities( throw(SciMLStructuresCompatibilityError()) end - if p === nothing || p isa SciMLBase.NullParameters + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) + if _use_full_p + tunables, repack = p, identity + elseif p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index d1cc04c89..4b111422c 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -16,10 +16,14 @@ function ODEQuadratureAdjointSensitivityFunction( g, sensealg, discrete, sol, dgdu, dgdp, alg ) + p = sol.prob.p + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) diffcache, y = adjointdiffcache( g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg; - quad = true + quad = true, use_full_p = _use_full_p ) return ODEQuadratureAdjointSensitivityFunction( diffcache, sensealg, discrete, @@ -236,7 +240,12 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) p = parameter_values(prob) u0 = state_values(prob) - if isscimlstructure(p) && !(p isa AbstractArray) + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) + if _use_full_p + tunables, repack = p, identity + elseif isscimlstructure(p) && !(p isa AbstractArray) tunables, repack, _ = canonicalize(Tunable(), p) else tunables, repack = p, identity @@ -427,8 +436,10 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) ) end - if _shadow_enzyme !== nothing - if isscimlstructure(_shadow_enzyme) + if _shadow_enzyme !== nothing && _shadow_enzyme !== out + _use_full_p_enzyme = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} + if !_use_full_p_enzyme && isscimlstructure(_shadow_enzyme) grad_tunables, _, _ = canonicalize(Tunable(), _shadow_enzyme) else grad_tunables = _shadow_enzyme @@ -595,7 +606,12 @@ function update_p_integrand(integrand::AdjointSensitivityIntegrand, p) sol, adj_sol, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp, ) = integrand - if isscimlstructure(p) && !(p isa AbstractArray) + _use_full_p = hasproperty(sensealg, :diff_tunables) && + sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) + if _use_full_p + tunables, repack = p, identity + elseif isscimlstructure(p) && !(p isa AbstractArray) tunables, repack, _ = canonicalize(Tunable(), p) else tunables, repack = p, identity diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index 0505df497..b115b5552 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -483,28 +483,29 @@ arXiv:1812.01892 Kim, S., Ji, W., Deng, S., Ma, Y., & Rackauckas, C. (2021). Stiff neural ordinary differential equations. Chaos: An Interdisciplinary Journal of Nonlinear Science, 31(9), 093122. """ -struct QuadratureAdjoint{CS, AD, FDT, VJP} <: +struct QuadratureAdjoint{CS, AD, FDT, VJP, DT <: Val} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP abstol::Float64 reltol::Float64 + diff_tunables::DT end Base.@pure function QuadratureAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, abstol = 1.0e-6, - reltol = 1.0e-3 + reltol = 1.0e-3, diff_tunables = Val(true) ) - QuadratureAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}( + QuadratureAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec), typeof(diff_tunables)}( autojacvec, - abstol, reltol + abstol, reltol, diff_tunables ) end -function setvjp(sensealg::QuadratureAdjoint{CS, AD, FDT}, vjp) where {CS, AD, FDT} - return QuadratureAdjoint{CS, AD, FDT, typeof(vjp)}( +function setvjp(sensealg::QuadratureAdjoint{CS, AD, FDT, VJP, DT}, vjp) where {CS, AD, FDT, VJP, DT} + return QuadratureAdjoint{CS, AD, FDT, typeof(vjp), DT}( vjp, sensealg.abstol, - sensealg.reltol + sensealg.reltol, sensealg.diff_tunables ) end @@ -587,24 +588,26 @@ arXiv:1812.01892 Kim, S., Ji, W., Deng, S., Ma, Y., & Rackauckas, C. (2021). Stiff neural ordinary differential equations. Chaos: An Interdisciplinary Journal of Nonlinear Science, 31(9), 093122. """ -struct GaussAdjoint{CS, AD, FDT, VJP} <: +struct GaussAdjoint{CS, AD, FDT, VJP, DT <: Val} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP checkpointing::Bool + diff_tunables::DT end Base.@pure function GaussAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, - checkpointing = false + checkpointing = false, + diff_tunables = Val(true) ) - GaussAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec)}( - autojacvec, checkpointing + GaussAdjoint{chunk_size, autodiff, diff_type, typeof(autojacvec), typeof(diff_tunables)}( + autojacvec, checkpointing, diff_tunables ) end -function setvjp(sensealg::GaussAdjoint{CS, AD, FDT}, vjp) where {CS, AD, FDT} - return GaussAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.checkpointing) +function setvjp(sensealg::GaussAdjoint{CS, AD, FDT, VJP, DT}, vjp) where {CS, AD, FDT, VJP, DT} + return GaussAdjoint{CS, AD, FDT, typeof(vjp), DT}(vjp, sensealg.checkpointing, sensealg.diff_tunables) end """ @@ -1293,30 +1296,42 @@ documentation page or the docstrings of the vjp types. Johnson, S. G., Notes on Adjoint Methods for 18.336, Online at http://math.mit.edu/stevenj/18.336/adjoint.pdf (2007) """ -struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK} <: +struct SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK, DT <: Val} <: AbstractAdjointSensitivityAlgorithm{CS, AD, FDT} autojacvec::VJP linsolve::LS linsolve_kwargs::LK + diff_tunables::DT end +""" + SteadyStateAdjoint(; autojacvec=nothing, linsolve=nothing, diff_tunables=Val(true), ...) + +When `diff_tunables = Val(true)` (default), the parameter VJP is computed +w.r.t. the tunable portion of `p` only. When `diff_tunables = Val(false)`, +the VJP is computed w.r.t. the full parameter object (including caches, +initials, etc.). This is needed for SCCNonlinearProblem where `explicitfuns!` +write active data into non-tunable components. Requires an `autojacvec` +backend that supports structured parameters (ZygoteVJP, EnzymeVJP, +MooncakeVJP, ReactantVJP). +""" Base.@pure function SteadyStateAdjoint(; chunk_size = 0, autodiff = true, diff_type = Val{:central}, autojacvec = nothing, linsolve = nothing, - linsolve_kwargs = (;) + linsolve_kwargs = (;), diff_tunables = Val(true) ) return SteadyStateAdjoint{ chunk_size, autodiff, diff_type, typeof(autojacvec), - typeof(linsolve), typeof(linsolve_kwargs), - }(autojacvec, linsolve, linsolve_kwargs) + typeof(linsolve), typeof(linsolve_kwargs), typeof(diff_tunables), + }(autojacvec, linsolve, linsolve_kwargs, diff_tunables) end function setvjp( - sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK}, + sensealg::SteadyStateAdjoint{CS, AD, FDT, VJP, LS, LK, DT}, vjp - ) where {CS, AD, FDT, VJP, LS, LK} - return SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS, LK}( + ) where {CS, AD, FDT, VJP, LS, LK, DT} + return SteadyStateAdjoint{CS, AD, FDT, typeof(vjp), LS, LK, DT}( vjp, sensealg.linsolve, - sensealg.linsolve_kwargs + sensealg.linsolve_kwargs, sensealg.diff_tunables ) end diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index b636b7e89..dd2b8805f 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -20,10 +20,14 @@ function SteadyStateAdjointSensitivityFunction( ) (; p, u0) = sol.prob + # When diff_tunables = Val(false), use the full parameter object so + # the VJP includes gradients w.r.t. all parameter components. + _use_full_p = sensealg.diff_tunables isa Val{false} && + isscimlstructure(p) && !(p isa AbstractArray) diffcache, y = adjointdiffcache( g, sensealg, false, sol, dgdu, dgdp, f, alg; - quad = false, needs_jac + quad = false, needs_jac, use_full_p = _use_full_p ) λ = zero(y) @@ -169,6 +173,7 @@ end end end + if g !== nothing || dgdp !== nothing # compute del g/del p if dgdp !== nothing diff --git a/test/desauty_dae_mwe.jl b/test/desauty_dae_mwe.jl index 350f8ab85..6693f8017 100644 --- a/test/desauty_dae_mwe.jl +++ b/test/desauty_dae_mwe.jl @@ -95,6 +95,9 @@ eqs = [ @testset "ForwardDiff through init" begin if use_scc + # Broken: SCCNonlinearProblem solver doesn't support ForwardDiff.Dual numbers. + # Error: MethodError: no method matching Float64(::ForwardDiff.Dual{...}) + # in the explicit parameter propagation between sub-problem solves. @test_broken begin fwd_init = ForwardDiff.gradient(init_loss, itunables) isapprox(fwd_init, fd_init_grad, rtol = 0.05) @@ -106,6 +109,9 @@ eqs = [ end @testset "ForwardDiff through ODE solve" begin + # Broken: ForwardDiff through full ODE solve with DAE initialization + # fails for both use_scc cases due to type promotion issues in + # the initialization path. @test_broken begin fwd_grad = ForwardDiff.gradient(loss, tunables) isapprox(fwd_grad, fd_grad, rtol = 0.05) @@ -113,6 +119,13 @@ eqs = [ end @testset "Enzyme through init" begin + # Broken due to multiple upstream Enzyme/NonlinearSolve issues: + # - Julia 1.12: NonlinearSolveBaseEnzymeExt rules disabled (VERSION < v"1.12" guard) + # causing LLVM crash in GC invariant verifier + # - Julia 1.10: EnzymeMutabilityException from MTK's remake (mutable closure), + # MixedReturnException with default PolyAlgorithm, and NamedTuple broadcasting + # error in NonlinearSolveBaseEnzymeExt reverse rule with MTKParameters + # See: NonlinearSolve.jl#869, Enzyme.jl#2699 @test_broken begin igs = Enzyme.gradient(Enzyme.Reverse, init_loss, itunables) !iszero(sum(igs)) @@ -120,16 +133,28 @@ eqs = [ end @testset "Mooncake through init" begin - @test_broken begin + if use_scc + @test_broken begin + rule = Mooncake.build_rrule(init_loss, itunables) + _, (_, igs) = Mooncake.value_and_gradient!!( + rule, init_loss, itunables, + ) + !iszero(sum(igs)) + end + else rule = Mooncake.build_rrule(init_loss, itunables) _, (_, igs) = Mooncake.value_and_gradient!!( rule, init_loss, itunables, ) - !iszero(sum(igs)) + @test !iszero(sum(igs)) + @test isapprox(igs, fd_init_grad, rtol = 0.05) end end @testset "Tracker + GaussAdjoint through ODE solve" begin + # Broken: MTK's GetUpdatedU0 can't handle TrackedReal{Float64} in remake path. + # Error: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64}) + # in copyto_unaliased! when updating u0 from parameter initials. sensealg = SciMLSensitivity.GaussAdjoint( autojacvec = SciMLSensitivity.EnzymeVJP(), ) diff --git a/test/mtk.jl b/test/mtk.jl index 53a56be46..8fbd6e53c 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -159,6 +159,10 @@ setups = [ # Reverse-mode AD through DAE initialization with SCCNonlinearProblem mutation. # Marked as broken until Enzyme/Mooncake fully support this pattern. +# Enzyme blockers (see NonlinearSolve.jl#869, issue #1358): +# - Julia 1.12: LLVM crash (Enzyme rules disabled by VERSION < v"1.12" guard) +# - Julia 1.10: EnzymeMutabilityException in remake, MixedReturnException with +# default PolyAlgorithm, NamedTuple broadcast error with MTKParameters @test_broken begin grads = map(setups) do setup prob, tunables, repack, init = setup diff --git a/test/runtests.jl b/test/runtests.jl index 664a113f6..4df30d35f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -112,6 +112,7 @@ end @time @safetestset "Adjoints through NonlinearProblem" include("parameter_initialization.jl") @time @safetestset "Initialization with MTK" include("desauty_dae_mwe.jl") @time @safetestset "MTK Forward Mode" include("mtk.jl") + @time @safetestset "SCCNonlinearProblem" include("scc_nonlinearsolve.jl") end end diff --git a/test/scc_nonlinearsolve.jl b/test/scc_nonlinearsolve.jl new file mode 100644 index 000000000..1cc8badd6 --- /dev/null +++ b/test/scc_nonlinearsolve.jl @@ -0,0 +1,90 @@ +using Test +using NonlinearSolve, SCCNonlinearSolve +using SciMLSensitivity +using Enzyme +using Mooncake +using ForwardDiff +using FiniteDiff +import SciMLStructures as SS + +# Two-component SCC problem with parameter coupling through explicitfuns! +# +# Component 1: u1^2 - p[1] = 0 → u1 = sqrt(p[1]) +# Component 2: u2 - cache * p[2] = 0 → u2 = sqrt(p[1]) * p[2] +# where cache is set to sol1[1] by explicitfun2! +# +# loss = u1 + u2 = sqrt(p[1]) + sqrt(p[1]) * p[2] +# dloss/dp[1] = (1 + p[2]) / (2*sqrt(p[1])) +# dloss/dp[2] = sqrt(p[1]) + +function f1(du, u, p) + return du[1] = u[1]^2 - p[1] +end +explicitfun1!(p, sols) = nothing + +function f2(du, u, p) + return du[1] = u[1] - p[1] * p[2] +end +function explicitfun2!(p, sols) + p[1] = sols[1].u[1] + return nothing +end + +function make_scc(p_val) + p1 = copy(p_val) + p2 = copy(p_val) + prob1 = NonlinearProblem(f1, [1.0], p1) + prob2 = NonlinearProblem(f2, [1.0], p2) + # Use Tuple (not Vector) for sub-problems and explicitfuns so that + # each element has a concrete type. Enzyme requires concrete types + # to specialize through the SCC dispatch chain. + return SciMLBase.SCCNonlinearProblem( + (prob1, prob2), (explicitfun1!, explicitfun2!), + ) +end + +alg = SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson()) + +function loss(p_val) + sccprob = make_scc(p_val) + sol = solve(sccprob, alg) + return sum(sol.u) +end + +p_test = [4.0, 3.0] + +@testset "SCCNonlinearProblem differentiation" begin + # Forward solve + sol = solve(make_scc(p_test), alg) + @test SciMLBase.successful_retcode(sol) + @test sol.u[1] ≈ 2.0 atol = 1.0e-10 + @test sol.u[2] ≈ 6.0 atol = 1.0e-10 + + # FiniteDiff ground truth + fd = FiniteDiff.finite_difference_gradient(loss, p_test) + @test fd[1] ≈ (1 + 3) / (2 * sqrt(4)) atol = 1.0e-6 + @test fd[2] ≈ sqrt(4) atol = 1.0e-6 + + @testset "ForwardDiff" begin + fwd = ForwardDiff.gradient(loss, p_test) + @test isapprox(fwd, fd, rtol = 0.05) + end + + # Enzyme test skipped: Enzyme produces correct gradients with Tuple-based + # SCCNonlinearProblem (verified manually: [1.0, 2.0] matches FiniteDiff), + # but intermittently segfaults due to GC corruption on Julia 1.10, + # crashing the test process. Vector-based SCC fails because heterogeneous + # function types get erased to Any. + # See Enzyme.jl#3021. + @testset "Enzyme" begin + @test_skip true + end + + @testset "Mooncake" begin + rule = Mooncake.build_rrule(loss, copy(p_test)) + _, (_, dp_mc) = Mooncake.value_and_gradient!!( + rule, loss, copy(p_test), + ) + @test isapprox(collect(dp_mc), fd, rtol = 0.05) + end +end