@@ -6,25 +6,33 @@ use num_traits::Zero;
66use vortex:: array:: ArrayRef ;
77use vortex:: array:: ExecutionCtx ;
88use vortex:: array:: IntoArray ;
9+ use vortex:: array:: LEGACY_SESSION ;
10+ use vortex:: array:: VortexSessionExecute ;
911use vortex:: array:: arrays:: ExtensionArray ;
1012use vortex:: array:: arrays:: FixedSizeListArray ;
1113use vortex:: array:: arrays:: PrimitiveArray ;
14+ use vortex:: array:: builtins:: ArrayBuiltins ;
1215use vortex:: array:: match_each_float_ptype;
1316use vortex:: array:: stats:: ArrayStats ;
1417use vortex:: array:: validity:: Validity ;
1518use vortex:: dtype:: DType ;
1619use vortex:: dtype:: Nullability ;
1720use vortex:: dtype:: extension:: ExtDType ;
1821use vortex:: dtype:: extension:: ExtDTypeRef ;
22+ use vortex:: encodings:: runend:: RunEndArray ;
23+ use vortex:: encodings:: sequence:: SequenceArray ;
24+ use vortex:: error:: VortexExpect ;
1925use vortex:: error:: VortexResult ;
2026use vortex:: error:: vortex_ensure;
2127use vortex:: error:: vortex_ensure_eq;
2228use vortex:: error:: vortex_err;
2329use vortex:: expr:: Expression ;
2430use vortex:: expr:: root;
2531use vortex:: extension:: EmptyMetadata ;
32+ use vortex:: scalar:: PValue ;
2633use vortex:: scalar_fn:: EmptyOptions ;
2734use vortex:: scalar_fn:: ScalarFn ;
35+ use vortex:: scalar_fn:: fns:: operators:: Operator ;
2836
2937use crate :: scalar_fns:: l2_norm:: L2Norm ;
3038use crate :: utils:: extension_element_ptype;
@@ -45,12 +53,13 @@ pub struct NormVectorArray {
4553 /// The backing vector array that has been unit normalized.
4654 ///
4755 /// The underlying elements of the vector array must be floating-point. This child may be
48- /// nullable; its validity determines the validity of the `NormVectorArray`.
56+ /// nullable, and its validity determines the validity of the `NormVectorArray`.
4957 pub ( crate ) vector_array : ArrayRef ,
5058
5159 /// The L2 norms of each vector.
5260 ///
53- /// This must have the same dtype as the elements of the vector array.
61+ /// This must have the same validity as the vector array, and the same dtype as the elements of
62+ /// the vector array.
5463 pub ( crate ) norms : ArrayRef ,
5564
5665 /// Stats set owned by this array.
@@ -65,7 +74,7 @@ impl NormVectorArray {
6574 /// `norms` must be a primitive array of the same float type with the same length. The
6675 /// `vector_array` may be nullable.
6776 pub fn try_new ( vector_array : ArrayRef , norms : ArrayRef ) -> VortexResult < Self > {
68- let ext = Self :: validate ( & vector_array) ?;
77+ let ext = Self :: validate ( & vector_array, & norms ) ?;
6978
7079 let element_ptype = extension_element_ptype ( & ext) ?;
7180
@@ -90,9 +99,9 @@ impl NormVectorArray {
9099 } )
91100 }
92101
93- /// Validates that the given array has the [`Vector`] extension type and returns the extension
94- /// dtype .
95- fn validate ( vector_array : & ArrayRef ) -> VortexResult < ExtDTypeRef > {
102+ /// Validates that the given array has the [`Vector`] extension type and returns the
103+ /// [`ExtDTypeRef`] of the vector array on success .
104+ fn validate_vector_array ( vector_array : & ArrayRef ) -> VortexResult < ExtDTypeRef > {
96105 let ext = vector_array. dtype ( ) . as_extension_opt ( ) . ok_or_else ( || {
97106 vortex_err ! (
98107 "vector_array dtype must be an extension type, got {}" ,
@@ -109,6 +118,54 @@ impl NormVectorArray {
109118 Ok ( ext. clone ( ) )
110119 }
111120
121+ /// Validates that the given `vector_array` and `norms` are compatible.
122+ ///
123+ /// Checks that:
124+ /// - The `vector_array` has the [`Vector`] extension type.
125+ /// - Both arrays have the same length.
126+ /// - The element primitive type of the vectors matches the primitive type of the norms.
127+ /// - Both arrays share the same validity mask.
128+ ///
129+ /// Returns the [`ExtDTypeRef`] of the vector array on success.
130+ fn validate ( vector_array : & ArrayRef , norms : & ArrayRef ) -> VortexResult < ExtDTypeRef > {
131+ let ext = Self :: validate_vector_array ( vector_array) ?;
132+
133+ vortex_ensure_eq ! (
134+ vector_array. len( ) ,
135+ norms. len( ) ,
136+ "vector_array and norms must have the same length"
137+ ) ;
138+
139+ let element_ptype = extension_element_ptype ( & ext) ?;
140+ vortex_ensure_eq ! (
141+ element_ptype,
142+ norms. dtype( ) . as_ptype( ) ,
143+ "vector elements ptype must be the same as the norms ptype"
144+ ) ;
145+
146+ // TODO(connor): Is there a better way to do this?
147+ let mut ctx = LEGACY_SESSION . create_execution_ctx ( ) ;
148+ let mask_eq = vector_array
149+ . validity ( ) ?
150+ . mask_eq ( & norms. validity ( ) ?, & mut ctx) ?;
151+ vortex_ensure ! (
152+ mask_eq,
153+ "vector_array and norms must have the same validity"
154+ ) ;
155+
156+ Ok ( ext)
157+ }
158+
159+ /// Returns a reference to the backing vector array that has been unit normalized.
160+ pub fn vector_array ( & self ) -> & ArrayRef {
161+ & self . vector_array
162+ }
163+
164+ /// Returns a reference to the L2 norms of each vector.
165+ pub fn norms ( & self ) -> & ArrayRef {
166+ & self . norms
167+ }
168+
112169 /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
113170 /// dividing each vector by its norm.
114171 ///
@@ -118,9 +175,9 @@ impl NormVectorArray {
118175 ///
119176 /// Note that compression is lossy per floating-point operations.
120177 pub fn compress ( vector_array : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < Self > {
121- let ext = Self :: validate ( & vector_array) ?;
178+ let ext = Self :: validate_vector_array ( & vector_array) ?;
122179
123- let list_size = extension_list_size ( & ext) ?;
180+ let list_size = extension_list_size ( & ext) ? as usize ;
124181 let row_count = vector_array. len ( ) ;
125182 let nullability = Nullability :: from ( vector_array. dtype ( ) . is_nullable ( ) ) ;
126183 let validity = vector_array. validity ( ) ?;
@@ -170,57 +227,62 @@ impl NormVectorArray {
170227 } )
171228 }
172229
173- /// Returns a reference to the backing vector array that has been unit normalized.
174- pub fn vector_array ( & self ) -> & ArrayRef {
175- & self . vector_array
176- }
177-
178- /// Returns a reference to the L2 norms of each vector.
179- pub fn norms ( & self ) -> & ArrayRef {
180- & self . norms
181- }
182-
183230 /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
184231 ///
185232 /// The returned array has the same dtype (including nullability) as the original
186233 /// `vector_array` child.
187234 pub fn decompress ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < ArrayRef > {
188- let ext = Self :: validate ( & self . vector_array ) ?;
189- let nullability = Nullability :: from ( self . vector_array . dtype ( ) . is_nullable ( ) ) ;
190-
191- let list_size = extension_list_size ( & ext) ?;
192- let row_count = self . vector_array . len ( ) ;
235+ let ext = self
236+ . dtype ( )
237+ . as_extension_opt ( )
238+ . vortex_expect ( "somehow had a non-extension dtype" ) ;
193239
194240 let storage = extension_storage ( & self . vector_array ) ?;
195- let flat = extract_flat_elements ( & storage, list_size , ctx) ?;
241+ let fsl : FixedSizeListArray = storage. execute ( ctx) ?;
196242
197- let norms_prim: PrimitiveArray = self . norms . clone ( ) . execute ( ctx) ?;
243+ let denormalized_fsl =
244+ broadcast_binary_to_elements ( fsl, self . norms . clone ( ) , Operator :: Mul , ctx) ?;
198245
199- match_each_float_ptype ! ( flat. ptype( ) , |T | {
200- let norms_slice = norms_prim. as_slice:: <T >( ) ;
201-
202- let result_elems: PrimitiveArray = ( 0 ..row_count)
203- . flat_map( |i| {
204- let norm = norms_slice[ i] ;
205- flat. row:: <T >( i) . iter( ) . map( move |& v| v * norm)
206- } )
207- . collect( ) ;
208-
209- let validity = Validity :: from( nullability) ;
210- let fsl = FixedSizeListArray :: new(
211- result_elems. into_array( ) ,
212- u32 :: try_from( list_size) ?,
213- validity,
214- row_count,
215- ) ;
216-
217- let ext_dtype =
218- ExtDType :: <Vector >:: try_new( EmptyMetadata , fsl. dtype( ) . clone( ) ) ?. erased( ) ;
219- Ok ( ExtensionArray :: new( ext_dtype, fsl. into_array( ) ) . into_array( ) )
220- } )
246+ Ok ( ExtensionArray :: new ( ext. clone ( ) , denormalized_fsl. into_array ( ) ) . into_array ( ) )
221247 }
222248}
223249
250+ /// We do not have any kind of "broadcast" expression where we evaluate a binary expression between
251+ /// every `FixedSizeList` element and another value. We can mimic this by creating a
252+ /// `RunEnd(Sequence)` array that we evaluate with the elements of the [`FixedSizeListArray`].
253+ fn broadcast_binary_to_elements (
254+ fsl : FixedSizeListArray ,
255+ values : ArrayRef ,
256+ op : Operator ,
257+ ctx : & mut ExecutionCtx ,
258+ ) -> VortexResult < FixedSizeListArray > {
259+ let num_lists = fsl. len ( ) ;
260+ let list_size = fsl. list_size ( ) ;
261+ let validity = fsl. validity ( ) ?;
262+ let elements = fsl. elements ( ) ;
263+ debug_assert ! ( elements. dtype( ) . is_primitive( ) ) ;
264+
265+ // Create the broadcasting array via a runend array with a sequence of ends.
266+ let base: PValue = list_size. into ( ) ;
267+ let multiplier: PValue = base;
268+ let ends_ptype = base. ptype ( ) ;
269+ let ends_nullability = Nullability :: NonNullable ;
270+
271+ let ends = SequenceArray :: try_new ( base, multiplier, ends_ptype, ends_nullability, num_lists) ?;
272+ let runend = RunEndArray :: try_new ( ends. into_array ( ) , values) ?;
273+
274+ let binary_eval = elements. binary ( runend. into_array ( ) , op) ?;
275+ let executed: PrimitiveArray = binary_eval. execute ( ctx) ?;
276+
277+ // SAFETY: We simply evaluated a scalar function on all of the elements, so none of the length
278+ // properties have changed.
279+ let fsl = unsafe {
280+ FixedSizeListArray :: new_unchecked ( executed. into_array ( ) , list_size, validity, num_lists)
281+ } ;
282+
283+ Ok ( fsl)
284+ }
285+
224286/// Returns `1 / norm` if the norm is non-zero, or zero otherwise.
225287///
226288/// This avoids division by zero for zero-length or all-zero vectors.
0 commit comments