From d20f2136b1c7eb114caae3c2d438da106f6d2aa4 Mon Sep 17 00:00:00 2001 From: Arn97 Date: Fri, 6 Mar 2026 20:06:30 -0500 Subject: [PATCH] Second attempt at issue #995 --- src/gauss_adjoint.jl | 48 +++++++++++++++++++++++++++++++++----- src/parameters_handling.jl | 16 +++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 02a79b1e6..4fe20702f 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -457,8 +457,17 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) u0 = state_values(prob) p = parameter_values(prob) - if p === nothing || p isa SciMLBase.NullParameters - tunables, repack = p, identity + # if p === nothing || p isa SciMLBase.NullParameters + # tunables, repack = p, identity + # ---------------------------------------------- + # fix 0: + if p === nothing + tunables = SciMLBase.NullParameters() + repack = _ -> nothing + elseif p isa SciMLBase.NullParameters + tunables = p + repack = identity + # ---------------------------------------------- elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) elseif isfunctor(p) @@ -477,7 +486,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) ) end - numparams = length(tunables) + #-------------------------------------------- + # fix 1: + numparams = (tunables === nothing || tunables isa SciMLBase.NullParameters) ? 0 : length(tunables) + #-------------------------------------------- + # numparams = length(tunables) y = zero(state_values(prob)) λ = zero(state_values(prob)) # we need to alias `y` @@ -558,6 +571,12 @@ end # out = λ df(u, p, t)/dp at u=y, p=p, t=t function vec_pjac!(out, λ, y, t, S::GaussIntegrand) + #--------------------------------------------- + # fix 5: + if S.tunables isa SciMLBase.NullParameters + return out # should already be length-0 / zeros + end + #--------------------------------------------- (; pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, tunables, repack) = S _odef = sol.prob.f f = unwrapped_f(_odef) @@ -695,8 +714,17 @@ function _adjoint_sensitivities( throw(SciMLStructuresCompatibilityError()) end - if p === nothing || p isa SciMLBase.NullParameters - tunables, repack = p, identity + # if p === nothing || p isa SciMLBase.NullParameters + # tunables, repack = p, identity + # ---------------------------------------------- + # fix 0: + if p === nothing + tunables = SciMLBase.NullParameters() + repack = _ -> nothing + elseif p isa SciMLBase.NullParameters + tunables = p + repack = identity + # ---------------------------------------------- elseif isscimlstructure(p) tunables, repack, _ = canonicalize(Tunable(), p) elseif isfunctor(p) @@ -714,15 +742,23 @@ function _adjoint_sensitivities( integrand = GaussIntegrand(sol, sensealg, checkpoints, dgdp_continuous) integrand_values = IntegrandValuesSum(allocate_zeros(tunables)) if sensealg isa GaussAdjoint + # cb = IntegratingSumCallback( + # (out, u, t, integrator) -> integrand(out, t, u), + # integrand_values, allocate_vjp(tunables) + # ) + #---------------------------------------------- + # fix 4: cb = IntegratingSumCallback( - (out, u, t, integrator) -> integrand(out, t, u), + (out, t, integrator) -> integrand(out, t, integrator.u), integrand_values, allocate_vjp(tunables) ) + #---------------------------------------------- elseif sensealg isa GaussKronrodAdjoint cb = IntegratingGKSumCallback( (out, u, t, integrator) -> integrand(out, t, u), integrand_values, allocate_vjp(tunables) ) + end rcb = nothing cb2 = nothing diff --git a/src/parameters_handling.jl b/src/parameters_handling.jl index b59d80e90..7ccbe5332 100644 --- a/src/parameters_handling.jl +++ b/src/parameters_handling.jl @@ -73,6 +73,13 @@ function allocate_vjp(λ::AbstractArray, x::NamedTuple{F}) where {F} end allocate_vjp(λ::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp, λ), x) +# --------------------------------------------- +# fix 3: make allocate_vjp safe on "no params" +allocate_vjp(x::Nothing) = nothing +allocate_vjp(x::SciMLBase.NullParameters) = nothing +allocate_vjp(::AbstractArray, ::Nothing) = nothing +allocate_vjp(::AbstractArray, ::SciMLBase.NullParameters) = nothing +# --------------------------------------------- allocate_vjp(x::AbstractArray) = zero(x) # similar(x) allocate_vjp(x::Tuple) = allocate_vjp.(x) allocate_vjp(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp.(values(x))) @@ -83,6 +90,11 @@ allocate_vjp(x) = fmap(allocate_vjp, x) `zero.(x)` for generic `x`. This is used to handle non-array parameters! """ +#--------------------------------------------- +# fix 2: +allocate_zeros(::Nothing) = nothing +allocate_zeros(::SciMLBase.NullParameters) = SciMLBase.NullParameters() +#--------------------------------------------- allocate_zeros(x::AbstractArray) = zero.(x) allocate_zeros(x::Tuple) = allocate_zeros.(x) allocate_zeros(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros.(values(x))) @@ -93,6 +105,10 @@ allocate_zeros(x) = fmap(allocate_zeros, x) `adjoint(y)` for generic `y`. This is used to handle non-array parameters! """ +#--------------------------------------------- +# fix 3: make recursive_adjoint safe on "no params" +recursive_adjoint(::Nothing) = nothing +#--------------------------------------------- recursive_adjoint(y::AbstractArray) = adjoint(y) recursive_adjoint(y::Tuple) = recursive_adjoint.(y) recursive_adjoint(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_adjoint.(values(y)))