diff --git a/Project.toml b/Project.toml index 0913e442..82fcfe48 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -27,6 +28,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ComponentArraysGPUArraysExt = "GPUArrays" ComponentArraysKernelAbstractionsExt = "KernelAbstractions" +ComponentArraysMooncakeExt = "Mooncake" ComponentArraysOptimisersExt = "Optimisers" ComponentArraysReactantExt = "Reactant" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" @@ -45,6 +47,7 @@ GPUArrays = "10.3.1, 11" KernelAbstractions = "0.9.29" LinearAlgebra = "1.10" Optimisers = "0.3, 0.4" +Mooncake = "0.5" Reactant = "0.2.15" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" diff --git a/ext/ComponentArraysMooncakeExt.jl b/ext/ComponentArraysMooncakeExt.jl new file mode 100644 index 00000000..8fccd8f9 --- /dev/null +++ b/ext/ComponentArraysMooncakeExt.jl @@ -0,0 +1,14 @@ +module ComponentArraysMooncakeExt + +using ComponentArrays, Mooncake + +# ComponentVector handling in @from_rrule +function Mooncake.increment_and_get_rdata!( + f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}}, + r::Mooncake.NoRData, + t::A, + ) where {P <: Union{Base.IEEEFloat, Complex{<:Base.IEEEFloat}}, A <: Array{P}} + return Mooncake.increment_and_get_rdata!(f.data[:data], r, t) +end + +end