Skip to content

Commit 1167924

Browse files
Add type-stable getindex for vec[end, col] pattern
- Add specialized getindex method for (RaggedEnd, Int) to handle vec[end, col] - Fix test to use lastindex(vec, 1) for row index instead of lastindex(vec) - Fix typo in test assertion (misplaced parenthesis) Addresses review feedback from JoshuaLampert. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 6b8b729 commit 1167924

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/vector_of_array.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,21 @@ Base.@propagate_inbounds function Base.getindex(
580580
end
581581
end
582582

583+
# Specialized method for type stability when first index is RaggedEnd (row dimension)
584+
# This handles the common case: vec[end, col] where end -> RaggedEnd(1, 0)
585+
Base.@propagate_inbounds function Base.getindex(
586+
A::AbstractVectorOfArray, re::RaggedEnd, col::Int
587+
)
588+
if re.dim == 0
589+
# Sentinel case: RaggedEnd(0, offset) means offset is a plain index
590+
return A.u[col][re.offset]
591+
else
592+
# Non-sentinel case: resolve the ragged index for the given column
593+
resolved_idx = lastindex(A.u[col], re.dim) + re.offset
594+
return A.u[col][resolved_idx]
595+
end
596+
end
597+
583598
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
584599
length(VA.u) <= 1 && return false
585600
first_size = size(VA.u[1], d)

test/interface_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ last_col = lastindex(testva_end, 2)
7575
last_col = lastindex(testva_end)
7676
@inferred testva_end[1, last_col]
7777
@test testva_end[1, last_col] == 2.0
78-
last_row = lastindex(testva_end)
78+
last_row = lastindex(testva_end, 1)
7979
@inferred testva_end[last_row, 1]
80-
@test testva_end[last_row, 1 == 2.0]
80+
@test testva_end[last_row, 1] == 2.0
8181

8282
# mapreduce
8383
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

0 commit comments

Comments
 (0)