Skip to content

Commit da56aa8

Browse files
Eliminate all invalidation trees (8 → 0)
- NamedArrayPartition: narrow setindex!(x, args...) to setindex!(x, v, i::Int) instead of catching all signatures. AbstractVector only needs the Int method. - Remove setindex! CartesianIndex from Union in multi-index method (Base handles it) - Remove dedicated (Int, CartesianIndex) getindex/setindex! methods; flatten inside the multi-arg dispatcher instead - Fix NamedArrayPartition test: x[1:end] now preserves type (correct behavior) Result: `using RecursiveArrayTools` causes 0 invalidation trees. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9e09463 commit da56aa8

File tree

3 files changed

+11
-19
lines changed

3 files changed

+11
-19
lines changed

src/named_array_partition.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ end
9292

9393
Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
9494
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
95-
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
96-
97-
Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
95+
# Delegate indexing to the underlying ArrayPartition.
96+
# Use concrete index types to avoid invalidating AbstractArray's generic setindex!.
97+
Base.@propagate_inbounds Base.getindex(x::NamedArrayPartition, i::Int) = ArrayPartition(x)[i]
98+
Base.@propagate_inbounds Base.setindex!(x::NamedArrayPartition, v, i::Int) = (ArrayPartition(x)[i] = v)
9899
function Base.map(f, x::NamedArrayPartition)
99100
return NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
100101
end

src/vector_of_array.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -735,21 +735,12 @@ Base.@propagate_inbounds function Base.setindex!(
735735
return u_col[CartesianIndex(inner_I)] = x
736736
end
737737

738-
## Mixed Int + CartesianIndex (needed for sum(A; dims=d) etc.)
739-
## Use @inline to avoid invalidation issues with overly broad signatures
740-
@inline Base.@propagate_inbounds function Base.getindex(
741-
A::AbstractVectorOfArray{T, N}, i::Int, ci::CartesianIndex
742-
) where {T, N}
743-
return A[i, Tuple(ci)...]
744-
end
745-
746-
@inline Base.@propagate_inbounds function Base.setindex!(
747-
A::AbstractVectorOfArray{T, N}, v, i::Int, ci::CartesianIndex
748-
) where {T, N}
749-
return A[i, Tuple(ci)...] = v
750-
end
751-
752738
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
739+
# Flatten CartesianIndex arguments (e.g. from sum(A; dims=d)) to plain Ints
740+
# so they hit the N-ary getindex method instead of the symbolic dispatch.
741+
if _arg isa Int && length(args) == 1 && args[1] isa CartesianIndex
742+
return A[_arg, Tuple(args[1])...]
743+
end
753744
symtype = symbolic_type(_arg)
754745
elsymtype = symbolic_type(eltype(_arg))
755746

@@ -818,7 +809,7 @@ end
818809
Base.@propagate_inbounds function Base.setindex!(
819810
VA::AbstractVectorOfArray{T, N},
820811
x,
821-
idxs::Union{Int, Colon, CartesianIndex, AbstractArray{Int}, AbstractArray{Bool}}...
812+
idxs::Union{Int, Colon, AbstractArray{Int}, AbstractArray{Bool}}...
822813
) where {
823814
T, N,
824815
}

test/named_array_partition_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using RecursiveArrayTools, ArrayInterface, Test
77
@test typeof(similar(x)) <: NamedArrayPartition
88
@test typeof(similar(x, Int)) <: NamedArrayPartition
99
@test x.a ones(10)
10-
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
10+
@test typeof(x .+ x[1:end]) <: NamedArrayPartition # x[1:end] preserves type
1111
@test all(x .== x[1:end])
1212
@test ArrayInterface.zeromatrix(x) isa Matrix
1313
@test size(ArrayInterface.zeromatrix(x)) == (30, 30)

0 commit comments

Comments
 (0)