Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions ext/ComponentArraysMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,113 @@
module ComponentArraysMooncakeExt

using ComponentArrays, Mooncake
using Base: IEEEFloat

# ComponentVector handling in @from_rrule
const _FloatLike = Union{IEEEFloat, Complex{<:IEEEFloat}}

# === Flat-Array-backed ComponentVector fdata ==========================================
# `Mooncake.FData{@NamedTuple{data::A, axes::NoFData}}` is the fdata layout of a
# `ComponentArray{T, N, A<:Array, Axes}` — the common "owns its storage" case.
#
# We need to handle three incoming ChainRules cotangent shapes that arise from
# `@from_rrule` / `@from_chainrules` declarations:
# (a) a raw `Array{P}` matching the primal underlying storage,
# (b) a `ComponentArray` with the same underlying storage type,
# (c) a `ComponentArray` whose data field is a different `AbstractArray{P}`
# (e.g. a `SubArray` produced by projecting a parent cotangent).

# (a) raw Array cotangent
function Mooncake.increment_and_get_rdata!(
f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}},
r::Mooncake.NoRData,
t::A,
) where {P <: Union{Base.IEEEFloat, Complex{<:Base.IEEEFloat}}, A <: Array{P}}
) where {P <: _FloatLike, A <: Array{P}}
return Mooncake.increment_and_get_rdata!(f.data[:data], r, t)
end

# (b) / (c) ComponentArray cotangent against a flat-Array-backed primal
function Mooncake.increment_and_get_rdata!(
f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}},
r::Mooncake.NoRData,
t::ComponentArray{P, N, <:AbstractArray{P}},
) where {P <: _FloatLike, N, A <: Array{P}}
data_t = getdata(t)
t_vec = data_t isa Array{P} ? data_t : collect(data_t)
return Mooncake.increment_and_get_rdata!(f.data[:data], r, t_vec)
end

# === SubArray-backed ComponentVector fdata ============================================
# A `ComponentVector` produced by `getproperty(::ComponentVector, ::Symbol)` (and any
# other view-producing path) wraps a `SubArray` rather than a `Vector`. Its Mooncake
# fdata accordingly nests an inner `FData` describing the SubArray's fields.
#
# We can only aggregate a ChainRules cotangent into this layout when the view fully
# covers its parent — otherwise the unmodelled indices leave us unable to place the
# cotangent into the correct slice of the parent tangent. That "full cover" case is
# however the common one: sub-CVs that land at an `@from_rrule` boundary are usually
# freshly allocated and own all of their parent storage. Outside of that, we raise a
# clear error instead of silently corrupting gradients.

function _increment_subarray_fdata!(f_cv, t_data::AbstractArray{P}) where {P <: _FloatLike}
parent = f_cv.data[:data].data[:parent]
if length(t_data) != length(parent)
throw(
ArgumentError(
"ComponentArraysMooncakeExt: cannot aggregate a cotangent of length " *
"$(length(t_data)) into a SubArray-backed ComponentVector tangent whose " *
"parent has length $(length(parent)). This happens when a cotangent " *
"flows into a view that does not fully cover its parent; there is no " *
"way to recover the view indices from Mooncake fdata alone. Please " *
"file an issue against ComponentArrays.jl with a reproducer so the " *
"offending rrule can be patched.",
),
)
end
t_vec = t_data isa Array{P} ? t_data : collect(t_data)
Mooncake.increment_and_get_rdata!(parent, Mooncake.NoRData(), t_vec)
return Mooncake.NoRData()
end

function Mooncake.increment_and_get_rdata!(
f::Mooncake.FData{
@NamedTuple{
data::Mooncake.FData{
@NamedTuple{
parent::Array{P, 1},
indices::Mooncake.NoFData,
offset1::Mooncake.NoFData,
stride1::Mooncake.NoFData,
},
},
axes::Mooncake.NoFData,
},
},
r::Mooncake.NoRData,
t::Array{P},
) where {P <: _FloatLike}
return _increment_subarray_fdata!(f, t)
end

function Mooncake.increment_and_get_rdata!(
f::Mooncake.FData{
@NamedTuple{
data::Mooncake.FData{
@NamedTuple{
parent::Array{P, 1},
indices::Mooncake.NoFData,
offset1::Mooncake.NoFData,
stride1::Mooncake.NoFData,
},
},
axes::Mooncake.NoFData,
},
},
r::Mooncake.NoRData,
t::ComponentArray{P, N, <:AbstractArray{P}},
) where {P <: _FloatLike, N}
return _increment_subarray_fdata!(f, getdata(t))
end

