From b1a3308640add835ccc77f04318bdef7afdd7cd2 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Fri, 8 May 2026 14:24:52 -0400 Subject: [PATCH 1/2] fix(ad): EnsembleSolution cotangent + remake reverse-mode plumbing - EnsembleSolution constructor adjoint/rrule: return cotangent for `sim` as a plain Vector matching the constructor argument's shape, not re-wrapped in another EnsembleSolution. Add `Tangent` overload for cotangents arriving from the upstream `Array(::AbstractVectorOfArray)` adjoint. - Delete the stale `@adjoint ODESolution{T1...T15}` (15 type params, ODESolution has 16 since `saved_subsystem` was added; the matching ChainRulesCore rrule for the 16-param case already covers it). - `remake(::ODEProblem; ...)` and 7 sibling overloads: guard the `maybe_eager_initialize_problem` + `@reset prob.u0/p` block on `initialization_data !== nothing` so the no-init-data path stays at a single ODEProblem construction. - Lower `remake(prob::ODEProblem; ...)` into a positional `_remake_ode_inner(prob, f, u0, tspan, p, kwargs, ...)` helper and attach the ChainRulesCore.rrule to the helper. Reverse-mode AD now flows kwarg cotangents naturally without going through Zygote's kwarg-cotangent plumbing. Cotangent distribution: if a kwarg was passed (not `missing`), its cotangent goes to that positional; otherwise it accumulates onto the corresponding `prob` field. Tests in `test/downstream/ensemble_remake_reverse_mode_adjoints.jl`: - Direct cotangent-shape assertions on the EnsembleSolution constructor adjoint (Zygote and ChainRulesCore paths). - Zygote-vs-ForwardDiff parity for `remake(::ODEProblem; u0)`, `remake(::ODEProblem; p)`, and field-pass-through. - Direct shape assertions on the `_remake_ode_inner` rrule's 12-tuple return covering both kwarg-supplied and prob-fall-through routing. Co-Authored-By: Chris Rackauckas --- ext/SciMLBaseChainRulesCoreExt.jl | 85 +++++++++++------ ext/SciMLBaseZygoteExt.jl | 26 +----- src/remake.jl | 76 ++++++++++----- .../ensemble_remake_reverse_mode_adjoints.jl | 93 +++++++++++++++++++ test/runtests.jl | 3 + 5 files changed, 211 insertions(+), 72 deletions(-) create mode 100644 test/downstream/ensemble_remake_reverse_mode_adjoints.jl diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 30d3ed79c..920924994 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,12 +1,62 @@ module SciMLBaseChainRulesCoreExt using SciMLBase -using SciMLBase: getobserved +using SciMLBase: getobserved, ODEProblem, _remake_ode_inner import ChainRulesCore -import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad +import ChainRulesCore: NoTangent, ZeroTangent, AbstractZero, Tangent, + @non_differentiable, zero_tangent, rrule_via_ad, backing using SymbolicIndexingInterface using RecursiveArrayTools: AbstractVectorOfArray +@inline function _remake_ode_inner_split_cotangent(Δ, f, u0, tspan, p) + Δ_nt = if Δ isa Tangent + b = backing(Δ) + b isa NamedTuple ? b : + (b isa Tuple && length(b) == 1 && b[1] isa NamedTuple ? b[1] : NamedTuple()) + elseif Δ isa NamedTuple + Δ + else + NamedTuple() + end + get_cot(field::Symbol) = (Δ_nt isa NamedTuple && haskey(Δ_nt, field)) ? + Δ_nt[field] : nothing + f_cot = (f === missing) ? NoTangent() : get_cot(:f) + u0_cot = (u0 === missing) ? NoTangent() : get_cot(:u0) + tspan_cot = (tspan === missing) ? NoTangent() : get_cot(:tspan) + p_cot = (p === missing) ? NoTangent() : get_cot(:p) + prob_cot = ( + f = (f === missing) ? get_cot(:f) : nothing, + u0 = (u0 === missing) ? get_cot(:u0) : nothing, + tspan = (tspan === missing) ? get_cot(:tspan) : nothing, + p = (p === missing) ? get_cot(:p) : nothing, + kwargs = nothing, + problem_type = nothing, + ) + return prob_cot, f_cot, u0_cot, tspan_cot, p_cot +end + +function ChainRulesCore.rrule( + ::typeof(_remake_ode_inner), + prob::ODEProblem, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, _kwargs + ) + new_prob = _remake_ode_inner( + prob, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, _kwargs + ) + function _remake_ode_inner_pullback(Δ) + prob_cot, f_cot, u0_cot, tspan_cot, p_cot = _remake_ode_inner_split_cotangent( + Δ, f, u0, tspan, p + ) + return (NoTangent(), prob_cot, f_cot, u0_cot, tspan_cot, p_cot, + NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), + NoTangent()) + end + return new_prob, _remake_ode_inner_pullback +end + @non_differentiable SciMLBase.checkkwargs(kwargshandle) # numargs and isinplace use `methods()` for runtime reflection and are not differentiable. @@ -163,8 +213,6 @@ function ChainRulesCore.rrule( RODESolutionAdjoint end -# EnsembleSolution rrule with full support for various gradient types -# Matches the Zygote extension implementation for consistency function ChainRulesCore.rrule( ::Type{EnsembleSolution}, sim, time, converged, stats = nothing ) @@ -176,38 +224,23 @@ function ChainRulesCore.rrule( for j in 1:size(p̄)[end - 1] ] for i in 1:size(p̄)[end] ] - return ( - NoTangent(), - EnsembleSolution(arrarr, 0.0, true, stats), - NoTangent(), - NoTangent(), - NoTangent(), - ) + return (NoTangent(), arrarr, NoTangent(), NoTangent(), NoTangent()) end function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - return ( - NoTangent(), - EnsembleSolution(p̄, 0.0, true, stats), - NoTangent(), - NoTangent(), - NoTangent(), - ) + return (NoTangent(), p̄, NoTangent(), NoTangent(), NoTangent()) end function EnsembleSolution_adjoint(p̄::AbstractVectorOfArray) - return ( - NoTangent(), - EnsembleSolution(p̄, 0.0, true, stats), - NoTangent(), - NoTangent(), - NoTangent(), - ) + return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) end function EnsembleSolution_adjoint(p̄::EnsembleSolution) - return (NoTangent(), p̄, NoTangent(), NoTangent(), NoTangent()) + return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) end function EnsembleSolution_adjoint(p̄::NamedTuple) return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) end + function EnsembleSolution_adjoint(p̄::Tangent) + return (NoTangent(), backing(p̄).u, NoTangent(), NoTangent(), NoTangent()) + end return out, EnsembleSolution_adjoint end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index b61678678..1bd1021fa 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -35,16 +35,16 @@ end for j in 1:size(p̄)[end - 1] ] for i in 1:size(p̄)[end] ] - (EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing) + (arrarr, nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) + (p̄, nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::RecursiveArrayTools.AbstractVectorOfArray) - (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) + (p̄.u, nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::EnsembleSolution) - (p̄, nothing, nothing, nothing) + (p̄.u, nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::NamedTuple) (p̄.u, nothing, nothing, nothing) @@ -190,24 +190,6 @@ end VA[sym], NonlinearSolution_getindex_pullback end -@adjoint function ODESolution{ - T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, - }( - u, - args... - ) where { - T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12, T13, T14, T15, - } - function ODESolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) - end - - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}( - u, args... - ), - ODESolutionAdjoint -end @adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}( u, diff --git a/src/remake.jl b/src/remake.jl index acd722909..8f88ebb0d 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -335,6 +335,18 @@ function remake( lazy_initialization = nothing, _kwargs... ) + return _remake_ode_inner( + prob, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, values(_kwargs) + ) +end + +function _remake_ode_inner( + prob::ODEProblem, f, u0, tspan, p, kwargs, + interpret_symbolicmap, build_initializeprob, use_defaults, + lazy_initialization, _kwargs + ) if tspan === missing tspan = prob.tspan end @@ -382,9 +394,11 @@ function remake( ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -600,9 +614,11 @@ function remake( SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -667,9 +683,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -764,9 +782,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -826,9 +846,11 @@ function remake( DAEProblem{iip}(f, du0, newu0, tspan, newp; differential_vars, kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -953,9 +975,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -1000,9 +1024,11 @@ function remake( SteadyStateProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp; kwargs...) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end @@ -1048,9 +1074,11 @@ function remake( ) end - u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) - @reset prob.u0 = u0 - @reset prob.p = p + if initialization_data !== nothing + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p + end return prob end diff --git a/test/downstream/ensemble_remake_reverse_mode_adjoints.jl b/test/downstream/ensemble_remake_reverse_mode_adjoints.jl new file mode 100644 index 000000000..82f037766 --- /dev/null +++ b/test/downstream/ensemble_remake_reverse_mode_adjoints.jl @@ -0,0 +1,93 @@ +using SciMLBase, OrdinaryDiffEq, Test +using SciMLBase: EnsembleProblem, EnsembleSerial, EnsembleSolution +using Zygote, ForwardDiff +import ChainRulesCore + +@testset "EnsembleSolution constructor cotangent shape" begin + f(u, p, t) = -u + prob = ODEProblem(f, [1.0], (0.0, 1.0)) + sols = [solve(prob, Tsit5(); saveat = 0.5) for _ in 1:3] + es = EnsembleSolution(sols, 0.0, true, nothing) + arr = Array(es) + @test ndims(arr) == 3 + + _, back = Zygote.pullback(EnsembleSolution, sols, 0.0, true, nothing) + sim_cot, = back(ones(eltype(arr), size(arr))) + + @test !(sim_cot isa EnsembleSolution) + @test length(sim_cot) == length(sols) + @test all(eltype(c) <: AbstractArray for c in sim_cot) +end + +@testset "EnsembleSolution rrule cotangent shape" begin + f(u, p, t) = -u + prob = ODEProblem(f, [1.0], (0.0, 1.0)) + sols = [solve(prob, Tsit5(); saveat = 0.5) for _ in 1:3] + es = EnsembleSolution(sols, 0.0, true, nothing) + arr = Array(es) + + _, pb = ChainRulesCore.rrule(EnsembleSolution, sols, 0.0, true, nothing) + sim_cot = pb(ones(eltype(arr), size(arr)))[2] + @test !(sim_cot isa EnsembleSolution) + @test length(sim_cot) == length(sols) +end + +@testset "remake(::ODEProblem; u0) gradient parity" begin + f(u, p, t) = u + base_prob = ODEProblem(f, [0.0, 0.0], (0.0, 1.0), [1.0]) + loss(p) = (q = remake(base_prob, u0 = [p[1] * 2, p[1] + 5]); sum(abs2, q.u0)) + p0 = [3.0] + @test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6 +end + +@testset "remake(::ODEProblem; p) gradient parity" begin + f(u, p, t) = p[1] * u + base_prob = ODEProblem(f, [1.0], (0.0, 1.0), [0.5]) + loss(p) = (q = remake(base_prob, p = [p[1] * 3]); sum(abs2, q.p)) + p0 = [2.0] + @test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6 +end + +@testset "remake field-pass-through gradient parity" begin + f(u, p, t) = u + base_prob = ODEProblem(f, [1.0], (0.0, 1.0), [1.0]) + loss(p) = (q = remake(base_prob, u0 = [p[1]]); sum(abs2, q.u0)) + p0 = [2.5] + @test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6 +end + +@testset "_remake_ode_inner rrule cotangent distribution" begin + base_prob = ODEProblem((u, p, t) -> u, [1.0, 2.0], (0.0, 1.0), [3.0]) + Δ_u0 = [10.0, 20.0] + Δ = ChainRulesCore.Tangent{Any}(; + f = ChainRulesCore.NoTangent(), + u0 = Δ_u0, + tspan = ChainRulesCore.NoTangent(), + p = ChainRulesCore.NoTangent(), + kwargs = ChainRulesCore.NoTangent(), + problem_type = ChainRulesCore.NoTangent(), + ) + + # u0 supplied → cotangent flows to the u0 positional. + _, pb = ChainRulesCore.rrule( + SciMLBase._remake_ode_inner, + base_prob, missing, [9.9, 8.8], missing, missing, missing, + true, Val{true}, false, nothing, NamedTuple() + ) + cot = pb(Δ) + @test length(cot) == 12 + @test cot[1] === ChainRulesCore.NoTangent() + @test cot[4] == Δ_u0 + @test cot[2].u0 === nothing + @test all(cot[i] === ChainRulesCore.NoTangent() for i in 7:12) + + # u0 not supplied → cotangent accumulates onto prob.u0. + _, pb2 = ChainRulesCore.rrule( + SciMLBase._remake_ode_inner, + base_prob, missing, missing, missing, [99.0], missing, + true, Val{true}, false, nothing, NamedTuple() + ) + cot2 = pb2(Δ) + @test cot2[4] === ChainRulesCore.NoTangent() + @test cot2[2].u0 == Δ_u0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 72a5867d1..98bf0ff5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,9 @@ end @time @safetestset "Autodiff Remake" begin include("downstream/remake_autodiff.jl") end + @time @safetestset "Ensemble + remake reverse-mode adjoints" begin + include("downstream/ensemble_remake_reverse_mode_adjoints.jl") + end @time @safetestset "Partial Functions" begin include("downstream/partial_functions.jl") end From d86532b43dfc72c6b620bcea18c7ac7a20990d84 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Tue, 12 May 2026 08:09:20 -0400 Subject: [PATCH 2/2] trim EnsembleSolution constructor adjoints to load-bearing branches Drop the AbstractArray{T,N} reshape, AbstractArray{<:AbstractArray,1}, AbstractVectorOfArray, and EnsembleSolution-self dispatches. The realistic AD chain feeds a partial NamedTuple cotangent (u = ...) from the upstream Array(::AbstractVectorOfArray) adjoint; only the NamedTuple (Zygote ext) and NamedTuple/Tangent (ChainRules ext) branches are reachable from the realistic chain. The dropped branches were defensive coverage for direct callers no production path uses. Test rewritten to feed a NamedTuple cotangent directly, matching what the upstream adjoint produces. Co-Authored-By: Chris Rackauckas --- ext/SciMLBaseChainRulesCoreExt.jl | 18 ----------- ext/SciMLBaseZygoteExt.jl | 18 ----------- .../ensemble_remake_reverse_mode_adjoints.jl | 32 +++++++------------ 3 files changed, 12 insertions(+), 56 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 920924994..ff9629f2e 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -217,24 +217,6 @@ function ChainRulesCore.rrule( ::Type{EnsembleSolution}, sim, time, converged, stats = nothing ) out = EnsembleSolution(sim, time, converged, stats) - function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} - arrarr = [ - [ - p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] - for j in 1:size(p̄)[end - 1] - ] for i in 1:size(p̄)[end] - ] - return (NoTangent(), arrarr, NoTangent(), NoTangent(), NoTangent()) - end - function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - return (NoTangent(), p̄, NoTangent(), NoTangent(), NoTangent()) - end - function EnsembleSolution_adjoint(p̄::AbstractVectorOfArray) - return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) - end - function EnsembleSolution_adjoint(p̄::EnsembleSolution) - return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) - end function EnsembleSolution_adjoint(p̄::NamedTuple) return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent()) end diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 1bd1021fa..5f3939fb1 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -28,24 +28,6 @@ end @adjoint function EnsembleSolution(sim, time, converged, stats) out = EnsembleSolution(sim, time, converged, stats) - function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} - arrarr = [ - [ - p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] - for j in 1:size(p̄)[end - 1] - ] for i in 1:size(p̄)[end] - ] - (arrarr, nothing, nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (p̄, nothing, nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::RecursiveArrayTools.AbstractVectorOfArray) - (p̄.u, nothing, nothing, nothing) - end - function EnsembleSolution_adjoint(p̄::EnsembleSolution) - (p̄.u, nothing, nothing, nothing) - end function EnsembleSolution_adjoint(p̄::NamedTuple) (p̄.u, nothing, nothing, nothing) end diff --git a/test/downstream/ensemble_remake_reverse_mode_adjoints.jl b/test/downstream/ensemble_remake_reverse_mode_adjoints.jl index 82f037766..193e9536d 100644 --- a/test/downstream/ensemble_remake_reverse_mode_adjoints.jl +++ b/test/downstream/ensemble_remake_reverse_mode_adjoints.jl @@ -3,33 +3,25 @@ using SciMLBase: EnsembleProblem, EnsembleSerial, EnsembleSolution using Zygote, ForwardDiff import ChainRulesCore -@testset "EnsembleSolution constructor cotangent shape" begin +@testset "EnsembleSolution constructor pulls NamedTuple cotangent" begin f(u, p, t) = -u prob = ODEProblem(f, [1.0], (0.0, 1.0)) sols = [solve(prob, Tsit5(); saveat = 0.5) for _ in 1:3] - es = EnsembleSolution(sols, 0.0, true, nothing) - arr = Array(es) - @test ndims(arr) == 3 + arrarr = [[copy(s.u[j]) for j in eachindex(s.u)] for s in sols] _, back = Zygote.pullback(EnsembleSolution, sols, 0.0, true, nothing) - sim_cot, = back(ones(eltype(arr), size(arr))) - - @test !(sim_cot isa EnsembleSolution) - @test length(sim_cot) == length(sols) - @test all(eltype(c) <: AbstractArray for c in sim_cot) -end - -@testset "EnsembleSolution rrule cotangent shape" begin - f(u, p, t) = -u - prob = ODEProblem(f, [1.0], (0.0, 1.0)) - sols = [solve(prob, Tsit5(); saveat = 0.5) for _ in 1:3] - es = EnsembleSolution(sols, 0.0, true, nothing) - arr = Array(es) + sim_cot, t_cot, c_cot, s_cot = back((u = arrarr,)) + @test sim_cot == arrarr + @test t_cot === nothing && c_cot === nothing && s_cot === nothing _, pb = ChainRulesCore.rrule(EnsembleSolution, sols, 0.0, true, nothing) - sim_cot = pb(ones(eltype(arr), size(arr)))[2] - @test !(sim_cot isa EnsembleSolution) - @test length(sim_cot) == length(sols) + cot = pb((u = arrarr,)) + @test cot[2] == arrarr + @test cot[1] === ChainRulesCore.NoTangent() + @test all(cot[i] === ChainRulesCore.NoTangent() for i in 3:5) + + cot_t = pb(ChainRulesCore.Tangent{Any}(; u = arrarr)) + @test cot_t[2] == arrarr end @testset "remake(::ODEProblem; u0) gradient parity" begin