Skip to content

Commit ecd6794

Browse files
Merge pull request #527 from JoshuaLampert/fix-range-raggedend
Fix `RaggedEnd` in range and returning `DiffEqArray`
2 parents 6519ecf + 66ded44 commit ecd6794

File tree

2 files changed

+161
-36
lines changed

2 files changed

+161
-36
lines changed

src/vector_of_array.jl

Lines changed: 122 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,18 @@ end
562562
function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd)
563563
return RaggedRange(stop.dim, Int(start), Int(step), stop.offset)
564564
end
565+
function Base.:(:)(start::RaggedEnd, stop::RaggedEnd)
566+
return RaggedRange(stop.dim, start.offset, 1, stop.offset)
567+
end
568+
function Base.:(:)(start::RaggedEnd, step::Integer, stop::RaggedEnd)
569+
return RaggedRange(stop.dim, start.offset, Int(step), stop.offset)
570+
end
571+
function Base.:(:)(start::RaggedEnd, stop::Integer)
572+
return RaggedRange(start.dim, start.offset, 1, Int(stop))
573+
end
574+
function Base.:(:)(start::RaggedEnd, step::Integer, stop::Integer)
575+
return RaggedRange(start.dim, start.offset, Int(step), Int(stop))
576+
end
565577
Base.broadcastable(x::RaggedRange) = Ref(x)
566578

567579
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
@@ -579,6 +591,12 @@ Base.@propagate_inbounds function _getindex(
579591
return A.u[I]
580592
end
581593

594+
Base.@propagate_inbounds function _getindex(
595+
A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Int
596+
)
597+
return A.u[I]
598+
end
599+
582600
Base.@propagate_inbounds function _getindex(
583601
A::AbstractVectorOfArray, ::NotSymbolic,
584602
I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...
@@ -589,6 +607,33 @@ Base.@propagate_inbounds function _getindex(
589607
stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...))
590608
end
591609
end
610+
611+
Base.@propagate_inbounds function _getindex(
612+
A::AbstractDiffEqArray, ::NotSymbolic,
613+
I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...
614+
)
615+
return if last(I) isa Int
616+
A.u[last(I)][Base.front(I)...]
617+
else
618+
col_idxs = last(I)
619+
# Only preserve DiffEqArray type if all prefix indices are Colons (selecting whole inner arrays)
620+
if all(idx -> idx isa Colon, Base.front(I))
621+
# For Colon, select all columns
622+
if col_idxs isa Colon
623+
col_idxs = eachindex(A.u)
624+
end
625+
# For DiffEqArray, we need to preserve the time values and type
626+
# Create a vector of sliced arrays instead of stacking into higher-dim array
627+
u_slice = [A.u[col][Base.front(I)...] for col in col_idxs]
628+
# Return as DiffEqArray with sliced time values
629+
return DiffEqArray(u_slice, A.t[col_idxs], parameter_values(A), symbolic_container(A))
630+
else
631+
# Prefix indices are not all Colons - do the same as VectorOfArray
632+
# (stack the results into a higher-dimensional array)
633+
return stack(getindex.(A.u[col_idxs], tuple.(Base.front(I))...))
634+
end
635+
end
636+
end
592637
Base.@propagate_inbounds function _getindex(
593638
VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex
594639
)
@@ -674,6 +719,17 @@ end
674719
return idx.dim == 0 ? idx.offset : idx
675720
end
676721

722+
@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedRange)
723+
# RaggedRange with dim=0 means it's a column range with pre-resolved indices
724+
if idx.dim == 0
725+
# Create a range with the offset as the stop value
726+
return Base.range(idx.start; step = idx.step, stop = idx.offset)
727+
else
728+
# dim != 0 means it's an inner-dimension range that needs column expansion
729+
return idx
730+
end
731+
end
732+
677733
@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
678734
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
679735
if idx.dim == 0
@@ -757,27 +813,54 @@ end
757813
return (Base.front(args)..., resolved_last)
758814
end
759815
elseif args[end] isa RaggedRange
760-
resolved_last = _resolve_ragged_index(args[end], A, 1)
761-
if length(args) == 1
762-
return (resolved_last,)
816+
# Only pre-resolve if it's an inner-dimension range (dim != 0)
817+
# Column ranges (dim == 0) are handled later by _column_indices
818+
if args[end].dim == 0
819+
# Column range - let _column_indices handle it
820+
return args
763821
else
764-
return (Base.front(args)..., resolved_last)
822+
resolved_last = _resolve_ragged_index(args[end], A, 1)
823+
if length(args) == 1
824+
return (resolved_last,)
825+
else
826+
return (Base.front(args)..., resolved_last)
827+
end
765828
end
766829
end
767830
return args
768831
end
769832

