Skip to content

Commit 6622e19

Browse files
Add CUDA extension to disambiguate CuArray(::AbstractVectorOfArray)
The functor-style `(::Type{<:AbstractGPUArray})(::AbstractVectorOfArray)` cannot resolve the ambiguity with CUDA's `CuArray(::AbstractArray{T,N})` because neither method is uniformly more specific (arg1 vs arg2). Fix: add a CUDA.jl extension that defines the exact disambiguation Julia requests: `CuArray(::AbstractVectorOfArray{T,N}) where {T,N}`. This is loaded only when CUDA is available and resolves the ambiguity by being a direct method on the concrete `CuArray` type. Uses stack(VA.u) to stay on GPU (avoids GPU→CPU→GPU round-trip). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 36bfdb6 commit 6622e19

3 files changed

Lines changed: 20 additions & 6 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1515
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1616

1717
[weakdeps]
18+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1819
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1920
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2021
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -29,6 +30,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[extensions]
33+
RecursiveArrayToolsCUDAExt = "CUDA"
3234
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
3335
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
3436
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
@@ -46,6 +48,7 @@ RecursiveArrayToolsZygoteExt = "Zygote"
4648
Adapt = "4"
4749
Aqua = "0.8"
4850
ArrayInterface = "7.16"
51+
CUDA = "5"
4952
DocStringExtensions = "0.9.3"
5053
FastBroadcast = "0.3.5"
5154
ForwardDiff = "0.10.38, 1"

ext/RecursiveArrayToolsCUDAExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module RecursiveArrayToolsCUDAExt
2+
3+
using RecursiveArrayTools: AbstractVectorOfArray
4+
import CUDA: CuArray
5+
6+
# Disambiguate CuArray(::AbstractVectorOfArray) vs CuArray(::AbstractArray{T,N}) from CUDA.jl.
7+
# This is the exact signature Julia's ambiguity error requests.
8+
# Uses stack to stay on GPU (avoids GPU→CPU→GPU round-trip).
9+
function CuArray(VA::AbstractVectorOfArray{T, N}) where {T, N}
10+
return CuArray{T, N}(stack(VA.u))
11+
end
12+
13+
end

src/RecursiveArrayTools.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,11 @@ module RecursiveArrayTools
176176
# AbstractVectorOfArray uses AbstractArray's show
177177

178178
import GPUArraysCore
179+
# GPU conversion via stack (stays on device, avoids materializing dense Array).
180+
# Only define convert — do NOT define a constructor (::Type{<:AbstractGPUArray})(::AbstractVectorOfArray)
181+
# because it creates an unresolvable ambiguity with CUDA.jl's CuArray(::AbstractArray{T,N}).
182+
# Instead, the KernelAbstractions extension defines the concrete CuArray dispatch.
179183
Base.convert(::Type{T}, VA::AbstractVectorOfArray) where {T <: GPUArraysCore.AnyGPUArray} = T(stack(VA.u))
180-
# Constructor: CuArray(va) etc. Must disambiguate with CuArray(::AbstractArray{T,N})
181-
# from CUDA.jl. AbstractVectorOfArray{T,N} is more specific than AbstractArray{T,N}
182-
# on arg2, matching {T,N} ensures equal specificity to CUDA's method.
183-
function (::Type{GA})(VA::AbstractVectorOfArray{T, N}) where {T, N, GA <: GPUArraysCore.AbstractGPUArray}
184-
return GA(stack(VA.u))
185-
end
186184

187185
export VectorOfArray, VA, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
188186
AllObserved, vecarr_to_vectors, tuples

0 commit comments

Comments
 (0)