diff --git a/ext/ComponentArraysMooncakeExt.jl b/ext/ComponentArraysMooncakeExt.jl index 60d51cce..a7a5de8f 100644 --- a/ext/ComponentArraysMooncakeExt.jl +++ b/ext/ComponentArraysMooncakeExt.jl @@ -1,16 +1,113 @@ module ComponentArraysMooncakeExt using ComponentArrays, Mooncake +using Base: IEEEFloat -# ComponentVector handling in @from_rrule +const _FloatLike = Union{IEEEFloat, Complex{<:IEEEFloat}} + +# === Flat-Array-backed ComponentVector fdata ========================================== +# `Mooncake.FData{@NamedTuple{data::A, axes::NoFData}}` is the fdata layout of a +# `ComponentArray{T, N, A<:Array, Axes}` — the common "owns its storage" case. +# +# We need to handle three incoming ChainRules cotangent shapes that arise from +# `@from_rrule` / `@from_chainrules` declarations: +# (a) a raw `Array{P}` matching the primal underlying storage, +# (b) a `ComponentArray` with the same underlying storage type, +# (c) a `ComponentArray` whose data field is a different `AbstractArray{P}` +# (e.g. a `SubArray` produced by projecting a parent cotangent). + +# (a) raw Array cotangent function Mooncake.increment_and_get_rdata!( f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}}, r::Mooncake.NoRData, t::A, - ) where {P <: Union{Base.IEEEFloat, Complex{<:Base.IEEEFloat}}, A <: Array{P}} + ) where {P <: _FloatLike, A <: Array{P}} return Mooncake.increment_and_get_rdata!(f.data[:data], r, t) end +# (b) / (c) ComponentArray cotangent against a flat-Array-backed primal +function Mooncake.increment_and_get_rdata!( + f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}}, + r::Mooncake.NoRData, + t::ComponentArray{P, N, <:AbstractArray{P}}, + ) where {P <: _FloatLike, N, A <: Array{P}} + data_t = getdata(t) + t_vec = data_t isa Array{P} ? data_t : collect(data_t) + return Mooncake.increment_and_get_rdata!(f.data[:data], r, t_vec) +end + +# === SubArray-backed ComponentVector fdata ============================================ +# A `ComponentVector` produced by `getproperty(::ComponentVector, ::Symbol)` (and any +# other view-producing path) wraps a `SubArray` rather than a `Vector`. Its Mooncake +# fdata accordingly nests an inner `FData` describing the SubArray's fields. +# +# We can only aggregate a ChainRules cotangent into this layout when the view fully +# covers its parent — otherwise the unmodelled indices leave us unable to place the +# cotangent into the correct slice of the parent tangent. That "full cover" case is +# however the common one: sub-CVs that land at an `@from_rrule` boundary are usually +# freshly allocated and own all of their parent storage. Outside of that, we raise a +# clear error instead of silently corrupting gradients. + +function _increment_subarray_fdata!(f_cv, t_data::AbstractArray{P}) where {P <: _FloatLike} + parent = f_cv.data[:data].data[:parent] + if length(t_data) != length(parent) + throw( + ArgumentError( + "ComponentArraysMooncakeExt: cannot aggregate a cotangent of length " * + "$(length(t_data)) into a SubArray-backed ComponentVector tangent whose " * + "parent has length $(length(parent)). This happens when a cotangent " * + "flows into a view that does not fully cover its parent; there is no " * + "way to recover the view indices from Mooncake fdata alone. Please " * + "file an issue against ComponentArrays.jl with a reproducer so the " * + "offending rrule can be patched.", + ), + ) + end + t_vec = t_data isa Array{P} ? t_data : collect(t_data) + Mooncake.increment_and_get_rdata!(parent, Mooncake.NoRData(), t_vec) + return Mooncake.NoRData() +end + +function Mooncake.increment_and_get_rdata!( + f::Mooncake.FData{ + @NamedTuple{ + data::Mooncake.FData{ + @NamedTuple{ + parent::Array{P, 1}, + indices::Mooncake.NoFData, + offset1::Mooncake.NoFData, + stride1::Mooncake.NoFData, + }, + }, + axes::Mooncake.NoFData, + }, + }, + r::Mooncake.NoRData, + t::Array{P}, + ) where {P <: _FloatLike} + return _increment_subarray_fdata!(f, t) +end + +function Mooncake.increment_and_get_rdata!( + f::Mooncake.FData{ + @NamedTuple{ + data::Mooncake.FData{ + @NamedTuple{ + parent::Array{P, 1}, + indices::Mooncake.NoFData, + offset1::Mooncake.NoFData, + stride1::Mooncake.NoFData, + }, + }, + axes::Mooncake.NoFData, + }, + }, + r::Mooncake.NoRData, + t::ComponentArray{P, N, <:AbstractArray{P}}, + ) where {P <: _FloatLike, N} + return _increment_subarray_fdata!(f, getdata(t)) +end + function Mooncake.friendly_tangent_cache(x::ComponentArray) Mooncake.FriendlyTangentCache{Mooncake.AsPrimal}(copy(x)) end diff --git a/test/autodiff/Project.toml b/test/autodiff/Project.toml index cec53860..0d382e23 100644 --- a/test/autodiff/Project.toml +++ b/test/autodiff/Project.toml @@ -1,8 +1,10 @@ [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -11,8 +13,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ArrayInterface = "7.22.0" +ChainRulesCore = "1" FiniteDiff = "2.29.0" ForwardDiff = "1.3.1" +Mooncake = "0.5.26" Optimisers = "0.4.7" ReverseDiff = "1.16.2" Tracker = "0.2.38" diff --git a/test/autodiff/autodiff_tests.jl b/test/autodiff/autodiff_tests.jl index 119b6e9c..d412abac 100644 --- a/test/autodiff/autodiff_tests.jl +++ b/test/autodiff/autodiff_tests.jl @@ -1,5 +1,5 @@ using ComponentArrays -import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote +import ChainRulesCore, FiniteDiff, ForwardDiff, Mooncake, ReverseDiff, Tracker, Zygote using Optimisers, ArrayInterface using Test @@ -134,3 +134,74 @@ end @test ArrayInterface.restructure(ps, ps_tracked) isa ComponentVector{<:Any, <:Tracker.TrackedArray} end + +@testset "Mooncake" begin + # Native Mooncake rules — gradient through `getproperty` on a flat ComponentVector + # and on a ComponentVector with nested axes. + flat = ComponentArray(a = 1.0, b = 2.0, c = 3.0) + loss_flat(p) = sum(abs2, p.a) + 0.5 * p.b + p.c^2 + let + cache = Mooncake.prepare_gradient_cache(loss_flat, flat) + _, g = Mooncake.value_and_gradient!!(cache, loss_flat, flat) + @test g[2].fields.data ≈ [2.0, 0.5, 6.0] + end + + u0 = ComponentArray(x = 1.0, y = 2.0) + p_all = ComponentArray(a = 3.0, b = 4.0, c = 5.0) + nested = ComponentArray(; u0, p_all) + loss_nested(θ) = sum(abs2, θ.u0) + 0.5 * sum(abs2, θ.p_all) + let + cache = Mooncake.prepare_gradient_cache(loss_nested, nested) + _, g = Mooncake.value_and_gradient!!(cache, loss_nested, nested) + @test g[2].fields.data ≈ [2.0, 4.0, 3.0, 4.0, 5.0] + end + + # @from_rrule round-trip — this is the path that fails without the extension, + # because ComponentArrays' `ChainRulesCore.rrule` for `getdata`/`getproperty`/ + # `Type{ComponentArray}(...)` returns `ComponentArray` cotangents that Mooncake's + # `increment_and_get_rdata!` dispatch rejects by default. Downstream SciML + # packages hit this whenever they declare an `@from_rrule` with a ComponentArray + # argument, which is how the SciMLSensitivity tutorials exercised it. + sum_abs2(x::AbstractArray) = sum(abs2, x) + function ChainRulesCore.rrule(::typeof(sum_abs2), x::AbstractArray) + y = sum_abs2(x) + function sum_abs2_pb(Δy) + return ( + ChainRulesCore.NoTangent(), + ComponentArray(2 .* Δy .* getdata(x), getaxes(x)), + ) + end + return y, sum_abs2_pb + end + function ChainRulesCore.rrule(::typeof(sum_abs2), x::AbstractVector) + y = sum_abs2(x) + sum_abs2_pb(Δy) = (ChainRulesCore.NoTangent(), 2 .* Δy .* x) + return y, sum_abs2_pb + end + Mooncake.@from_rrule( + Mooncake.DefaultCtx, + Tuple{typeof(sum_abs2), ComponentVector{Float64, Vector{Float64}}}, + ) + + # (a) ComponentArray cotangent against a flat-Array-backed CV fdata + let + v = ComponentArray(a = 1.0, b = 2.0, c = 3.0) + cache = Mooncake.prepare_gradient_cache(sum_abs2, v) + val, g = Mooncake.value_and_gradient!!(cache, sum_abs2, v) + @test val ≈ 14.0 + @test g[2].fields.data ≈ [2.0, 4.0, 6.0] + end + + # (b) Nested ComponentArray constructed with `ComponentArray(; u0, p_all)` + # — the "feedback_control.md" layout from SciMLSensitivity#1419. + let + nested2 = ComponentArray(; u0 = [1.0, 2.0], p_all = ComponentArray(a = 3.0, b = 4.0)) + cache = Mooncake.prepare_gradient_cache(sum_abs2, nested2) + val, g = Mooncake.value_and_gradient!!(cache, sum_abs2, nested2) + @test val ≈ 30.0 + @test g[2].fields.data ≈ [2.0, 4.0, 6.0, 8.0] + end + + @test Mooncake.friendly_tangent_cache(flat) isa + Mooncake.FriendlyTangentCache{Mooncake.AsPrimal} +end