@@ -8,6 +8,7 @@ use vortex::array::arrays::ConstantArray;
88use vortex:: array:: arrays:: Extension ;
99use vortex:: array:: arrays:: PrimitiveArray ;
1010use vortex:: dtype:: DType ;
11+ use vortex:: dtype:: NativePType ;
1112use vortex:: dtype:: PType ;
1213use vortex:: dtype:: extension:: ExtDTypeRef ;
1314use vortex:: error:: VortexResult ;
@@ -60,28 +61,57 @@ pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult<ArrayRef> {
6061 Ok ( ext. storage_array ( ) . clone ( ) )
6162}
6263
63- // TODO(connor): it would be nicer if this took a generic parameter and a FnMut arg that we run
64- // directly on the values without having to return this ugly stride.
64+ /// The flat primitive elements of a tensor storage array, with typed row access.
65+ ///
66+ /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a
67+ /// constant input materializes only a single row (stride=0), while a full array uses
68+ /// stride=list_size.
69+ pub ( crate ) struct FlatElements {
70+ elems : PrimitiveArray ,
71+ stride : usize ,
72+ list_size : usize ,
73+ }
74+
75+ impl FlatElements {
76+ /// Returns the [`PType`] of the underlying elements.
77+ pub fn ptype ( & self ) -> PType {
78+ self . elems . ptype ( )
79+ }
80+
81+ /// Returns the `i`-th row as a typed slice of length `list_size`.
82+ pub fn row < T : NativePType > ( & self , i : usize ) -> & [ T ] {
83+ let slice = self . elems . as_slice :: < T > ( ) ;
84+ & slice[ i * self . stride ..i * self . stride + self . list_size ]
85+ }
86+ }
87+
6588/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
6689///
6790/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
68- /// materialized to avoid expanding it to the full column length. Returns `(elements, stride)`
69- /// where `stride` is `list_size` for a full array and `0` for a constant.
91+ /// materialized to avoid expanding it to the full column length.
7092pub ( crate ) fn extract_flat_elements (
7193 storage : & ArrayRef ,
7294 list_size : usize ,
73- ) -> VortexResult < ( PrimitiveArray , usize ) > {
95+ ) -> VortexResult < FlatElements > {
7496 if let Some ( constant) = storage. as_opt :: < Constant > ( ) {
7597 // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge
7698 // amount of data.
7799 let single = ConstantArray :: new ( constant. scalar ( ) . clone ( ) , 1 ) . into_array ( ) ;
78100 let fsl = single. to_canonical ( ) ?. into_fixed_size_list ( ) ;
79101 let elems = fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
80- return Ok ( ( elems, 0 ) ) ;
102+ return Ok ( FlatElements {
103+ elems,
104+ stride : 0 ,
105+ list_size,
106+ } ) ;
81107 }
82108
83109 // Otherwise we have to fully expand all of the data.
84110 let fsl = storage. to_canonical ( ) ?. into_fixed_size_list ( ) ;
85111 let elems = fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
86- Ok ( ( elems, list_size) )
112+ Ok ( FlatElements {
113+ elems,
114+ stride : list_size,
115+ list_size,
116+ } )
87117}
0 commit comments