Skip to content

Commit 69798c3

Browse files
Fix GPU CuArray ambiguity and type-stable sum for LTS
- Fix CuArray ambiguity: replace (T::Type{<:AnyGPUArray})(VA) = T(Array(VA)) with T(stack(VA.u)) and add N type parameter to disambiguate against CuArray(::AbstractArray{T,N}) from CUDA.jl - Add type-stable Base.sum(VA::AbstractVectorOfArray) that reduces over .u to avoid inference failure on Julia 1.10 with deeply nested type parameters (fixes @inferred sum(VA[VA[zeros(4,4)]]) on LTS) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4a17e3c commit 69798c3

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ module RecursiveArrayTools
177177

178178
import GPUArraysCore
179179
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
180-
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA))
180+
# Disambiguate with CuArray(::AbstractArray{T,N}) by providing the typed method
181+
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray{<:Any, N}) where {N} = T(stack(VA.u))
181182

182183
export VectorOfArray, VA, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
183184
AllObserved, vecarr_to_vectors, tuples

src/vector_of_array.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,16 @@ vecarr_to_vectors(VA::AbstractVectorOfArray) = [VA[i, :] for i in eachindex(VA.u
11181118
# linear algebra
11191119
ArrayInterface.issingular(va::AbstractVectorOfArray) = ArrayInterface.issingular(Matrix(va))
11201120

1121+
# Type-stable sum/mapreduce that avoids inference issues on Julia 1.10
1122+
# with deeply nested VectorOfArray type parameters
1123+
function Base.sum(VA::AbstractVectorOfArray{T}) where {T}
1124+
return sum(sum, VA.u)::T
1125+
end
1126+
1127+
function Base.sum(f::F, VA::AbstractVectorOfArray{T}) where {F, T}
1128+
return sum(u -> sum(f, u), VA.u)
1129+
end
1130+
11211131
# make it show just like its data
11221132
function Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray)
11231133
(println(io, summary(x), ':'); show(io, m, x.u))

0 commit comments

Comments
 (0)