Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,22 @@ ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve
"""
struct AP end

# Optimized partition-level any/all for ArrayPartition lives in
# RecursiveArrayToolsArrayPartitionAnyAll to avoid ~780 invalidations.
# Without the extension, any/all uses the AbstractArray element-by-element
# fallback, which triggers scalar indexing errors on GPU arrays.
# Load the subpackage to fix:
# Optimized partition-level any/all for ArrayPartition.
#
# On Julia ≥ 1.13 (JuliaLang/julia#61184), Base removes the f::Function
# restriction from any/all, so defining any(f, ::ArrayPartition) causes
# only 1 invalidation (down from ~780). Safe to inline directly.
#
# On Julia < 1.13, the methods live in RecursiveArrayToolsArrayPartitionAnyAll
# to avoid the invalidations. Without the extension, any/all uses the
# AbstractArray element-by-element fallback, which triggers scalar indexing
# errors on GPU arrays. Load the subpackage to fix:
# using RecursiveArrayToolsArrayPartitionAnyAll
@static if VERSION >= v"1.13.0-DEV.0"
Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x))
Base.any(f::Function, A::ArrayPartition) = any((any(f, x) for x in A.x))
Base.any(A::ArrayPartition) = any(identity, A)
Base.all(f, A::ArrayPartition) = all((all(f, x) for x in A.x))
Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x))
Base.all(A::ArrayPartition) = all(identity, A)
end
16 changes: 10 additions & 6 deletions test/partitions_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,17 @@ recursivecopy!(dest_ap, src_ap)
@inferred mapreduce(string, *, x)
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"

# any/all — optimized partition-level iteration requires RecursiveArrayToolsArrayPartitionAnyAll
# to avoid ~780 invalidations. Without the extension, Base's AbstractArray fallback is used.
# On GPU arrays, the fallback triggers scalar indexing errors — load the subpackage to fix.
@test which(any, Tuple{Function, ArrayPartition}).module === Base
@test which(all, Tuple{Function, ArrayPartition}).module === Base
# any/all — on Julia ≥ 1.13, optimized methods are inlined (1 invalidation).
# On older Julia, they require RecursiveArrayToolsArrayPartitionAnyAll (~780 invalidations).
@static if VERSION >= v"1.13.0-DEV.0"
@test which(any, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools
@test which(all, Tuple{Function, ArrayPartition}).module === RecursiveArrayTools
else
@test which(any, Tuple{Function, ArrayPartition}).module === Base
@test which(all, Tuple{Function, ArrayPartition}).module === Base
end

# Correctness tests (work via AbstractArray fallback on CPU)
# Correctness tests
@test !any(isnan, AP[[1, 2], [3.0, 4.0]])
@test !any(isnan, AP[[3.0, 4.0]])
@test any(isnan, AP[[NaN], [3.0, 4.0]])
Expand Down
Loading