Skip to content

Commit a9edc84

Browse files
Inline optimized any/all for ArrayPartition on Julia 1.13+
Julia 1.13 (JuliaLang/julia#61184) removes the f::Function restriction from any/all, so defining any(f, ::ArrayPartition) no longer causes ~780 invalidations. On 1.13+, the optimized partition-level methods are defined directly in the main package. On older Julia, the GPU-check fallback remains and the RecursiveArrayToolsArrayPartitionAnyAll subpackage is needed for optimized methods. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 56ec0b3 commit a9edc84

2 files changed

Lines changed: 64 additions & 40 deletions

File tree

src/array_partition.jl

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -741,37 +741,56 @@ ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve
741741
"""
742742
struct AP end
743743

744-
# any/all: provide methods that work partition-by-partition to avoid scalar indexing on GPUs.
745-
# Without RecursiveArrayToolsArrayPartitionAnyAll, the AbstractArray fallback iterates
746-
# element-by-element, which triggers scalar GPU indexing errors. These methods check for
747-
# GPU sub-arrays and give a helpful error, or fall through to the partition-level implementation.
748-
const _ANYALL_HINT = """
749-
`any`/`all` on `ArrayPartition` with GPU arrays requires loading the subpackage:
750-
using RecursiveArrayToolsArrayPartitionAnyAll
751-
This provides optimized partition-level `any`/`all` that avoids scalar GPU indexing.
752-
"""
753-
754-
function _check_gpu_anyall(A::ArrayPartition)
755-
for x in A.x
756-
if x isa GPUArraysCore.AnyGPUArray
757-
error(_ANYALL_HINT)
744+
# any/all on ArrayPartition: partition-level iteration avoids scalar GPU indexing
745+
# and is ~1.5-1.8x faster than element-by-element AbstractArray fallback.
746+
#
747+
# On Julia ≥ 1.13, Base no longer restricts any/all to f::Function, so defining
748+
# any(f, ::ArrayPartition) causes minimal invalidation. On older Julia, the
749+
# f::Function methods caused ~780 invalidations, so the optimized methods were
750+
# separated into RecursiveArrayToolsArrayPartitionAnyAll.
751+
#
752+
# We now define them inline with a GPU check: if sub-arrays are GPU arrays and
753+
# we're on old Julia without the extension, give a helpful error.
754+
@static if VERSION >= v"1.13.0-DEV.0"
755+
# Julia 1.13+: safe to define directly, minimal invalidation
756+
Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x))
757+
Base.any(f::Function, A::ArrayPartition) = any((any(f, x) for x in A.x))
758+
Base.any(A::ArrayPartition) = any(identity, A)
759+
Base.all(f, A::ArrayPartition) = all((all(f, x) for x in A.x))
760+
Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x))
761+
Base.all(A::ArrayPartition) = all(identity, A)
762+
else
763+
# Julia < 1.13: only define GPU-check methods that error for GPU arrays
764+
# and fall through to AbstractArray for CPU. The optimized partition-level
765+
# methods are in RecursiveArrayToolsArrayPartitionAnyAll to avoid invalidations.
766+
const _ANYALL_GPU_HINT = """
767+
`any`/`all` on `ArrayPartition` with GPU arrays requires loading the subpackage:
768+
using RecursiveArrayToolsArrayPartitionAnyAll
769+
This provides optimized partition-level `any`/`all` that avoids scalar GPU indexing.
770+
"""
771+
772+
function _check_gpu_anyall(A::ArrayPartition)
773+
for x in A.x
774+
if x isa GPUArraysCore.AnyGPUArray
775+
error(_ANYALL_GPU_HINT)
776+
end
758777
end
759778
end
760-
end
761779

762-
function Base.any(f::Function, A::ArrayPartition)
763-
_check_gpu_anyall(A)
764-
return Base.invoke(any, Tuple{Function, AbstractArray}, f, A)
765-
end
766-
function Base.all(f::Function, A::ArrayPartition)
767-
_check_gpu_anyall(A)
768-
return Base.invoke(all, Tuple{Function, AbstractArray}, f, A)
769-
end
770-
function Base.any(f, A::ArrayPartition)
771-
_check_gpu_anyall(A)
772-
return Base.invoke(any, Tuple{Any, AbstractArray}, f, A)
773-
end
774-
function Base.all(f, A::ArrayPartition)
775-
_check_gpu_anyall(A)
776-
return Base.invoke(all, Tuple{Any, AbstractArray}, f, A)
780+
function Base.any(f::Function, A::ArrayPartition)
781+
_check_gpu_anyall(A)
782+
return Base.invoke(any, Tuple{Function, AbstractArray}, f, A)
783+
end
784+
function Base.all(f::Function, A::ArrayPartition)
785+
_check_gpu_anyall(A)
786+
return Base.invoke(all, Tuple{Function, AbstractArray}, f, A)
787+
end
788+
function Base.any(f, A::ArrayPartition)
789+
_check_gpu_anyall(A)
790+
return Base.invoke(any, Tuple{Any, AbstractArray}, f, A)
791+
end
792+
function Base.all(f, A::ArrayPartition)
793+
_check_gpu_anyall(A)
794+
return Base.invoke(all, Tuple{Any, AbstractArray}, f, A)
795+
end
777796
end

test/partitions_test.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,21 @@ recursivecopy!(dest_ap, src_ap)
178178
@inferred mapreduce(string, *, x)
179179
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"
180180

181-
# any/all — the base package defines GPU-aware methods that error for GPU sub-arrays
182-
# and fall through to AbstractArray iteration for CPU arrays. The optimized partition-level
183-
# iteration (1.5-1.8x faster) requires `using RecursiveArrayToolsArrayPartitionAnyAll`.
184-
# These @test_broken verify the optimized extension is NOT loaded.
185-
@test_broken occursin("ArrayPartitionAnyAll",
186-
string(which(any, Tuple{Function, ArrayPartition}).module))
187-
@test_broken occursin("ArrayPartitionAnyAll",
188-
string(which(all, Tuple{Function, ArrayPartition}).module))
189-
190-
# Correctness tests pass via the base package methods (element-by-element on CPU)
181+
# any/all — on Julia ≥ 1.13, optimized partition-level methods are defined inline.
182+
# On older Julia, they require `using RecursiveArrayToolsArrayPartitionAnyAll`.
183+
@static if VERSION >= v"1.13.0-DEV.0"
184+
# Optimized methods are always active on 1.13+
185+
@test which(any, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools
186+
@test which(all, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools
187+
else
188+
# On older Julia, optimized methods require the extension
189+
@test_broken occursin("ArrayPartitionAnyAll",
190+
string(which(any, Tuple{Function, ArrayPartition}).module))
191+
@test_broken occursin("ArrayPartitionAnyAll",
192+
string(which(all, Tuple{Function, ArrayPartition}).module))
193+
end
194+
195+
# Correctness tests
191196
@test !any(isnan, AP[[1, 2], [3.0, 4.0]])
192197
@test !any(isnan, AP[[3.0, 4.0]])
193198
@test any(isnan, AP[[NaN], [3.0, 4.0]])

0 commit comments

Comments
 (0)