Skip to content

Commit c3b3a8d

Browse files
committed
Fix type instability for vec[1, end] indexing
Fixes #525 The issue was that when `end` is used for the last dimension, `lastindex(VA, d)` returns `RaggedEnd(0, lastindex(VA.u))` as a sentinel for an already-resolved index. However, the `getindex` dispatch path couldn't infer the return type because `cols.dim == 0` is a runtime check. This PR adds: 1. A specialized `getindex(A, i::Int, re::RaggedEnd)` method that handles the common case `vec[i, end]` directly with type stability 2. A helper function `_ragged_getindex_int_col` as a function barrier for the Int column case 3. An early check in `_ragged_getindex` for the `RaggedEnd` sentinel case The fix ensures that `@code_warntype` shows the proper return type (Float64) instead of Any. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 6519ecf commit c3b3a8d

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

src/vector_of_array.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,22 @@ function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd)
564564
end
565565
Base.broadcastable(x::RaggedRange) = Ref(x)
566566

567+
# Specialized method for type stability when last index is RaggedEnd with dim=0 (resolved column index)
568+
# This handles the common case: vec[i, end] where end -> RaggedEnd(0, lastindex)
569+
Base.@propagate_inbounds function Base.getindex(
570+
A::AbstractVectorOfArray, i::Int, re::RaggedEnd
571+
)
572+
if re.dim == 0
573+
# Sentinel case: RaggedEnd(0, offset) means offset is the resolved column index
574+
return A.u[re.offset][i]
575+
else
576+
# Non-sentinel case: resolve the ragged index for the last column
577+
col = lastindex(A.u)
578+
resolved_idx = lastindex(A.u[col], re.dim) + re.offset
579+
return A.u[col][i, resolved_idx]
580+
end
581+
end
582+
567583
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
568584
length(VA.u) <= 1 && return false
569585
first_size = size(VA.u[1], d)
@@ -767,6 +783,18 @@ end
767783
return args
768784
end
769785

786+
# Helper function for type-stable getindex when column index is an Int
787+
# This function barrier ensures the compiler can fully infer the return type
788+
@inline function _ragged_getindex_int_col(A::AbstractVectorOfArray, prefix, cols::Int)
789+
if all(idx -> idx === Colon(), prefix)
790+
return A.u[cols]
791+
end
792+
resolved = _resolve_ragged_indices(prefix, A, cols)
793+
inner_nd = ndims(A.u[cols])
794+
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
795+
return A.u[cols][padded...]
796+
end
797+
770798
@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
771799
n = ndims(A)
772800
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
@@ -842,14 +870,12 @@ end
842870
cols = Colon()
843871
prefix = I
844872
end
873+
# Handle RaggedEnd sentinel (dim=0) early for type stability - this represents an already-resolved Int
874+
if cols isa RaggedEnd && cols.dim == 0
875+
return _ragged_getindex_int_col(A, prefix, cols.offset)
876+
end
845877
if cols isa Int
846-
if all(idx -> idx === Colon(), prefix)
847-
return A.u[cols]
848-
end
849-
resolved = _resolve_ragged_indices(prefix, A, cols)
850-
inner_nd = ndims(A.u[cols])
851-
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
852-
return A.u[cols][padded...]
878+
return _ragged_getindex_int_col(A, prefix, cols)
853879
else
854880
col_idxs = _column_indices(A, cols)
855881
# Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection

test/interface_tests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ push!(testda, [-1, -2, -3, -4])
6666
@inferred sum(testva)
6767
@inferred sum(VectorOfArray([VectorOfArray([zeros(4, 4)])]))
6868
@inferred mapreduce(string, *, testva)
69+
# Type stability for `end` indexing (issue #525)
70+
testva_end = VectorOfArray([fill(2.0, 2) for i in 1:10])
71+
# Use lastindex directly since `end` doesn't work in SafeTestsets
72+
last_col = lastindex(testva_end, 2)
73+
@inferred testva_end[1, last_col]
74+
@test testva_end[1, last_col] == 2.0
6975

7076
# mapreduce
7177
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

0 commit comments

Comments
 (0)