833+
# Helper function to preserve DiffEqArray type when slicing
834+
@inline function _preserve_array_type(A::AbstractVectorOfArray, u_slice, col_idxs)
835+
return VectorOfArray(u_slice)
836+
end
837+
838+
@inline function _preserve_array_type(A::AbstractDiffEqArray, u_slice, col_idxs)
839+
return DiffEqArray(u_slice, A.t[col_idxs], parameter_values(A), symbolic_container(A))
840+
end
841+
770842
@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
771843
n = ndims(A)
772844
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
773845
if length(I) == n - 1
774846
raw_cols = last(I)
847+
# Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve)
848+
is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0
849+
false # Inner dimension - don't preserve type
850+
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
851+
true # Inner dimension range converted to column range - DO preserve type
852+
else
853+
true # Column selection (dim == 0 or not ragged)
854+
end
855+
775856
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
776857
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
777858
lastindex(A.u) + raw_cols.offset
778859
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
860+
# Convert inner-dimension range to column range by resolving bounds
861+
start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start
779862
stop_val = lastindex(A.u) + raw_cols.offset
780-
Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val)
863+
Base.range(start_val; step = raw_cols.step, stop = stop_val)
781864
else
782865
_column_indices(A, raw_cols)
783866
end
@@ -800,37 +883,41 @@ end
800883
end
801884
return A.u[cols][padded...]
802885
else
803-
return VectorOfArray(
804-
[
805-
begin
806-
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
807-
inner_nd = ndims(A.u[col])
808-
n_missing = inner_nd - length(resolved_prefix)
809-
padded = if n_missing > 0
810-
if all(idx -> idx === Colon(), resolved_prefix)
811-
(
812-
resolved_prefix...,
813-
ntuple(_ -> Colon(), n_missing)...,
814-
)
815-
else
816-
(
817-
resolved_prefix...,
818-
(
819-
lastindex(
820-
A.u[col],
821-
length(resolved_prefix) + i
822-
) for i in 1:n_missing
823-
)...,
824-
)
825-
end
886+
u_slice = [
887+
begin
888+
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
889+
inner_nd = ndims(A.u[col])
890+
n_missing = inner_nd - length(resolved_prefix)
891+
padded = if n_missing > 0
892+
if all(idx -> idx === Colon(), resolved_prefix)
893+
(
894+
resolved_prefix...,
895+
ntuple(_ -> Colon(), n_missing)...,
896+
)
826897
else
827-
resolved_prefix
828-
end
829-
A.u[col][padded...]
898+
(
899+
resolved_prefix...,
900+
(
901+
lastindex(
902+
A.u[col],
903+
length(resolved_prefix) + i
904+
) for i in 1:n_missing
905+
)...,
906+
)
830907
end
831-
for col in cols
832-
]
833-
)
908+
else
909+
resolved_prefix
910+
end
911+
A.u[col][padded...]
912+
end
913+
for col in cols
914+
]
915+
# Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions
916+
if is_column_selection
917+
return _preserve_array_type(A, u_slice, cols)
918+
else
919+
return VectorOfArray(u_slice)
920+
end
834921
end
835922
end
836923

@@ -864,7 +951,7 @@ end
864951
if col_idxs isa Int
865952
return A.u[col_idxs]
866953
else
867-
return VectorOfArray(A.u[col_idxs])
954+
return _preserve_array_type(A, A.u[col_idxs], col_idxs)
868955
end
869956
end
870957
# If col_idxs resolved to a single Int, handle it directly

test/basic_indexing.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ diffeq = DiffEqArray(recs, t)
7777
@test diffeq[:, 1] == testa[:, 1]
7878
@test diffeq.u == recs
7979
@test diffeq[:, end] == testa[:, end]
80-
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:length(recs)], t)
80+
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:length(recs)], t[2:end])
81+
@test diffeq[:, 2:end].t == t[2:end]
82+
@test diffeq[:, (end - 1):end] == DiffEqArray([recs[i] for i in (length(recs) - 1):length(recs)], t[(length(t) - 1):length(t)])
83+
@test diffeq[:, (end - 1):end].t == t[(length(t) - 1):length(t)]
84+
@test diffeq[:, (end - 5):8] == DiffEqArray([recs[i] for i in (length(t) - 5):8], t[(length(t) - 5):8])
85+
@test diffeq[:, (end - 5):8].t == t[(length(t) - 5):8]
8186

