@@ -741,37 +741,56 @@ ODEProblem(func, AP[ [1.,2.,3.], [1. 2.;3. 4.] ], (0, 1)) |> solve
741741"""
742742struct AP end
743743
744- # any/all: provide methods that work partition-by-partition to avoid scalar indexing on GPUs.
745- # Without RecursiveArrayToolsArrayPartitionAnyAll, the AbstractArray fallback iterates
746- # element-by-element, which triggers scalar GPU indexing errors. These methods check for
747- # GPU sub-arrays and give a helpful error, or fall through to the partition-level implementation.
748- const _ANYALL_HINT = """
749- `any`/`all` on `ArrayPartition` with GPU arrays requires loading the subpackage:
750- using RecursiveArrayToolsArrayPartitionAnyAll
751- This provides optimized partition-level `any`/`all` that avoids scalar GPU indexing.
752- """
753-
754- function _check_gpu_anyall (A:: ArrayPartition )
755- for x in A. x
756- if x isa GPUArraysCore. AnyGPUArray
757- error (_ANYALL_HINT)
744+ # any/all on ArrayPartition: partition-level iteration avoids scalar GPU indexing
745+ # and is ~1.5-1.8x faster than element-by-element AbstractArray fallback.
746+ #
747+ # On Julia ≥ 1.13, Base no longer restricts any/all to f::Function, so defining
748+ # any(f, ::ArrayPartition) causes minimal invalidation. On older Julia, the
749+ # f::Function methods caused ~780 invalidations, so the optimized methods were
750+ # separated into RecursiveArrayToolsArrayPartitionAnyAll.
751+ #
752+ # We now define them inline with a GPU check: if sub-arrays are GPU arrays and
753+ # we're on old Julia without the extension, give a helpful error.
754+ @static if VERSION >= v " 1.13.0-DEV.0"
755+ # Julia 1.13+: safe to define directly, minimal invalidation
756+ Base. any (f, A:: ArrayPartition ) = any ((any (f, x) for x in A. x))
757+ Base. any (f:: Function , A:: ArrayPartition ) = any ((any (f, x) for x in A. x))
758+ Base. any (A:: ArrayPartition ) = any (identity, A)
759+ Base. all (f, A:: ArrayPartition ) = all ((all (f, x) for x in A. x))
760+ Base. all (f:: Function , A:: ArrayPartition ) = all ((all (f, x) for x in A. x))
761+ Base. all (A:: ArrayPartition ) = all (identity, A)
762+ else
763+ # Julia < 1.13: only define GPU-check methods that error for GPU arrays
764+ # and fall through to AbstractArray for CPU. The optimized partition-level
765+ # methods are in RecursiveArrayToolsArrayPartitionAnyAll to avoid invalidations.
766+ const _ANYALL_GPU_HINT = """
767+ `any`/`all` on `ArrayPartition` with GPU arrays requires loading the subpackage:
768+ using RecursiveArrayToolsArrayPartitionAnyAll
769+ This provides optimized partition-level `any`/`all` that avoids scalar GPU indexing.
770+ """
771+
772+ function _check_gpu_anyall (A:: ArrayPartition )
773+ for x in A. x
774+ if x isa GPUArraysCore. AnyGPUArray
775+ error (_ANYALL_GPU_HINT)
776+ end
758777 end
759778 end
760- end
761779
762- function Base. any (f:: Function , A:: ArrayPartition )
763- _check_gpu_anyall (A)
764- return Base. invoke (any, Tuple{Function, AbstractArray}, f, A)
765- end
766- function Base. all (f:: Function , A:: ArrayPartition )
767- _check_gpu_anyall (A)
768- return Base. invoke (all, Tuple{Function, AbstractArray}, f, A)
769- end
770- function Base. any (f, A:: ArrayPartition )
771- _check_gpu_anyall (A)
772- return Base. invoke (any, Tuple{Any, AbstractArray}, f, A)
773- end
774- function Base. all (f, A:: ArrayPartition )
775- _check_gpu_anyall (A)
776- return Base. invoke (all, Tuple{Any, AbstractArray}, f, A)
780+ function Base. any (f:: Function , A:: ArrayPartition )
781+ _check_gpu_anyall (A)
782+ return Base. invoke (any, Tuple{Function, AbstractArray}, f, A)
783+ end
784+ function Base. all (f:: Function , A:: ArrayPartition )
785+ _check_gpu_anyall (A)
786+ return Base. invoke (all, Tuple{Function, AbstractArray}, f, A)
787+ end
788+ function Base. any (f, A:: ArrayPartition )
789+ _check_gpu_anyall (A)
790+ return Base. invoke (any, Tuple{Any, AbstractArray}, f, A)
791+ end
792+ function Base. all (f, A:: ArrayPartition )
793+ _check_gpu_anyall (A)
794+ return Base. invoke (all, Tuple{Any, AbstractArray}, f, A)
795+ end
777796end
0 commit comments