@@ -585,6 +585,12 @@ Base.@propagate_inbounds function _getindex(
585585 return A. u[I]
586586end
587587
588+ Base. @propagate_inbounds function _getindex (
589+ A:: AbstractDiffEqArray , :: NotSymbolic , :: Colon , I:: Int
590+ )
591+ return A. u[I]
592+ end
593+
588594Base. @propagate_inbounds function _getindex (
589595 A:: AbstractVectorOfArray , :: NotSymbolic ,
590596 I:: Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon} ...
@@ -595,6 +601,33 @@ Base.@propagate_inbounds function _getindex(
595601 stack (getindex .(A. u[last (I)], tuple .(Base. front (I))... ))
596602 end
597603end
604+
605+ Base. @propagate_inbounds function _getindex (
606+ A:: AbstractDiffEqArray , :: NotSymbolic ,
607+ I:: Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon} ...
608+ )
609+ return if last (I) isa Int
610+ A. u[last (I)][Base. front (I)... ]
611+ else
612+ col_idxs = last (I)
613+ # Only preserve DiffEqArray type if all prefix indices are Colons (selecting whole inner arrays)
614+ if all (idx -> idx isa Colon, Base. front (I))
615+ # For Colon, select all columns
616+ if col_idxs isa Colon
617+ col_idxs = eachindex (A. u)
618+ end
619+ # For DiffEqArray, we need to preserve the time values and type
620+ # Create a vector of sliced arrays instead of stacking into higher-dim array
621+ u_slice = [A. u[col][Base. front (I)... ] for col in col_idxs]
622+ # Return as DiffEqArray with sliced time values
623+ return DiffEqArray (u_slice, A. t[col_idxs], parameter_values (A), symbolic_container (A))
624+ else
625+ # Prefix indices are not all Colons - do the same as VectorOfArray
626+ # (stack the results into a higher-dimensional array)
627+ return stack (getindex .(A. u[col_idxs], tuple .(Base. front (I))... ))
628+ end
629+ end
630+ end
598631Base. @propagate_inbounds function _getindex (
599632 VA:: AbstractVectorOfArray , :: NotSymbolic , ii:: CartesianIndex
600633 )
680713 return idx. dim == 0 ? idx. offset : idx
681714end
682715
716+ @inline function _column_indices (VA:: AbstractVectorOfArray , idx:: RaggedRange )
717+ # RaggedRange with dim=0 means it's a column range with pre-resolved indices
718+ if idx. dim == 0
719+ # Create a range with the offset as the stop value
720+ return Base. range (idx. start; step = idx. step, stop = idx. offset)
721+ else
722+ # dim != 0 means it's an inner-dimension range that needs column expansion
723+ return idx
724+ end
725+ end
726+
683727@inline _resolve_ragged_index (idx, :: AbstractVectorOfArray , :: Any ) = idx
684728@inline function _resolve_ragged_index (idx:: RaggedEnd , VA:: AbstractVectorOfArray , col)
685729 if idx. dim == 0
@@ -763,27 +807,54 @@ end
763807 return (Base. front (args)... , resolved_last)
764808 end
765809 elseif args[end ] isa RaggedRange
766- resolved_last = _resolve_ragged_index (args[end ], A, 1 )
767- if length (args) == 1
768- return (resolved_last,)
810+ # Only pre-resolve if it's an inner-dimension range (dim != 0)
811+ # Column ranges (dim == 0) are handled later by _column_indices
812+ if args[end ]. dim == 0
813+ # Column range - let _column_indices handle it
814+ return args
769815 else
770- return (Base. front (args)... , resolved_last)
816+ resolved_last = _resolve_ragged_index (args[end ], A, 1 )
817+ if length (args) == 1
818+ return (resolved_last,)
819+ else
820+ return (Base. front (args)... , resolved_last)
821+ end
771822 end
772823 end
773824 return args
774825end
775826
827+ # Helper function to preserve DiffEqArray type when slicing
828+ @inline function _preserve_array_type (A:: AbstractVectorOfArray , u_slice, col_idxs)
829+ return VectorOfArray (u_slice)
830+ end
831+
832+ @inline function _preserve_array_type (A:: AbstractDiffEqArray , u_slice, col_idxs)
833+ return DiffEqArray (u_slice, A. t[col_idxs], parameter_values (A), symbolic_container (A))
834+ end
835+
776836@inline function _ragged_getindex (A:: AbstractVectorOfArray , I... )
777837 n = ndims (A)
778838 # Special-case when user provided one fewer index than ndims(A): last index is column selector.
779839 if length (I) == n - 1
780840 raw_cols = last (I)
841+ # Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve)
842+ is_column_selection = if raw_cols isa RaggedEnd && raw_cols. dim != 0
843+ false # Inner dimension - don't preserve type
844+ elseif raw_cols isa RaggedRange && raw_cols. dim != 0
845+ true # Inner dimension range converted to column range - DO preserve type
846+ else
847+ true # Column selection (dim == 0 or not ragged)
848+ end
849+
781850 # If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
782851 cols = if raw_cols isa RaggedEnd && raw_cols. dim != 0
783852 lastindex (A. u) + raw_cols. offset
784853 elseif raw_cols isa RaggedRange && raw_cols. dim != 0
854+ # Convert inner-dimension range to column range by resolving bounds
855+ start_val = raw_cols. start < 0 ? lastindex (A. u) + raw_cols. start : raw_cols. start
785856 stop_val = lastindex (A. u) + raw_cols. offset
786- Base. range (raw_cols . start ; step = raw_cols. step, stop = stop_val)
857+ Base. range (start_val ; step = raw_cols. step, stop = stop_val)
787858 else
788859 _column_indices (A, raw_cols)
789860 end
@@ -806,37 +877,41 @@ end
806877 end
807878 return A. u[cols][padded... ]
808879 else
809- return VectorOfArray (
810- [
811- begin
812- resolved_prefix = _resolve_ragged_indices (prefix, A, col)
813- inner_nd = ndims (A. u[col])
814- n_missing = inner_nd - length (resolved_prefix)
815- padded = if n_missing > 0
816- if all (idx -> idx === Colon (), resolved_prefix)
817- (
818- resolved_prefix... ,
819- ntuple (_ -> Colon (), n_missing)... ,
820- )
821- else
822- (
823- resolved_prefix... ,
824- (
825- lastindex (
826- A. u[col],
827- length (resolved_prefix) + i
828- ) for i in 1 : n_missing
829- ). .. ,
830- )
831- end
880+ u_slice = [
881+ begin
882+ resolved_prefix = _resolve_ragged_indices (prefix, A, col)
883+ inner_nd = ndims (A. u[col])
884+ n_missing = inner_nd - length (resolved_prefix)
885+ padded = if n_missing > 0
886+ if all (idx -> idx === Colon (), resolved_prefix)
887+ (
888+ resolved_prefix... ,
889+ ntuple (_ -> Colon (), n_missing)... ,
890+ )
832891 else
833- resolved_prefix
834- end
835- A. u[col][padded... ]
892+ (
893+ resolved_prefix... ,
894+ (
895+ lastindex (
896+ A. u[col],
897+ length (resolved_prefix) + i
898+ ) for i in 1 : n_missing
899+ ). .. ,
900+ )
836901 end
837- for col in cols
838- ]
839- )
902+ else
903+ resolved_prefix
904+ end
905+ A. u[col][padded... ]
906+ end
907+ for col in cols
908+ ]
909+ # Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions
910+ if is_column_selection
911+ return _preserve_array_type (A, u_slice, cols)
912+ else
913+ return VectorOfArray (u_slice)
914+ end
840915 end
841916 end
842917
870945 if col_idxs isa Int
871946 return A. u[col_idxs]
872947 else
873- return VectorOfArray (A . u[col_idxs])
948+ return _preserve_array_type (A, A . u[col_idxs], col_idxs )
874949 end
875950 end
876951 # If col_idxs resolved to a single Int, handle it directly
0 commit comments