8287
# ## (Int, Int)
8388
@test testa[5, 4] == testva[5, 4]
@@ -148,6 +153,12 @@ diffeq = DiffEqArray(recs, t)
148153
@test testva[1:2, 1:2] == [1 3; 2 5]
149154
@test diffeq[:, 1] == recs[1]
150155
@test diffeq[1:2, 1:2] == [1 3; 2 5]
156+
@test diffeq[:, 1:2] == DiffEqArray([recs[i] for i in 1:2], t[1:2])
157+
@test diffeq[:, 1:2].t == t[1:2]
158+
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:3], t[2:end])
159+
@test diffeq[:, 2:end].t == t[2:end]
160+
@test diffeq[:, (end - 1):end] == DiffEqArray([recs[i] for i in (length(recs) - 1):length(recs)], t[(length(t) - 1):length(t)])
161+
@test diffeq[:, (end - 1):end].t == t[(length(t) - 1):length(t)]
151162

152163
# Test views of heterogeneous arrays (issue #453)
153164
f = VectorOfArray([[1.0], [2.0, 3.0]])
@@ -179,6 +190,7 @@ ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
179190
@test ragged[1:end, 3] == [6.0, 7.0, 8.0, 9.0]
180191
@test ragged[:, end] == [6.0, 7.0, 8.0, 9.0]
181192
@test ragged[:, 2:end] == VectorOfArray(ragged.u[2:end])
193+
@test ragged[:, (end - 1):end] == VectorOfArray(ragged.u[(end - 1):end])
182194

183195
ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
184196
@test ragged2[end, 1] == 4.0
@@ -199,6 +211,7 @@ ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
199211
@test ragged2[1:(end - 1), 1] == [1.0, 2.0, 3.0]
200212
@test ragged2[1:(end - 1), 2] == [5.0]
201213
@test ragged2[1:(end - 1), 3] == [7.0, 8.0]
214+
@test ragged2[:, (end - 1):end] == VectorOfArray(ragged2.u[(end - 1):end])
202215

203216
# Test that RaggedEnd and RaggedRange broadcast as scalars
204217
# (fixes issue with SymbolicIndexingInterface where broadcasting over RaggedEnd would fail)
@@ -222,6 +235,31 @@ u = VectorOfArray([[1.0], [2.0, 3.0]])
222235
u[:, 2] .= [10.0, 11.0]
223236
@test u.u[2] == [10.0, 11.0]
224237

238+
# Test DiffEqArray with 2D inner arrays (matrices)
239+
t = 1:2
240+
recs_2d = [rand(2, 3), rand(2, 4)]
241+
diffeq_2d = DiffEqArray(recs_2d, t)
242+
@test diffeq_2d[:, 1] == recs_2d[1]
243+
@test diffeq_2d[:, 2] == recs_2d[2]
244+
@test diffeq_2d[:, 1:2] == DiffEqArray(recs_2d[1:2], t[1:2])
245+
@test diffeq_2d[:, 1:2].t == t[1:2]
246+
@test diffeq_2d[:, 2:end] == DiffEqArray(recs_2d[2:end], t[2:end])
247+
@test diffeq_2d[:, 2:end].t == t[2:end]
248+
@test diffeq_2d[:, (end - 1):end] == DiffEqArray(recs_2d[(end - 1):end], t[(end - 1):end])
249+
@test diffeq_2d[:, (end - 1):end].t == t[(end - 1):end]
250+
251+
# Test DiffEqArray with 3D inner arrays (tensors)
252+
recs_3d = [rand(2, 3, 4), rand(2, 3, 5)]
253+
diffeq_3d = DiffEqArray(recs_3d, t)
254+
@test diffeq_3d[:, :, :, 1] == recs_3d[1]
255+
@test diffeq_3d[:, :, :, 2] == recs_3d[2]
256+
@test diffeq_3d[:, :, :, 1:2] == DiffEqArray(recs_3d[1:2], t[1:2])
257+
@test diffeq_3d[:, :, :, 1:2].t == t[1:2]
258+
@test diffeq_3d[:, :, :, 2:end] == DiffEqArray(recs_3d[2:end], t[2:end])
259+
@test diffeq_3d[:, :, :, 2:end].t == t[2:end]
260+
@test diffeq_3d[:, :, :, (end - 1):end] == DiffEqArray(recs_3d[(end - 1):end], t[(end - 1):end])
261+
@test diffeq_3d[:, :, :, (end - 1):end].t == t[(end - 1):end]
262+
225263
# 2D inner arrays (matrices) with ragged second dimension
226264
u = VectorOfArray([zeros(1, n) for n in (2, 3)])
227265
@test length(view(u, 1, :, 1)) == 2

0 commit comments

Comments
 (0)