diff --git a/Project.toml b/Project.toml index 0488c683..77955591 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "4.0.1" authors = ["Chris Rackauckas "] +version = "4.0.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -20,6 +20,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -38,6 +39,7 @@ RecursiveArrayToolsForwardDiffExt = "ForwardDiff" RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions" RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" +RecursiveArrayToolsMooncakeExt = "Mooncake" RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] RecursiveArrayToolsStatisticsExt = "Statistics" @@ -59,6 +61,7 @@ KernelAbstractions = "0.9.36" LinearAlgebra = "1.10" Measurements = "2.11" MonteCarloMeasurements = "1.2" +Mooncake = "0.5" NLsolve = "4.5" Pkg = "1" Polyester = "0.7.16" @@ -86,6 +89,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -102,4 +106,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"] +test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "Mooncake", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"] diff --git a/ext/RecursiveArrayToolsMooncakeExt.jl b/ext/RecursiveArrayToolsMooncakeExt.jl new file mode 100644 index 00000000..d74ce862 --- /dev/null +++ b/ext/RecursiveArrayToolsMooncakeExt.jl @@ -0,0 +1,44 @@ +module RecursiveArrayToolsMooncakeExt + +using RecursiveArrayTools +using Mooncake + +# `ArrayPartition` cotangent handling for `@from_chainrules` / +# `@from_rrule`-generated rules. +# +# When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's +# `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`, +# such as the one produced by `SecondOrderODEProblem`) returns a parameter +# / state cotangent as an `ArrayPartition`, Mooncake's +# `@from_chainrules`/`@from_rrule` accumulator looks for an +# `increment_and_get_rdata!` method matching +# `(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)` +# — and there isn't one by default, so the call falls through to the +# generic error path: +# +# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}}, +# rdata type Mooncake.NoRData, and tangent type +# RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}} +# combination is not supported with @from_chainrules or @from_rrule. +# +# Add the missing dispatch. An `ArrayPartition`'s only field is `x::Tuple` +# of inner arrays, so the FData layout is +# `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up +# with `t.x`. Walk the tuple element-by-element and forward each leaf to +# the existing `increment_and_get_rdata!` for the leaf's array type, which +# does the actual in-place accumulation. +function Mooncake.increment_and_get_rdata!( + f::Mooncake.FData{@NamedTuple{x::T}}, + r::Mooncake.NoRData, + t::ArrayPartition{P, T}, + ) where {P, T <: Tuple} + fxs = f.data[:x] + txs = t.x + @assert length(fxs) == length(txs) + for i in eachindex(fxs) + Mooncake.increment_and_get_rdata!(fxs[i], Mooncake.NoRData(), txs[i]) + end + return Mooncake.NoRData() +end + +end diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 00000000..6a0ff28e --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,39 @@ +using RecursiveArrayTools, Mooncake, Test + +# Regression test for the `RecursiveArrayToolsMooncakeExt` dispatch that +# lets Mooncake's `@from_chainrules`/`@from_rrule` accumulator handle an +# `ArrayPartition` cotangent returned by an upstream ChainRule (e.g. +# SciMLSensitivity's `_concrete_solve_adjoint` for a `SecondOrderODEProblem`). +# Without the extension, the call below fell through to Mooncake's generic +# error path: +# +# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float64}, Vector{Float64}}}}, +# rdata type Mooncake.NoRData, and tangent type +# RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}} +# combination is not supported with @from_chainrules or @from_rrule. + +@testset "ArrayPartition increment_and_get_rdata!" begin + @test Base.get_extension(RecursiveArrayTools, :RecursiveArrayToolsMooncakeExt) !== + nothing + + # Tangent produced by an upstream ChainRule. + t = ArrayPartition([1.0, 2.0], [3.0, 4.0]) + # Pre-existing FData that the method should accumulate into in place. + f = Mooncake.FData((x = ([10.0, 20.0], [30.0, 40.0]),)) + + r = Mooncake.increment_and_get_rdata!(f, Mooncake.NoRData(), t) + + @test r === Mooncake.NoRData() + @test f.data.x[1] == [11.0, 22.0] + @test f.data.x[2] == [33.0, 44.0] + + # Three-way partition with Float32 leaves — exercises the inner + # per-leaf dispatch on a different eltype and arity. + t32 = ArrayPartition(Float32[1, 2], Float32[3, 4, 5], Float32[6]) + f32 = Mooncake.FData((x = (Float32[10, 20], Float32[30, 40, 50], Float32[60]),)) + r32 = Mooncake.increment_and_get_rdata!(f32, Mooncake.NoRData(), t32) + @test r32 === Mooncake.NoRData() + @test f32.data.x[1] == Float32[11, 22] + @test f32.data.x[2] == Float32[33, 44, 55] + @test f32.data.x[3] == Float32[66] +end diff --git a/test/runtests.jl b/test/runtests.jl index 7bdb76c8..b87c9aa6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,6 +46,7 @@ end @time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl") @time @safetestset "Linear Algebra Tests" include("linalg.jl") @time @safetestset "Adjoint Tests" include("adjoints.jl") + @time @safetestset "Mooncake Tests" include("mooncake.jl") @time @safetestset "Measurement Tests" include("measurements.jl") end