Skip to content

Commit 2cc6009

Browse files
Merge pull request #530 from Ickaser/typestable-indexing
Achieve type stability for indexing by avoiding variable boxing, including for ranges
2 parents cc8aa48 + 075c18b commit 2cc6009

File tree

2 files changed

+102
-130
lines changed

2 files changed

+102
-130
lines changed

src/vector_of_array.jl

Lines changed: 98 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -576,37 +576,6 @@ function Base.:(:)(start::RaggedEnd, step::Integer, stop::Integer)
576576
end
577577
Base.broadcastable(x::RaggedRange) = Ref(x)
578578

579-
# Specialized method for type stability when last index is RaggedEnd with dim=0 (resolved column index)
580-
# This handles the common case: vec[i, end] where end -> RaggedEnd(0, lastindex)
581-
Base.@propagate_inbounds function Base.getindex(
582-
A::AbstractVectorOfArray, i::Int, re::RaggedEnd
583-
)
584-
if re.dim == 0
585-
# Sentinel case: RaggedEnd(0, offset) means offset is the resolved column index
586-
return A.u[re.offset][i]
587-
else
588-
# Non-sentinel case: resolve the ragged index for the last column
589-
col = lastindex(A.u)
590-
resolved_idx = lastindex(A.u[col], re.dim) + re.offset
591-
return A.u[col][i, resolved_idx]
592-
end
593-
end
594-
595-
# Specialized method for type stability when first index is RaggedEnd (row dimension)
596-
# This handles the common case: vec[end, col] where end -> RaggedEnd(1, 0)
597-
Base.@propagate_inbounds function Base.getindex(
598-
A::AbstractVectorOfArray, re::RaggedEnd, col::Int
599-
)
600-
if re.dim == 0
601-
# Sentinel case: RaggedEnd(0, offset) means offset is a plain index
602-
return A.u[col][re.offset]
603-
else
604-
# Non-sentinel case: resolve the ragged index for the given column
605-
resolved_idx = lastindex(A.u[col], re.dim) + re.offset
606-
return A.u[col][resolved_idx]
607-
end
608-
end
609-
610579
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
611580
length(VA.u) <= 1 && return false
612581
first_size = size(VA.u[1], d)
@@ -740,8 +709,8 @@ Base.@propagate_inbounds function _getindex(
740709
return getindex(A, all_variable_symbols(A), args...)
741710
end
742711

743-
@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ?
744-
eachindex(VA.u) : idx
712+
@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx
713+
@inline _column_indices(VA::AbstractVectorOfArray, idx::Colon) = eachindex(VA.u)
745714
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
746715
return findall(idx)
747716
end
@@ -874,106 +843,115 @@ end
874843
n = ndims(A)
875844
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
876845
if length(I) == n - 1
877-
raw_cols = last(I)
878-
# Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve)
879-
is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0
880-
false # Inner dimension - don't preserve type
881-
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
882-
true # Inner dimension range converted to column range - DO preserve type
883-
else
884-
true # Column selection (dim == 0 or not ragged)
885-
end
846+
return _ragged_getindex_nm1dims(A, I...)
847+
else
848+
return _ragged_getindex_full(A, I...)
849+
end
850+
end
886851

887-
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
888-
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
889-
lastindex(A.u) + raw_cols.offset
890-
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
891-
# Convert inner-dimension range to column range by resolving bounds
892-
start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start
893-
stop_val = lastindex(A.u) + raw_cols.offset
894-
Base.range(start_val; step = raw_cols.step, stop = stop_val)
895-
else
896-
_column_indices(A, raw_cols)
897-
end
898-
prefix = Base.front(I)
899-
if cols isa Int
900-
resolved_prefix = _resolve_ragged_indices(prefix, A, cols)
901-
inner_nd = ndims(A.u[cols])
902-
n_missing = inner_nd - length(resolved_prefix)
903-
padded = if n_missing > 0
904-
if all(idx -> idx === Colon(), resolved_prefix)
905-
(resolved_prefix..., ntuple(_ -> Colon(), n_missing)...)
906-
else
907-
(
908-
resolved_prefix...,
909-
(lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)...,
910-
)
911-
end
852+
@inline function _ragged_getindex_nm1dims(A::AbstractVectorOfArray, I...)
853+
raw_cols = last(I)
854+
# Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve)
855+
is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0
856+
false # Inner dimension - don't preserve type
857+
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
858+
true # Inner dimension range converted to column range - DO preserve type
859+
else
860+
true # Column selection (dim == 0 or not ragged)
861+
end
862+
863+
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
864+
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
865+
lastindex(A.u) + raw_cols.offset
866+
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
867+
# Convert inner-dimension range to column range by resolving bounds
868+
start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start
869+
stop_val = lastindex(A.u) + raw_cols.offset
870+
Base.range(start_val; step = raw_cols.step, stop = stop_val)
871+
else
872+
_column_indices(A, raw_cols)
873+
end
874+
prefix = Base.front(I)
875+
if cols isa Int
876+
resolved_prefix = _resolve_ragged_indices(prefix, A, cols)
877+
inner_nd = ndims(A.u[cols])
878+
n_missing = inner_nd - length(resolved_prefix)
879+
padded = if n_missing > 0
880+
if all(idx -> idx === Colon(), resolved_prefix)
881+
(resolved_prefix..., ntuple(_ -> Colon(), n_missing)...)
912882
else
913-
resolved_prefix
883+
(
884+
resolved_prefix...,
885+
(lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)...,
886+
)
914887
end
915-
return A.u[cols][padded...]
916888
else
917-
u_slice = [
918-
begin
919-
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
920-
inner_nd = ndims(A.u[col])
921-
n_missing = inner_nd - length(resolved_prefix)
922-
padded = if n_missing > 0
923-
if all(idx -> idx === Colon(), resolved_prefix)
924-
(
925-
resolved_prefix...,
926-
ntuple(_ -> Colon(), n_missing)...,
927-
)
928-
else
929-
(
930-
resolved_prefix...,
931-
(
932-
lastindex(
933-
A.u[col],
934-
length(resolved_prefix) + i
935-
) for i in 1:n_missing
936-
)...,
937-
)
938-
end
889+
resolved_prefix
890+
end
891+
return A.u[cols][padded...]
892+
else
893+
u_slice = [
894+
begin
895+
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
896+
inner_nd = ndims(A.u[col])
897+
n_missing = inner_nd - length(resolved_prefix)
898+
padded = if n_missing > 0
899+
if all(idx -> idx === Colon(), resolved_prefix)
900+
(
901+
resolved_prefix...,
902+
ntuple(_ -> Colon(), n_missing)...,
903+
)
939904
else
940-
resolved_prefix
941-
end
942-
A.u[col][padded...]
905+
(
906+
resolved_prefix...,
907+
(
908+
lastindex(
909+
A.u[col],
910+
length(resolved_prefix) + i
911+
) for i in 1:n_missing
912+
)...,
913+
)
943914
end
944-
for col in cols
945-
]
946-
# Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions
947-
if is_column_selection
948-
return _preserve_array_type(A, u_slice, cols)
949-
else
950-
return VectorOfArray(u_slice)
951-
end
915+
else
916+
resolved_prefix
917+
end
918+
A.u[col][padded...]
919+
end
920+
for col in cols
921+
]
922+
# Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions
923+
if is_column_selection
924+
return _preserve_array_type(A, u_slice, cols)
925+
else
926+
return VectorOfArray(u_slice)
952927
end
953928
end
929+
end
954930

931+
@inline function _padded_resolved_indices(prefix, A::AbstractVectorOfArray, col)
932+
resolved = _resolve_ragged_indices(prefix, A, col)
933+
inner_nd = ndims(A.u[col])
934+
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
935+
return padded
936+
end
937+
938+
@inline function _ragged_getindex_full(A::AbstractVectorOfArray, I...)
955939
# Otherwise, use the full-length interpretation (last index is column selector; missing columns default to Colon()).
956-
if length(I) == n
957-
cols = last(I)
958-
prefix = Base.front(I)
940+
n = ndims(A)
941+
cols, prefix = if length(I) == n
942+
last(I), Base.front(I)
959943
else
960-
cols = Colon()
961-
prefix = I
944+
Colon(), I
962945
end
963946
if cols isa Int
964947
if all(idx -> idx === Colon(), prefix)
965948
return A.u[cols]
966949
end
967-
resolved = _resolve_ragged_indices(prefix, A, cols)
968-
inner_nd = ndims(A.u[cols])
969-
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
970-
return A.u[cols][padded...]
950+
return A.u[cols][_padded_resolved_indices(prefix, A, cols)...]
971951
else
972952
col_idxs = _column_indices(A, cols)
973953
# Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection
974-
if col_idxs isa RaggedEnd
975-
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
976-
elseif col_idxs isa RaggedRange
954+
if col_idxs isa RaggedEnd || col_idxs isa RaggedRange
977955
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
978956
end
979957
# If we're selecting whole inner arrays (all leading indices are Colons),
@@ -986,23 +964,14 @@ end
986964
end
987965
end
988966
# If col_idxs resolved to a single Int, handle it directly
989-
if col_idxs isa Int
990-
resolved = _resolve_ragged_indices(prefix, A, col_idxs)
991-
inner_nd = ndims(A.u[col_idxs])
992-
padded = (
993-
resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...,
994-
)
995-
return A.u[col_idxs][padded...]
996-
end
997967
vals = map(col_idxs) do col
998-
resolved = _resolve_ragged_indices(prefix, A, col)
999-
inner_nd = ndims(A.u[col])
1000-
padded = (
1001-
resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...,
1002-
)
1003-
A.u[col][padded...]
968+
A.u[col][_padded_resolved_indices(prefix, A, col)...]
969+
end
970+
if col_idxs isa Int
971+
return vals
972+
else
973+
return stack(vals)
1004974
end
1005-
return stack(vals)
1006975
end
1007976
end
1008977

test/interface_tests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,19 @@ push!(testda, [-1, -2, -3, -4])
6767
@inferred sum(VectorOfArray([VectorOfArray([zeros(4, 4)])]))
6868
@inferred mapreduce(string, *, testva)
6969
# Type stability for `end` indexing (issue #525)
70-
testva_end = VectorOfArray([fill(2.0, 2) for i in 1:10])
70+
testva_end = VectorOfArray(fill(fill(2.0, 2), 10))
7171
# Use lastindex directly since `end` doesn't work in SafeTestsets
7272
last_col = lastindex(testva_end, 2)
7373
@inferred testva_end[1, last_col]
74+
@inferred testva_end[1, 1:last_col]
7475
@test testva_end[1, last_col] == 2.0
7576
last_col = lastindex(testva_end)
7677
@inferred testva_end[1, last_col]
78+
@inferred testva_end[1, 1:last_col]
7779
@test testva_end[1, last_col] == 2.0
7880
last_row = lastindex(testva_end, 1)
7981
@inferred testva_end[last_row, 1]
82+
@inferred testva_end[1:last_row, 1]
8083
@test testva_end[last_row, 1] == 2.0
8184

8285
# mapreduce

0 commit comments

Comments
 (0)