Skip to content

Commit 6a779d4

Browse files
Revert 1.13 any/all inlining — keep subpackage split, add comment
The 1.13 inlining of optimized any/all should be a separate PR. Keep the clean split: RecursiveArrayToolsArrayPartitionAnyAll provides the optimized methods, Base's AbstractArray fallback is used without it. Tests verify dispatch goes to Base without extension, and the Subpackages test group verifies loading the extension overrides correctly. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a9edc84 commit 6a779d4

File tree

2 files changed

+12
-67
lines changed

2 files changed

+12
-67
lines changed

src/array_partition.jl

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -741,56 +741,9 @@ ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve
741741
"""
742742
struct AP end
743743

744-
# any/all on ArrayPartition: partition-level iteration avoids scalar GPU indexing
745-
# and is ~1.5-1.8x faster than element-by-element AbstractArray fallback.
746-
#
747-
# On Julia ≥ 1.13, Base no longer restricts any/all to f::Function, so defining
748-
# any(f, ::ArrayPartition) causes minimal invalidation. On older Julia, the
749-
# f::Function methods caused ~780 invalidations, so the optimized methods were
750-
# separated into RecursiveArrayToolsArrayPartitionAnyAll.
751-
#
752-
# We now define them inline with a GPU check: if sub-arrays are GPU arrays and
753-
# we're on old Julia without the extension, give a helpful error.
754-
@static if VERSION >= v"1.13.0-DEV.0"
755-
# Julia 1.13+: safe to define directly, minimal invalidation
756-
Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x))
757-
Base.any(f::Function, A::ArrayPartition) = any((any(f, x) for x in A.x))
758-
Base.any(A::ArrayPartition) = any(identity, A)
759-
Base.all(f, A::ArrayPartition) = all((all(f, x) for x in A.x))
760-
Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x))
761-
Base.all(A::ArrayPartition) = all(identity, A)
762-
else
763-
# Julia < 1.13: only define GPU-check methods that error for GPU arrays
764-
# and fall through to AbstractArray for CPU. The optimized partition-level
765-
# methods are in RecursiveArrayToolsArrayPartitionAnyAll to avoid invalidations.
766-
const _ANYALL_GPU_HINT = """
767-
`any`/`all` on `ArrayPartition` with GPU arrays requires loading the subpackage:
768-
using RecursiveArrayToolsArrayPartitionAnyAll
769-
This provides optimized partition-level `any`/`all` that avoids scalar GPU indexing.
770-
"""
771-
772-
function _check_gpu_anyall(A::ArrayPartition)
773-
for x in A.x
774-
if x isa GPUArraysCore.AnyGPUArray
775-
error(_ANYALL_GPU_HINT)
776-
end
777-
end
778-
end
779-
780-
function Base.any(f::Function, A::ArrayPartition)
781-
_check_gpu_anyall(A)
782-
return Base.invoke(any, Tuple{Function, AbstractArray}, f, A)
783-
end
784-
function Base.all(f::Function, A::ArrayPartition)
785-
_check_gpu_anyall(A)
786-
return Base.invoke(all, Tuple{Function, AbstractArray}, f, A)
787-
end
788-
function Base.any(f, A::ArrayPartition)
789-
_check_gpu_anyall(A)
790-
return Base.invoke(any, Tuple{Any, AbstractArray}, f, A)
791-
end
792-
function Base.all(f, A::ArrayPartition)
793-
_check_gpu_anyall(A)
794-
return Base.invoke(all, Tuple{Any, AbstractArray}, f, A)
795-
end
796-
end
744+
# Optimized partition-level any/all for ArrayPartition lives in
745+
# RecursiveArrayToolsArrayPartitionAnyAll to avoid ~780 invalidations.
746+
# Without the extension, any/all uses the AbstractArray element-by-element
747+
# fallback, which triggers scalar indexing errors on GPU arrays.
748+
# Load the subpackage to fix:
749+
# using RecursiveArrayToolsArrayPartitionAnyAll

test/partitions_test.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,13 @@ recursivecopy!(dest_ap, src_ap)
178178
@inferred mapreduce(string, *, x)
179179
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"
180180

181-
# any/all — on Julia ≥ 1.13, optimized partition-level methods are defined inline.
182-
# On older Julia, they require `using RecursiveArrayToolsArrayPartitionAnyAll`.
183-
@static if VERSION >= v"1.13.0-DEV.0"
184-
# Optimized methods are always active on 1.13+
185-
@test which(any, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools
186-
@test which(all, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools
187-
else
188-
# On older Julia, optimized methods require the extension
189-
@test_broken occursin("ArrayPartitionAnyAll",
190-
string(which(any, Tuple{Function, ArrayPartition}).module))
191-
@test_broken occursin("ArrayPartitionAnyAll",
192-
string(which(all, Tuple{Function, ArrayPartition}).module))
193-
end
181+
# any/all — optimized partition-level iteration requires RecursiveArrayToolsArrayPartitionAnyAll
182+
# to avoid ~780 invalidations. Without the extension, Base's AbstractArray fallback is used.
183+
# On GPU arrays, the fallback triggers scalar indexing errors — load the subpackage to fix.
184+
@test which(any, Tuple{Function, ArrayPartition}).module === Base
185+
@test which(all, Tuple{Function, ArrayPartition}).module === Base
194186

195-
# Correctness tests
187+
# Correctness tests (work via AbstractArray fallback on CPU)
196188
@test !any(isnan, AP[[1, 2], [3.0, 4.0]])
197189
@test !any(isnan, AP[[3.0, 4.0]])
198190
@test any(isnan, AP[[NaN], [3.0, 4.0]])

0 commit comments

Comments
 (0)