Skip to content

[v3-backport] Add Mooncake extension for ArrayPartition cotangents#577

Merged
ChrisRackauckas merged 2 commits intoSciML:v3-backportfrom
ChrisRackauckas-Claude:mooncake-arraypartition-rdata-v3
Apr 12, 2026
Merged

[v3-backport] Add Mooncake extension for ArrayPartition cotangents#577
ChrisRackauckas merged 2 commits intoSciML:v3-backportfrom
ChrisRackauckas-Claude:mooncake-arraypartition-rdata-v3

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

Backport of #575 to the v3 maintenance line.

Adds a new RecursiveArrayToolsMooncakeExt weak-dep extension that registers a Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}}, ::NoRData, ::ArrayPartition{P, T}) method so Mooncake's @from_chainrules/@from_rrule accumulator can handle an ArrayPartition cotangent returned by an upstream ChainRule (e.g. SciMLSensitivity's _concrete_solve_adjoint for a SecondOrderODEProblem). Without this, the call falls through to:

ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
rdata type Mooncake.NoRData, and tangent type
RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
combination is not supported with @from_chainrules or @from_rrule.

Changes

  • ext/RecursiveArrayToolsMooncakeExt.jl — new extension. Walks f.data.x and t.x in lockstep and forwards each leaf to the existing per-array increment_and_get_rdata!, mirroring the ComponentArrays dispatch (Add friendly_tangent_cache function to Mooncake ComponentArrays.jl#350 / ci: explicitly specify token for codecov #351). Returns Mooncake.NoRData().
  • Project.toml — Mooncake added to [weakdeps], [extensions], [compat] (0.5), and to [extras] / [targets.test] so the extension is exercised in CI.
  • test/mooncake.jl — direct unit test for the new dispatch: Float64 two-partition and Float32 three-partition cases, checking in-place accumulation per leaf and the NoRData() return. Wired into the Core testset in runtests.jl.

Test plan

Local Pkg.test on Julia 1.10.11 (v3-backport Project.toml, clean checkout of this branch):

Test Summary:  | Pass  Total  Time
Mooncake Tests |    8      8  1.2s
...
Testing RecursiveArrayTools tests passed

Aqua (including test_stale_deps) passes; full Core testset green.

Related

🤖 Generated with Claude Code

ChrisRackauckas and others added 2 commits April 11, 2026 21:24
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching

    (FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)

There isn't a default method registered for this combination, so the
call falls through to the generic error path:

    ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
    rdata type Mooncake.NoRData, and tangent type
    RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
    combination is not supported with @from_chainrules or @from_rrule.

Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).

Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Add Mooncake to [extras] and [targets.test] so the new
`RecursiveArrayToolsMooncakeExt` is actually loaded and exercised in
the test suite, and add test/mooncake.jl as a direct unit test for the
new `Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}},
::NoRData, ::ArrayPartition{P, T})` dispatch: constructs a matching
FData and ArrayPartition, calls `increment_and_get_rdata!`, and checks
that (a) the in-place accumulation on each inner-array leaf is correct
and (b) the method returns `NoRData()`. Also exercises a three-way
Float32 ArrayPartition to cover a different eltype and arity. Register
the testset in runtests.jl under the Core group.

Backport of SciML#575 to v3-backport.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ChrisRackauckas ChrisRackauckas merged commit 308d4f8 into SciML:v3-backport Apr 12, 2026
22 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants