Skip to content

Commit 1bc7c27

Browse files
committed
fix returning DiffEqArray
1 parent 7632937 commit 1bc7c27

File tree

2 files changed

+145
-36
lines changed

2 files changed

+145
-36
lines changed

src/vector_of_array.jl

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,12 @@ Base.@propagate_inbounds function _getindex(
585585
return A.u[I]
586586
end
587587

588+
Base.@propagate_inbounds function _getindex(
589+
A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Int
590+
)
591+
return A.u[I]
592+
end
593+
588594
Base.@propagate_inbounds function _getindex(
589595
A::AbstractVectorOfArray, ::NotSymbolic,
590596
I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...
@@ -595,6 +601,33 @@ Base.@propagate_inbounds function _getindex(
595601
stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...))
596602
end
597603
end
604+
605+
Base.@propagate_inbounds function _getindex(
606+
A::AbstractDiffEqArray, ::NotSymbolic,
607+
I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...
608+
)
609+
return if last(I) isa Int
610+
A.u[last(I)][Base.front(I)...]
611+
else
612+
col_idxs = last(I)
613+
# Only preserve DiffEqArray type if all prefix indices are Colons (selecting whole inner arrays)
614+
if all(idx -> idx isa Colon, Base.front(I))
615+
# For Colon, select all columns
616+
if col_idxs isa Colon
617+
col_idxs = eachindex(A.u)
618+
end
619+
# For DiffEqArray, we need to preserve the time values and type
620+
# Create a vector of sliced arrays instead of stacking into higher-dim array
621+
u_slice = [A.u[col][Base.front(I)...] for col in col_idxs]
622+
# Return as DiffEqArray with sliced time values
623+
return DiffEqArray(u_slice, A.t[col_idxs], parameter_values(A), symbolic_container(A))
624+
else
625+
# Prefix indices are not all Colons - do the same as VectorOfArray
626+
# (stack the results into a higher-dimensional array)
627+
return stack(getindex.(A.u[col_idxs], tuple.(Base.front(I))...))
628+
end
629+
end
630+
end
598631
Base.@propagate_inbounds function _getindex(
599632
VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex
600633
)
@@ -680,6 +713,17 @@ end
680713
return idx.dim == 0 ? idx.offset : idx
681714
end
682715

716+
@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedRange)
717+
# RaggedRange with dim=0 means it's a column range with pre-resolved indices
718+
if idx.dim == 0
719+
# Create a range with the offset as the stop value
720+
return Base.range(idx.start; step = idx.step, stop = idx.offset)
721+
else
722+
# dim != 0 means it's an inner-dimension range that needs column expansion
723+
return idx
724+
end
725+
end
726+
683727
@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
684728
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
685729
if idx.dim == 0
@@ -763,27 +807,54 @@ end
763807
return (Base.front(args)..., resolved_last)
764808
end
765809
elseif args[end] isa RaggedRange
766-
resolved_last = _resolve_ragged_index(args[end], A, 1)
767-
if length(args) == 1
768-
return (resolved_last,)
810+
# Only pre-resolve if it's an inner-dimension range (dim != 0)
811+
# Column ranges (dim == 0) are handled later by _column_indices
812+
if args[end].dim == 0
813+
# Column range - let _column_indices handle it
814+
return args
769815
else
770-
return (Base.front(args)..., resolved_last)
816+
resolved_last = _resolve_ragged_index(args[end], A, 1)
817+
if length(args) == 1
818+
return (resolved_last,)
819+
else
820+
return (Base.front(args)..., resolved_last)
821+
end
771822
end
772823
end
773824
return args
774825
end
775826

827+
# Helper function to preserve DiffEqArray type when slicing
828+
@inline function _preserve_array_type(A::AbstractVectorOfArray, u_slice, col_idxs)
829+
return VectorOfArray(u_slice)
830+
end
831+
832+
@inline function _preserve_array_type(A::AbstractDiffEqArray, u_slice, col_idxs)
833+
return DiffEqArray(u_slice, A.t[col_idxs], parameter_values(A), symbolic_container(A))
834+
end
835+
776836
@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
777837
n = ndims(A)
778838
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
779839
if length(I) == n - 1
780840
raw_cols = last(I)
841+
# Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve)
842+
is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0
843+
false # Inner dimension - don't preserve type
844+
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
845+
true # Inner dimension range converted to column range - DO preserve type
846+
else
847+
true # Column selection (dim == 0 or not ragged)
848+
end
849+
781850
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
782851
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
783852
lastindex(A.u) + raw_cols.offset
784853
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
854+
# Convert inner-dimension range to column range by resolving bounds
855+
start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start
785856
stop_val = lastindex(A.u) + raw_cols.offset
786-
Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val)
857+
Base.range(start_val; step = raw_cols.step, stop = stop_val)
787858
else
788859
_column_indices(A, raw_cols)
789860
end
@@ -806,37 +877,41 @@ end
806877
end
807878
return A.u[cols][padded...]
808879
else
809-
return VectorOfArray(
810-
[
811-
begin
812-
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
813-
inner_nd = ndims(A.u[col])
814-
n_missing = inner_nd - length(resolved_prefix)
815-
padded = if n_missing > 0
816-
if all(idx -> idx === Colon(), resolved_prefix)
817-
(
818-
resolved_prefix...,
819-
ntuple(_ -> Colon(), n_missing)...,
820-
)
821-
else
822-
(
823-
resolved_prefix...,
824-
(
825-
lastindex(
826-
A.u[col],
827-
length(resolved_prefix) + i
828-
) for i in 1:n_missing
829-
)...,
830-
)
831-
end
880+
u_slice = [
881+
begin
882+
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
883+
inner_nd = ndims(A.u[col])
884+
n_missing = inner_nd - length(resolved_prefix)
885+
padded = if n_missing > 0
886+
if all(idx -> idx === Colon(), resolved_prefix)
887+
(
888+
resolved_prefix...,
889+
ntuple(_ -> Colon(), n_missing)...,
890+
)
832891
else
833-
resolved_prefix
834-
end
835-
A.u[col][padded...]
892+
(
893+
resolved_prefix...,
894+
(
895+
lastindex(
896+
A.u[col],
897+
length(resolved_prefix) + i
898+
) for i in 1:n_missing
899+
)...,
900+
)
836901
end
837-
for col in cols
838-
]
839-
)
902+
else
903+
resolved_prefix
904+
end
905+
A.u[col][padded...]
906+
end
907+
for col in cols
908+
]
909+
# Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions
910+
if is_column_selection
911+
return _preserve_array_type(A, u_slice, cols)
912+
else
913+
return VectorOfArray(u_slice)
914+
end
840915
end
841916
end
842917

