22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44use num_traits:: Float ;
5+ use num_traits:: Zero ;
56use vortex:: array:: ArrayRef ;
67use vortex:: array:: ExecutionCtx ;
78use vortex:: array:: IntoArray ;
@@ -57,7 +58,8 @@ pub struct NormVectorArray {
5758}
5859
5960impl NormVectorArray {
60- /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms.
61+ /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and associated L2
62+ /// norms for each vector.
6163 ///
6264 /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and
6365 /// `norms` must be a primitive array of the same float type with the same length. The
@@ -113,12 +115,15 @@ impl NormVectorArray {
113115 /// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs
114116 /// are supported; the validity mask is preserved and the normalized data for null rows is
115117 /// unspecified.
118+ ///
119+ /// Note that compression is lossy per floating-point operations.
116120 pub fn compress ( vector_array : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < Self > {
117121 let ext = Self :: validate ( & vector_array) ?;
118122
119123 let list_size = extension_list_size ( & ext) ?;
120124 let row_count = vector_array. len ( ) ;
121125 let nullability = Nullability :: from ( vector_array. dtype ( ) . is_nullable ( ) ) ;
126+ let validity = vector_array. validity ( ) ?;
122127
123128 // Compute L2 norms using the scalar function. If the input is nullable, the norms will
124129 // also be nullable (null vectors produce null norms).
@@ -135,10 +140,17 @@ impl NormVectorArray {
135140 let norms_slice = norms_prim. as_slice:: <T >( ) ;
136141
137142 let normalized_elems: PrimitiveArray = ( 0 ..row_count)
138- . flat_map( |i| {
143+ . map( |i| -> VortexResult <Vec <T >> {
144+ if !validity. is_valid( i) ? {
145+ return Ok ( vec![ T :: zero( ) ; list_size] ) ;
146+ }
147+
139148 let inv_norm = safe_inv_norm( norms_slice[ i] ) ;
140- flat. row:: <T >( i) . iter( ) . map( move |& v| v * inv_norm)
149+ Ok ( flat. row:: <T >( i) . iter( ) . map( |& v| v * inv_norm) . collect ( ) )
141150 } )
151+ . collect:: <VortexResult <Vec <Vec <T >>>>( ) ?
152+ . into_iter( )
153+ . flatten( )
142154 . collect( ) ;
143155
144156 // Reconstruct the vector array with the same nullability as the input.
0 commit comments