diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 30d3ed79c..ff9629f2e 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,12 +1,62 @@ module SciMLBaseChainRulesCoreExt using SciMLBase -using SciMLBase: getobserved +using SciMLBase: getobserved, ODEProblem, _remake_ode_inner import ChainRulesCore -import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad +import ChainRulesCore: NoTangent, ZeroTangent, AbstractZero, Tangent, + @non_differentiable, zero_tangent, rrule_via_ad, backing using SymbolicIndexingInterface using RecursiveArrayTools: AbstractVectorOfArray +@inline function _remake_ode_inner_split_cotangent(Δ, f, u0, tspan, p) + Δ_nt = if Δ isa Tangent + b = backing(Δ) + b isa NamedTuple ? b : + (b isa Tuple && length(b) == 1 && b[1] isa NamedTuple ? b[1] : NamedTuple()) + elseif Δ isa NamedTuple + Δ + else + NamedTuple() + end + get_cot(field::Symbol) = (Δ_nt isa NamedTuple && haskey(Δ_nt, field)) ? + Δ_nt[field] : nothing + f_cot = (f === missing) ? NoTangent() : get_cot(:f) + u0_cot = (u0 === missing) ? NoTangent() : get_cot(:u0) + tspan_cot = (tspan === missing) ? NoTangent() : get_cot(:tspan) + p_cot = (p === missing) ? NoTangent() : get_cot(:p) + prob_cot = ( + f = (f === missing) ? get_cot(:f) : nothing, + u0 = (u0 === missing) ? get_cot(:u0) : nothing, + tspan = (tspan === missing) ? get_cot(:tspan) : nothing, + p = (p === missing) ? get_cot(:p) : nothing, + kwargs = nothing, + problem_type = nothing, + ) + return prob_cot, f_cot, u0_cot, tspan_cot, p_cot +end + +function ChainRulesCore.rrule( + ::typeof(_remake_ode_inner), + prob::ODEProblem, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, _kwargs + ) + new_prob = _remake_ode_inner( + prob, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, _kwargs + ) + function _remake_ode_inner_pullback(Δ) + prob_cot, f_cot, u0_cot, tspan_cot, p_cot = _remake_ode_inner_split_cotangent( + Δ, f, u0, tspan, p + ) + return (NoTangent(), prob_cot, f_cot, u0_cot, tspan_cot, p_cot, + NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), + NoTangent()) + end + return new_prob, _remake_ode_inner_pullback +end + @non_differentiable SciMLBase.checkkwargs(kwargshandle) # numargs and isinplace use `methods()` for runtime reflection and are not differentiable. @@ -163,51 +213,16 @@ function ChainRulesCore.rrule( RODESolutionAdjoint end -# EnsembleSolution rrule with full support for various gradient types -# Matches the Zygote extension implementation for consistency function ChainRulesCore.rrule( ::Type{EnsembleSolution}, sim, time, converged, stats = nothing ) out = EnsembleSolution(sim, time, converged, stats) - function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} - arrarr = [ - [ - p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] - for j in 1:size(p̄)[end - 1] - ] for i in 1:size(p̄)[end] - ] - return ( - NoTangent(), - EnsembleSolution(arrarr, 0.0, true, stats), - NoTangent(), - NoTangent(), - NoTangent(), - ) - end - function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - return ( - NoTangent(), - EnsembleSolution(p̄, 0.0, true, stats), - NoTangent(), - NoTangent(), - NoTangent(), - ) - end - function EnsembleSolution_adjoint(p̄::AbstractVectorOfArray) - return ( - NoTangent(), - EnsembleSolution(p̄, 0.0, true, stats), - NoTangent(), - NoTangent(), - NoTangent(), - ) - end - function EnsembleSolution_adjoint(p̄::EnsembleSolution) - return (NoTangent(), p̄, NoTangent(), NoTangent(), NoTangent()) - end function EnsembleSolution_adjoint(p̄::NamedTuple) return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) end + function EnsembleSolution_adjoint(p̄::Tangent) + return (NoTangent(), backing(p̄).u, NoTangent(), NoTangent(), NoTangent()) + end return out, EnsembleSolution_adjoint end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index b61678678..5f3939fb1 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -28,24 +28,6 @@ end @adjoint function EnsembleSolution(sim, time, converged, stats) out = EnsembleSolution(sim, time, converged, stats) - function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} - arrarr = [ - [ - p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] - for j in 1:size(p̄)[end - 1] - ] for i in 1:size(p̄)[end] - ] - (EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::RecursiveArrayTools.AbstractVectorOfArray) - (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::EnsembleSolution) - (p̄, nothing, nothing, nothing) - end function EnsembleSolution_adjoint(p̄::NamedTuple) (p̄.u, nothing, nothing, nothing) end @@ -190,24 +172,6 @@ end VA[sym], NonlinearSolution_getindex_pullback end -@adjoint function ODESolution{ - T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, - }( - u, - args... - ) where { - T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12, T13, T14, T15, - } - function ODESolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) - end - - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}( - u, args... - ), - ODESolutionAdjoint -end @adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}( u, diff --git a/src/remake.jl b/src/remake.jl index acd722909..8f88ebb0d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -335,6 +335,18 @@ function remake( lazy_initialization = nothing, _kwargs... ) + return _remake_ode_inner( + prob, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, values(_kwargs) + ) +end + +function _remake_ode_inner( + prob::ODEProblem, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, _kwargs + ) if tspan === missing tspan = prob.tspan end @@ -382,9 +394,11 @@ function remake( ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -600,9 +614,11 @@ function remake( SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -667,9 +683,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -764,9 +782,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -826,9 +846,11 @@ function remake( DAEProblem{iip}(f, du0, newu0, tspan, newp; differential_vars, kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -953,9 +975,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -1000,9 +1024,11 @@ function remake( SteadyStateProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp; kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -1048,9 +1074,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end diff --git a/test/downstream/ensemble_remake_reverse_mode_adjoints.jl b/test/downstream/ensemble_remake_reverse_mode_adjoints.jl new file mode 100644 index 000000000..193e9536d --- /dev/null +++ b/test/downstream/ensemble_remake_reverse_mode_adjoints.jl @@ -0,0 +1,85 @@ +using SciMLBase, OrdinaryDiffEq, Test +using SciMLBase: EnsembleProblem, EnsembleSerial, EnsembleSolution +using Zygote, ForwardDiff +import ChainRulesCore + +@testset "EnsembleSolution constructor pulls NamedTuple cotangent" begin + f(u, p, t) = -u + prob = ODEProblem(f, [1.0], (0.0, 1.0)) + sols = [solve(prob, Tsit5(); saveat = 0.5) for _ in 1:3] + arrarr = [[copy(s.u[j]) for j in eachindex(s.u)] for s in sols] + + _, back = Zygote.pullback(EnsembleSolution, sols, 0.0, true, nothing) + sim_cot, t_cot, c_cot, s_cot = back((u = arrarr,)) + @test sim_cot == arrarr + @test t_cot === nothing && c_cot === nothing && s_cot === nothing + + _, pb = ChainRulesCore.rrule(EnsembleSolution, sols, 0.0, true, nothing) + cot = pb((u = arrarr,)) + @test cot[2] == arrarr + @test cot[1] === ChainRulesCore.NoTangent() + @test all(cot[i] === ChainRulesCore.NoTangent() for i in 3:5) + + cot_t = pb(ChainRulesCore.Tangent{Any}(; u = arrarr)) + @test cot_t[2] == arrarr +end + +@testset "remake(::ODEProblem; u0) gradient parity" begin + f(u, p, t) = u + base_prob = ODEProblem(f, [0.0, 0.0], (0.0, 1.0), [1.0]) + loss(p) = (q = remake(base_prob, u0 = [p[1] * 2, p[1] + 5]); sum(abs2, q.u0)) + p0 = [3.0] + @test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6 +end + +@testset "remake(::ODEProblem; p) gradient parity" begin + f(u, p, t) = p[1] * u + base_prob = ODEProblem(f, [1.0], (0.0, 1.0), [0.5]) + loss(p) = (q = remake(base_prob, p = [p[1] * 3]); sum(abs2, q.p)) + p0 = [2.0] + @test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6 +end + +@testset "remake field-pass-through gradient parity" begin + f(u, p, t) = u + base_prob = ODEProblem(f, [1.0], (0.0, 1.0), [1.0]) + loss(p) = (q = remake(base_prob, u0 = [p[1]]); sum(abs2, q.u0)) + p0 = [2.5] + @test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6 +end + +@testset "_remake_ode_inner rrule cotangent distribution" begin + base_prob = ODEProblem((u, p, t) -> u, [1.0, 2.0], (0.0, 1.0), [3.0]) + Δ_u0 = [10.0, 20.0] + Δ = ChainRulesCore.Tangent{Any}(; + f = ChainRulesCore.NoTangent(), + u0 = Δ_u0, + tspan = ChainRulesCore.NoTangent(), + p = ChainRulesCore.NoTangent(), + kwargs = ChainRulesCore.NoTangent(), + problem_type = ChainRulesCore.NoTangent(), + ) + + # u0 supplied → cotangent flows to the u0 positional. + _, pb = ChainRulesCore.rrule( + SciMLBase._remake_ode_inner, + base_prob, missing, [9.9, 8.8], missing, missing, missing, + true, Val{true}, false, nothing, NamedTuple() + ) + cot = pb(Δ) + @test length(cot) == 12 + @test cot[1] === ChainRulesCore.NoTangent() + @test cot[4] == Δ_u0 + @test cot[2].u0 === nothing + @test all(cot[i] === ChainRulesCore.NoTangent() for i in 7:12) + + # u0 not supplied → cotangent accumulates onto prob.u0. + _, pb2 = ChainRulesCore.rrule( + SciMLBase._remake_ode_inner, + base_prob, missing, missing, missing, [99.0], missing, + true, Val{true}, false, nothing, NamedTuple() + ) + cot2 = pb2(Δ) + @test cot2[4] === ChainRulesCore.NoTangent() + @test cot2[2].u0 == Δ_u0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 72a5867d1..98bf0ff5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,9 @@ end @time @safetestset "Autodiff Remake" begin include("downstream/remake_autodiff.jl") end + @time @safetestset "Ensemble + remake reverse-mode adjoints" begin + include("downstream/ensemble_remake_reverse_mode_adjoints.jl") + end @time @safetestset "Partial Functions" begin include("downstream/partial_functions.jl") end