Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecursiveArrayTools"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "4.0.1"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "4.0.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -20,6 +20,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -38,6 +39,7 @@ RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsMooncakeExt = "Mooncake"
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
RecursiveArrayToolsStatisticsExt = "Statistics"
Expand All @@ -59,6 +61,7 @@ KernelAbstractions = "0.9.36"
LinearAlgebra = "1.10"
Measurements = "2.11"
MonteCarloMeasurements = "1.2"
Mooncake = "0.5"
NLsolve = "4.5"
Pkg = "1"
Polyester = "0.7.16"
Expand Down Expand Up @@ -86,6 +89,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -102,4 +106,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "Mooncake", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
44 changes: 44 additions & 0 deletions ext/RecursiveArrayToolsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module RecursiveArrayToolsMooncakeExt

using RecursiveArrayTools
using Mooncake

# `ArrayPartition` cotangent handling for `@from_chainrules` /
# `@from_rrule`-generated rules.
#
# When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
# `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
# such as the one produced by `SecondOrderODEProblem`) returns a parameter
# / state cotangent as an `ArrayPartition`, Mooncake's
# `@from_chainrules`/`@from_rrule` accumulator looks for an
# `increment_and_get_rdata!` method matching
# `(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)`
# — and there isn't one by default, so the call falls through to the
# generic error path:
#
# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
# rdata type Mooncake.NoRData, and tangent type
# RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
# combination is not supported with @from_chainrules or @from_rrule.
#
# Add the missing dispatch. An `ArrayPartition`'s only field is `x::Tuple`
# of inner arrays, so the FData layout is
# `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up
# with `t.x`. Walk the tuple element-by-element and forward each leaf to
# the existing `increment_and_get_rdata!` for the leaf's array type, which
# does the actual in-place accumulation.
function Mooncake.increment_and_get_rdata!(
f::Mooncake.FData{@NamedTuple{x::T}},
r::Mooncake.NoRData,
t::ArrayPartition{P, T},
) where {P, T <: Tuple}
fxs = f.data[:x]
txs = t.x
@assert length(fxs) == length(txs)
for i in eachindex(fxs)
Mooncake.increment_and_get_rdata!(fxs[i], Mooncake.NoRData(), txs[i])
end
return Mooncake.NoRData()
end

end
39 changes: 39 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using RecursiveArrayTools, Mooncake, Test

# Regression test for the `RecursiveArrayToolsMooncakeExt` dispatch that
# lets Mooncake's `@from_chainrules`/`@from_rrule` accumulator handle an
# `ArrayPartition` cotangent returned by an upstream ChainRule (e.g.
# SciMLSensitivity's `_concrete_solve_adjoint` for a `SecondOrderODEProblem`).
# Without the extension, the call below fell through to Mooncake's generic
# error path:
#
# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float64}, Vector{Float64}}}},
# rdata type Mooncake.NoRData, and tangent type
# RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}
# combination is not supported with @from_chainrules or @from_rrule.

@testset "ArrayPartition increment_and_get_rdata!" begin
@test Base.get_extension(RecursiveArrayTools, :RecursiveArrayToolsMooncakeExt) !==
nothing

# Tangent produced by an upstream ChainRule.
t = ArrayPartition([1.0, 2.0], [3.0, 4.0])
# Pre-existing FData that the method should accumulate into in place.
f = Mooncake.FData((x = ([10.0, 20.0], [30.0, 40.0]),))

r = Mooncake.increment_and_get_rdata!(f, Mooncake.NoRData(), t)

@test r === Mooncake.NoRData()
@test f.data.x[1] == [11.0, 22.0]
@test f.data.x[2] == [33.0, 44.0]

# Three-way partition with Float32 leaves — exercises the inner
# per-leaf dispatch on a different eltype and arity.
t32 = ArrayPartition(Float32[1, 2], Float32[3, 4, 5], Float32[6])
f32 = Mooncake.FData((x = (Float32[10, 20], Float32[30, 40, 50], Float32[60]),))
r32 = Mooncake.increment_and_get_rdata!(f32, Mooncake.NoRData(), t32)
@test r32 === Mooncake.NoRData()
@test f32.data.x[1] == Float32[11, 22]
@test f32.data.x[2] == Float32[33, 44, 55]
@test f32.data.x[3] == Float32[66]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end
@time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl")
@time @safetestset "Linear Algebra Tests" include("linalg.jl")
@time @safetestset "Adjoint Tests" include("adjoints.jl")
@time @safetestset "Mooncake Tests" include("mooncake.jl")
@time @safetestset "Measurement Tests" include("measurements.jl")
end

Expand Down
Loading