@@ -870,7 +945,7 @@ end
870945
if col_idxs isa Int
871946
return A.u[col_idxs]
872947
else
873-
return VectorOfArray(A.u[col_idxs])
948+
return _preserve_array_type(A, A.u[col_idxs], col_idxs)
874949
end
875950
end
876951
# If col_idxs resolved to a single Int, handle it directly

test/basic_indexing.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ diffeq = DiffEqArray(recs, t)
7777
@test diffeq[:, 1] == testa[:, 1]
7878
@test diffeq.u == recs
7979
@test diffeq[:, end] == testa[:, end]
80-
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:length(recs)], t)
80+
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:length(recs)], t[2:end])
81+
@test diffeq[:, 2:end].t == t[2:end]
82+
@test diffeq[:, end - 1:end] == DiffEqArray([recs[i] for i in (length(recs) - 1):length(recs)], t[(length(t) - 1):length(t)])
83+
@test diffeq[:, end - 1:end].t == t[(length(t) - 1):length(t)]
8184

8285
# ## (Int, Int)
8386
@test testa[5, 4] == testva[5, 4]
@@ -148,6 +151,12 @@ diffeq = DiffEqArray(recs, t)
148151
@test testva[1:2, 1:2] == [1 3; 2 5]
149152
@test diffeq[:, 1] == recs[1]
150153
@test diffeq[1:2, 1:2] == [1 3; 2 5]
154+
@test diffeq[:, 1:2] == DiffEqArray([recs[i] for i in 1:2], t[1:2])
155+
@test diffeq[:, 1:2].t == t[1:2]
156+
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:3], t[2:end])
157+
@test diffeq[:, 2:end].t == t[2:end]
158+
@test diffeq[:, end - 1:end] == DiffEqArray([recs[i] for i in (length(recs) - 1):length(recs)], t[(length(t) - 1):length(t)])
159+
@test diffeq[:, end - 1:end].t == t[(length(t) - 1):length(t)]
151160

152161
# Test views of heterogeneous arrays (issue #453)
153162
f = VectorOfArray([[1.0], [2.0, 3.0]])
@@ -224,6 +233,31 @@ u = VectorOfArray([[1.0], [2.0, 3.0]])
224233
u[:, 2] .= [10.0, 11.0]
225234
@test u.u[2] == [10.0, 11.0]
226235

236+
# Test DiffEqArray with 2D inner arrays (matrices)
237+
t = 1:2
238+
recs_2d = [rand(2, 3), rand(2, 4)]
239+
diffeq_2d = DiffEqArray(recs_2d, t)
240+
@test diffeq_2d[:, 1] == recs_2d[1]
241+
@test diffeq_2d[:, 2] == recs_2d[2]
242+
@test diffeq_2d[:, 1:2] == DiffEqArray(recs_2d[1:2], t[1:2])
243+
@test diffeq_2d[:, 1:2].t == t[1:2]
244+
@test diffeq_2d[:, 2:end] == DiffEqArray(recs_2d[2:end], t[2:end])
245+
@test diffeq_2d[:, 2:end].t == t[2:end]
246+
@test diffeq_2d[:, end-1:end] == DiffEqArray(recs_2d[end-1:end], t[end-1:end])
247+
@test diffeq_2d[:, end-1:end].t == t[end-1:end]
248+
249+
# Test DiffEqArray with 3D inner arrays (tensors)
250+
recs_3d = [rand(2, 3, 4), rand(2, 3, 5)]
251+
diffeq_3d = DiffEqArray(recs_3d, t)
252+
@test diffeq_3d[:, :, :, 1] == recs_3d[1]
253+
@test diffeq_3d[:, :, :, 2] == recs_3d[2]
254+
@test diffeq_3d[:, :, :, 1:2] == DiffEqArray(recs_3d[1:2], t[1:2])
255+
@test diffeq_3d[:, :, :, 1:2].t == t[1:2]
256+
@test diffeq_3d[:, :, :, 2:end] == DiffEqArray(recs_3d[2:end], t[2:end])
257+
@test diffeq_3d[:, :, :, 2:end].t == t[2:end]
258+
@test diffeq_3d[:, :, :, end-1:end] == DiffEqArray(recs_3d[end-1:end], t[end-1:end])
259+
@test diffeq_3d[:, :, :, end-1:end].t == t[end-1:end]
260+
227261
# 2D inner arrays (matrices) with ragged second dimension
228262
u = VectorOfArray([zeros(1, n) for n in (2, 3)])
229263
@test length(view(u, 1, :, 1)) == 2

0 commit comments

Comments
 (0)