function Mooncake.friendly_tangent_cache(x::ComponentArray)
Mooncake.FriendlyTangentCache{Mooncake.AsPrimal}(copy(x))
end
Expand Down
4 changes: 4 additions & 0 deletions test/autodiff/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -11,8 +13,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ArrayInterface = "7.22.0"
ChainRulesCore = "1"
FiniteDiff = "2.29.0"
ForwardDiff = "1.3.1"
Mooncake = "0.5.26"
Optimisers = "0.4.7"
ReverseDiff = "1.16.2"
Tracker = "0.2.38"
Expand Down
73 changes: 72 additions & 1 deletion test/autodiff/autodiff_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ComponentArrays
import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote
import ChainRulesCore, FiniteDiff, ForwardDiff, Mooncake, ReverseDiff, Tracker, Zygote
using Optimisers, ArrayInterface
using Test

Expand Down Expand Up @@ -134,3 +134,74 @@ end
@test ArrayInterface.restructure(ps, ps_tracked) isa
ComponentVector{<:Any, <:Tracker.TrackedArray}
end

@testset "Mooncake" begin
# Native Mooncake rules — gradient through `getproperty` on a flat ComponentVector
# and on a ComponentVector with nested axes.
flat = ComponentArray(a = 1.0, b = 2.0, c = 3.0)
loss_flat(p) = sum(abs2, p.a) + 0.5 * p.b + p.c^2
let
cache = Mooncake.prepare_gradient_cache(loss_flat, flat)
_, g = Mooncake.value_and_gradient!!(cache, loss_flat, flat)
@test g[2].fields.data ≈ [2.0, 0.5, 6.0]
end

u0 = ComponentArray(x = 1.0, y = 2.0)
p_all = ComponentArray(a = 3.0, b = 4.0, c = 5.0)
nested = ComponentArray(; u0, p_all)
loss_nested(θ) = sum(abs2, θ.u0) + 0.5 * sum(abs2, θ.p_all)
let
cache = Mooncake.prepare_gradient_cache(loss_nested, nested)
_, g = Mooncake.value_and_gradient!!(cache, loss_nested, nested)
@test g[2].fields.data ≈ [2.0, 4.0, 3.0, 4.0, 5.0]
end

# @from_rrule round-trip — this is the path that fails without the extension,
# because ComponentArrays' `ChainRulesCore.rrule` for `getdata`/`getproperty`/
# `Type{ComponentArray}(...)` returns `ComponentArray` cotangents that Mooncake's
# `increment_and_get_rdata!` dispatch rejects by default. Downstream SciML
# packages hit this whenever they declare an `@from_rrule` with a ComponentArray
# argument, which is how the SciMLSensitivity tutorials exercised it.
sum_abs2(x::AbstractArray) = sum(abs2, x)
function ChainRulesCore.rrule(::typeof(sum_abs2), x::AbstractArray)
y = sum_abs2(x)
function sum_abs2_pb(Δy)
return (
ChainRulesCore.NoTangent(),
ComponentArray(2 .* Δy .* getdata(x), getaxes(x)),
)
end
return y, sum_abs2_pb
end
function ChainRulesCore.rrule(::typeof(sum_abs2), x::AbstractVector)
y = sum_abs2(x)
sum_abs2_pb(Δy) = (ChainRulesCore.NoTangent(), 2 .* Δy .* x)
return y, sum_abs2_pb
end
Mooncake.@from_rrule(
Mooncake.DefaultCtx,
Tuple{typeof(sum_abs2), ComponentVector{Float64, Vector{Float64}}},
)

# (a) ComponentArray cotangent against a flat-Array-backed CV fdata
let
v = ComponentArray(a = 1.0, b = 2.0, c = 3.0)
cache = Mooncake.prepare_gradient_cache(sum_abs2, v)
val, g = Mooncake.value_and_gradient!!(cache, sum_abs2, v)
@test val ≈ 14.0
@test g[2].fields.data ≈ [2.0, 4.0, 6.0]
end

# (b) Nested ComponentArray constructed with `ComponentArray(; u0, p_all)`
# — the "feedback_control.md" layout from SciMLSensitivity#1419.
let
nested2 = ComponentArray(; u0 = [1.0, 2.0], p_all = ComponentArray(a = 3.0, b = 4.0))
cache = Mooncake.prepare_gradient_cache(sum_abs2, nested2)
val, g = Mooncake.value_and_gradient!!(cache, sum_abs2, nested2)
@test val ≈ 30.0
@test g[2].fields.data ≈ [2.0, 4.0, 6.0, 8.0]
end

@test Mooncake.friendly_tangent_cache(flat) isa
Mooncake.FriendlyTangentCache{Mooncake.AsPrimal}
end
Loading