Skip to content

Add friendly_tangent_cache function to Mooncake#350

Merged
ChrisRackauckas merged 1 commit intoSciML:mainfrom
yebai:patch-1
Apr 11, 2026
Merged

Add friendly_tangent_cache function to Mooncake#350
ChrisRackauckas merged 1 commit intoSciML:mainfrom
yebai:patch-1

Conversation

@yebai
Copy link
Copy Markdown
Contributor

@yebai yebai commented Apr 10, 2026

Register ComponentArray's friendly tangent type. See chalk-lab/Mooncake.jl#1137 (comment)

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Add any other context about the problem here.

@yebai
Copy link
Copy Markdown
Contributor Author

yebai commented Apr 10, 2026

cc @ChrisRackauckas

@ChrisRackauckas
Copy link
Copy Markdown
Member

Okay perfect. I was just looking into this because it was the last step to change all tutorials in SciMLSensitivity to Mooncake

@ChrisRackauckas ChrisRackauckas merged commit 410d4b7 into SciML:main Apr 11, 2026
17 of 18 checks passed
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/SciMLSensitivity.jl that referenced this pull request Apr 11, 2026
After SciML/ComponentArrays.jl#350 (released as ComponentArrays
v0.15.34) registers a `friendly_tangent_cache` override for
`ComponentArray`, the `OPT.AutoMooncake(; config = Mooncake.Config(;
friendly_tangents = true))` form now uses the friendly-tangent
unwrap path inside Mooncake itself, which solves the same
`copyto!(::ComponentVector, ::Mooncake.Tangent)` crash that
JuliaDiff/DifferentiationInterface.jl#989 fixed at the DI layer for
the `config = nothing` default.

I re-tested the migration with **stock DI 0.7.16** plus
**ComponentArrays from main (0.15.34)** and confirmed the migrated
tutorials still pass end-to-end (LV+CA BFGS, multiple_nn Lux+CA Adam,
local_minima Lux+CA Adam, parameter_estimation_ode PolyOpt,
getting_started + GaussAdjoint, bouncing_ball with the
`last(sol.u)[1]` workaround, divergence, exogenous_input, etc.).

The reverted tutorials (\`EnsembleProblem\`, \`MethodOfSteps\` DDE,
\`SecondOrderODEProblem\`, nested CV, \`ReverseDiffAdjoint\` inner,
\`SimpleChains\`+\`StaticArrays\`, FBDF stiff PDE) are still blocked
on independent upstream issues that CA SciML#350 does not address — I
reverified each one with friendly_tangents+CA-main and they still
fail with the same errors recorded in the !!! note callouts.

This commit:

1. Switches every migrated `OPT.AutoMooncake(; config = nothing)` /
   `SMS.AutoMooncake(...)` / `DI.AutoMooncake(...)` / `ADTypes.AutoMooncake(...)`
   to `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`
   (and the equivalent for the other prefixes).
2. Updates the recommended pattern shown in every `!!! note` callout
   on the still-Zygote tutorials to match.
3. Bumps the `ComponentArrays` compat in `docs/Project.toml` from
   `0.15` to `0.15.34` so the docs build picks up the friendly-tangent
   support.

With this change the SMS docs PR no longer hard-depends on
JuliaDiff/DifferentiationInterface.jl#989. That DI patch is still an
independently useful improvement (it makes the default
`config = nothing` form work without the user having to know about the
flag, and also fixes the `MVector`/`SVector` cases), but it is no
longer a blocker for this migration.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/RecursiveArrayTools.jl that referenced this pull request Apr 11, 2026
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching

    (FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)

There isn't a default method registered for this combination, so the
call falls through to the generic error path:

    ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
    rdata type Mooncake.NoRData, and tangent type
    RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
    combination is not supported with @from_chainrules or @from_rrule.

Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).

Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/RecursiveArrayTools.jl that referenced this pull request Apr 11, 2026
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching

    (FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)

There isn't a default method registered for this combination, so the
call falls through to the generic error path:

    ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
    rdata type Mooncake.NoRData, and tangent type
    RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
    combination is not supported with @from_chainrules or @from_rrule.

Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).

Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/RecursiveArrayTools.jl that referenced this pull request Apr 12, 2026
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching

    (FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)

There isn't a default method registered for this combination, so the
call falls through to the generic error path:

    ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
    rdata type Mooncake.NoRData, and tangent type
    RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
    combination is not supported with @from_chainrules or @from_rrule.

Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).

Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants