Skip to content

Commit 59a96eb

Browse files
Add Mooncake extension for ArrayPartition cotangents
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`, such as the one produced by `SecondOrderODEProblem`) returns a parameter / state cotangent as an `ArrayPartition`, Mooncake's `@from_chainrules` / `@from_rrule` accumulator looks for an `increment_and_get_rdata!` method matching (FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition) There isn't a default method registered for this combination, so the call falls through to the generic error path: ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}}, rdata type Mooncake.NoRData, and tangent type RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}} combination is not supported with @from_chainrules or @from_rrule. Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt` weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up with `t.x`. Walk the tuple element-by-element and forward each leaf to the existing `increment_and_get_rdata!` for the leaf's array type, which does the actual in-place accumulation. Returns `Mooncake.NoRData()` to match the no-rdata convention used by the equivalent ComponentArrays dispatch (SciML/ComponentArrays.jl#350 / #351). Tested end-to-end against the SciMLSensitivity neural-ODE `SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422, which adds the matching `df_iip`/`df_oop` cotangent unwrap on the SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition` training loop now runs under `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 3f59673 commit 59a96eb

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
2020
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2121
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2222
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
23+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2324
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2425
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -37,6 +38,7 @@ RecursiveArrayToolsFastBroadcastPolyesterExt = ["FastBroadcast", "Polyester"]
3738
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
3839
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
3940
RecursiveArrayToolsMeasurementsExt = "Measurements"
41+
RecursiveArrayToolsMooncakeExt = "Mooncake"
4042
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
4143
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
4244
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
@@ -58,6 +60,7 @@ GPUArraysCore = "0.2"
5860
KernelAbstractions = "0.9.36"
5961
LinearAlgebra = "1.10"
6062
Measurements = "2.11"
63+
Mooncake = "0.5"
6164
MonteCarloMeasurements = "1.2"
6265
NLsolve = "4.5"
6366
Pkg = "1"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module RecursiveArrayToolsMooncakeExt
2+
3+
using RecursiveArrayTools
4+
using Mooncake
5+
6+
# `ArrayPartition` cotangent handling for `@from_chainrules` /
7+
# `@from_rrule`-generated rules.
8+
#
9+
# When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
10+
# `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
11+
# such as the one produced by `SecondOrderODEProblem`) returns a parameter
12+
# / state cotangent as an `ArrayPartition`, Mooncake's
13+
# `@from_chainrules`/`@from_rrule` accumulator looks for an
14+
# `increment_and_get_rdata!` method matching
15+
# `(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)`
16+
# — and there isn't one by default, so the call falls through to the
17+
# generic error path:
18+
#
19+
# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
20+
# rdata type Mooncake.NoRData, and tangent type
21+
# RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
22+
# combination is not supported with @from_chainrules or @from_rrule.
23+
#
24+
# Add the missing dispatch. An `ArrayPartition`'s only field is `x::Tuple`
25+
# of inner arrays, so the FData layout is
26+
# `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up
27+
# with `t.x`. Walk the tuple element-by-element and forward each leaf to
28+
# the existing `increment_and_get_rdata!` for the leaf's array type, which
29+
# does the actual in-place accumulation.
30+
function Mooncake.increment_and_get_rdata!(
31+
f::Mooncake.FData{@NamedTuple{x::T}},
32+
r::Mooncake.NoRData,
33+
t::ArrayPartition{P, T},
34+
) where {P, T <: Tuple}
35+
fxs = f.data[:x]
36+
txs = t.x
37+
@assert length(fxs) == length(txs)
38+
for i in eachindex(fxs)
39+
Mooncake.increment_and_get_rdata!(fxs[i], Mooncake.NoRData(), txs[i])
40+
end
41+
return Mooncake.NoRData()
42+
end
43+
44+
end

0 commit comments

Comments
 (0)