Skip to content

Commit 56ec0b3

Browse files
Add GPU error hint for any/all on ArrayPartition, test subpackage loading
Without RecursiveArrayToolsArrayPartitionAnyAll, any/all on ArrayPartition with GPU sub-arrays would trigger scalar indexing errors. The base package now defines any/all methods for ArrayPartition that: - Check for GPU sub-arrays and throw a helpful error directing users to load RecursiveArrayToolsArrayPartitionAnyAll - Fall through to AbstractArray element-by-element iteration for CPU arrays Also adds a Subpackages test group that verifies loading RecursiveArrayToolsArrayPartitionAnyAll correctly overrides the methods with optimized partition-level iteration. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f91d1cc commit 56ec0b3

3 files changed

Lines changed: 64 additions & 8 deletions

File tree

src/array_partition.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,38 @@ ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve
740740
741741
"""
742742
struct AP end
743+
744+
# any/all: provide methods that work partition-by-partition to avoid scalar indexing on GPUs.
745+
# Without RecursiveArrayToolsArrayPartitionAnyAll, the AbstractArray fallback iterates
746+
# element-by-element, which triggers scalar GPU indexing errors. These methods check for
747+
# GPU sub-arrays and give a helpful error, or fall through to the partition-level implementation.
748+
const _ANYALL_HINT = """
749+
`any`/`all` on `ArrayPartition` with GPU arrays requires loading the subpackage:
750+
using RecursiveArrayToolsArrayPartitionAnyAll
751+
This provides optimized partition-level `any`/`all` that avoids scalar GPU indexing.
752+
"""
753+
754+
function _check_gpu_anyall(A::ArrayPartition)
755+
for x in A.x
756+
if x isa GPUArraysCore.AnyGPUArray
757+
error(_ANYALL_HINT)
758+
end
759+
end
760+
end
761+
762+
function Base.any(f::Function, A::ArrayPartition)
763+
_check_gpu_anyall(A)
764+
return Base.invoke(any, Tuple{Function, AbstractArray}, f, A)
765+
end
766+
function Base.all(f::Function, A::ArrayPartition)
767+
_check_gpu_anyall(A)
768+
return Base.invoke(all, Tuple{Function, AbstractArray}, f, A)
769+
end
770+
function Base.any(f, A::ArrayPartition)
771+
_check_gpu_anyall(A)
772+
return Base.invoke(any, Tuple{Any, AbstractArray}, f, A)
773+
end
774+
function Base.all(f, A::ArrayPartition)
775+
_check_gpu_anyall(A)
776+
return Base.invoke(all, Tuple{Any, AbstractArray}, f, A)
777+
end

test/partitions_test.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,16 @@ recursivecopy!(dest_ap, src_ap)
178178
@inferred mapreduce(string, *, x)
179179
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"
180180

181-
# any/all — optimized partition-level iteration requires RecursiveArrayToolsArrayPartitionAnyAll.
182-
# Without the extension, these use the AbstractArray element-by-element fallback.
183-
# `using RecursiveArrayToolsArrayPartitionAnyAll` enables ~1.5-1.8x faster partition-level
184-
# short-circuiting. These @test_broken verify the optimized methods are NOT active.
185-
@test_broken which(any, Tuple{Function, ArrayPartition}).module !== Base
186-
@test_broken which(all, Tuple{Function, ArrayPartition}).module !== Base
187-
188-
# Correctness tests still pass via AbstractArray fallback
181+
# any/all — the base package defines GPU-aware methods that error for GPU sub-arrays
182+
# and fall through to AbstractArray iteration for CPU arrays. The optimized partition-level
183+
# iteration (1.5-1.8x faster) requires `using RecursiveArrayToolsArrayPartitionAnyAll`.
184+
# These @test_broken verify the optimized extension is NOT loaded.
185+
@test_broken occursin("ArrayPartitionAnyAll",
186+
string(which(any, Tuple{Function, ArrayPartition}).module))
187+
@test_broken occursin("ArrayPartitionAnyAll",
188+
string(which(all, Tuple{Function, ArrayPartition}).module))
189+
190+
# Correctness tests pass via the base package methods (element-by-element on CPU)
189191
@test !any(isnan, AP[[1, 2], [3.0, 4.0]])
190192
@test !any(isnan, AP[[3.0, 4.0]])
191193
@test any(isnan, AP[[NaN], [3.0, 4.0]])

test/runtests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,25 @@ end
4545
@time @safetestset "SymbolicIndexingInterface API test" include("symbolic_indexing_interface_test.jl")
4646
end
4747

48+
if GROUP == "Subpackages" || GROUP == "All"
49+
# Test that loading RecursiveArrayToolsArrayPartitionAnyAll overrides any/all
50+
Pkg.develop(PackageSpec(
51+
path = joinpath(dirname(@__DIR__), "lib", "RecursiveArrayToolsArrayPartitionAnyAll")))
52+
@time @safetestset "ArrayPartition AnyAll Subpackage" begin
53+
using RecursiveArrayTools, RecursiveArrayToolsArrayPartitionAnyAll, Test
54+
# Verify optimized methods are active
55+
m_any = which(any, Tuple{Function, ArrayPartition})
56+
m_all = which(all, Tuple{Function, ArrayPartition})
57+
@test occursin("ArrayPartitionAnyAll", string(m_any.module))
58+
@test occursin("ArrayPartitionAnyAll", string(m_all.module))
59+
# Verify correctness
60+
@test any(isnan, ArrayPartition([NaN], [1.0]))
61+
@test !any(isnan, ArrayPartition([1.0], [2.0]))
62+
@test all(isnan, ArrayPartition([NaN], [NaN]))
63+
@test !all(isnan, ArrayPartition([NaN], [1.0]))
64+
end
65+
end
66+
4867
if GROUP == "Downstream"
4968
activate_downstream_env()
5069
@time @safetestset "ODE Solve Tests" include("downstream/odesolve.jl")

0 commit comments

Comments
 (0)