ext/Mooncake: handle ComponentArray cotangents at @from_rrule boundaries#352
Merged
ChrisRackauckas merged 1 commit intoSciML:mainfrom Apr 11, 2026
Conversation
The existing `increment_and_get_rdata!` method only matched a raw
`Array{P}` tangent against a flat-`Array`-backed ComponentVector fdata.
In practice the tangent coming out of a `ChainRulesCore.rrule` for a
ComponentArray primal is usually *another* ComponentArray (e.g. via
`ComponentArray(Δ, getaxes(x))`), so downstream packages that declare a
`@from_rrule` / `@from_chainrules` boundary with a ComponentArray
argument hit
ArgumentError: The fdata type ... ComponentVector{...} combination
is not supported with @from_chainrules or @from_rrule.
This is what blocked the Mooncake migration of the SciMLSensitivity.jl
tutorials in SciML/SciMLSensitivity.jl#1419 (the `feedback_control.md`
and `second_order_neural.md` notes). Widen the dispatch to cover:
- flat-`Array`-backed ComponentVector fdata with an incoming
`ComponentArray` cotangent (unwrap to the underlying storage),
- SubArray-backed ComponentVector fdata (produced by
`getproperty(::ComponentVector, ::Symbol)`) with either an `Array`
or a `ComponentArray` cotangent — handled for the common
full-parent-coverage case, with a clear `ArgumentError` for the
partial-view case that would otherwise silently misplace gradient
mass.
Tests: exercise both native Mooncake (`prepare_gradient_cache` +
`value_and_gradient!!` over nested `ComponentArray(; u0, p_all)`) and
the `@from_rrule` round-trip path that the new methods target. Adds
Mooncake to `test/autodiff/Project.toml` (pinned to `0.5.26` to match
the `friendly_tangent_cache` symbol the extension already references).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ComponentArraysMooncakeExt.increment_and_get_rdata!to cover the cotangent shapes that realChainRulesCore.rrules produce forComponentArrayprimals, so downstream packages that declare an@from_rrule/@from_chainrulesboundary with aComponentArrayargument actually work.Mooncaketotest/autodiffwith a focused@testsetthat exercises both native Mooncake (prepare_gradient_cache/value_and_gradient!!over nestedComponentArray(; u0, p_all)) and the@from_rruleround-trip that the new dispatch targets.Motivation
SciML/SciMLSensitivity.jl#1419migrates the docs tutorials from Zygote to Mooncake, but several of them (flagged as!!! notein that PR —second_order_neural.md,brusselator.md,feedback_control.md, …) have to stay on Zygote because they hitThe root cause is that the existing extension method
only matches when the ChainRules tangent
tis a rawArray{P}. In practice the rrules defined insrc/compat/chainrulescore.jl(forgetproperty,getdata,Type{ComponentArray}(data, axes),Type{CA}(nt::NamedTuple)) return cotangents that are themselvesComponentArrays (either flat-Array-backed or, when a view is involved,SubArray-backed). Any SciML package that declares a Mooncake primitive with aComponentVectorargument therefore funnels throughincrement_and_get_rdata!with aComponentArraytangent and hits the fallback error above.What this PR adds
Three additional methods on
Mooncake.increment_and_get_rdata!:Array-backed CV fdata +ComponentArraycotangent. Unwrap viagetdata(t)and delegate to the underlying storage. This is the common case — a loss function takes aComponentVector{Float64, Vector{Float64}}, the rrule returnsComponentArray(Δ, getaxes(x)).SubArray-backed CV fdata +Arraycotangent. Produced whenevergetproperty(::ComponentVector, ::Symbol)or any other view-producing operation crosses an@from_rruleboundary. We aggregate into the parent-array slot of theSubArray's structural tangent for the full-parent-coverage case, which is what actually lands at these boundaries in practice.SubArray-backed CV fdata +ComponentArraycotangent. Same as (2), but firstgetdata(t).Cases (2) and (3) raise a clear
ArgumentErrorfor the partial-view case (where the view's linear indices can't be recovered from fdata alone), so we never silently misplace gradient mass. Opening an issue with a reproducer is straightforward if anyone hits that path.The existing raw-
Arraymethod andMooncake.friendly_tangent_cachedefinition are preserved verbatim.Tests
New
@testset \"Mooncake\"intest/autodiff/autodiff_tests.jl:prepare_gradient_cache/value_and_gradient!!on a flatComponentVectorand on a nestedComponentArray(; u0, p_all)layout (matches thefeedback_control.mdshape fromSciML/SciMLSensitivity.jl#1419).sum_abs2with a hand-writtenChainRulesCore.rrulewhose pullback returns aComponentArraycotangent, declared as a Mooncake primitive via@from_rrule. Two cases: a flatComponentVector, and a nestedComponentArray(; u0 = Vector, p_all = ComponentArray). Both paths fail onmainand pass after this patch.Mooncake.friendly_tangent_cache(::ComponentArray)still returns aFriendlyTangentCache{AsPrimal}.test/autodiff/Project.tomlgrowsMooncake = \"0.5.26\"andChainRulesCore = \"1\". The0.5.26pin matches thefriendly_tangent_cachesymbol the extension already references (it doesn't exist in 0.5.24 and earlier — that precompile failure is how I discovered the existing extension has an implicit floor that the main Project.toml'sMooncake = \"0.5\"doesn't encode; worth tightening in a follow-up but orthogonal to this PR).Local results on Julia 1.12:
GROUP=Autodiff→ 56/56 pass (49 prior + 7 new asserts), 3m24sGROUP=Core→ 459 pass / 9 pre-existing broken, 1m45sNot in scope
The
MooncakeRuleCompilationErrormentioned in thefeedback_control.md/brusselator.mdnotes ofSciML/SciMLSensitivity.jl#1419is a compile-time failure inside Mooncake's rule builder rather than anincrement_and_get_rdata!dispatch gap — I couldn't reproduce it from any standalone ComponentArrays snippet, so it appears to originate in SciMLBase/SciMLSensitivity's adjoint stack rather than in ComponentArrays itself. This PR is targeted at the runtimeincrement_and_get_rdata!gap the notes explicitly attribute to "ComponentArrays' Mooncake extension".Test plan
GROUP=Autodiff julia --project=test/autodiff test/runtests.jl— 56/56 passGROUP=Core julia --project=test test/runtests.jl— 459/459 (+9 pre-existing broken)🤖 Generated with Claude Code