@@ -76,7 +76,8 @@ struct QuantizationResult {
7676 rotation : RotationMatrix ,
7777 centroids : Vec < f32 > ,
7878 all_indices : BufferMut < u8 > ,
79- /// Native-precision norms (matching the Vector element type).
79+ /// Native-precision norms (matching the Vector element type). Carries validity: null vectors
80+ /// have null norms.
8081 norms_array : ArrayRef ,
8182 padded_dim : usize ,
8283}
@@ -85,19 +86,22 @@ struct QuantizationResult {
8586/// normalize/rotate/quantize all rows.
8687///
8788/// Norms are computed in the native element precision via the [`L2Norm`] scalar function.
88- /// The rotation and centroid lookup happen in f32.
89+ /// The rotation and centroid lookup happen in f32. Null rows (per the input validity) produce
90+ /// all-zero codes.
8991#[ allow( clippy:: cast_possible_truncation) ]
9092fn turboquant_quantize_core (
9193 ext : & ExtensionArray ,
9294 fsl : & FixedSizeListArray ,
9395 seed : u64 ,
9496 bit_width : u8 ,
97+ validity : & Validity ,
9598 ctx : & mut ExecutionCtx ,
9699) -> VortexResult < QuantizationResult > {
97100 let dimension = fsl. list_size ( ) as usize ;
98101 let num_rows = fsl. len ( ) ;
99102
100- // Compute native-precision norms via the L2Norm scalar fn.
103+ // Compute native-precision norms via the L2Norm scalar fn. L2Norm propagates validity from
104+ // the input, so null vectors get null norms automatically.
101105 let norms_sfn = L2Norm :: try_new_array ( & ApproxOptions :: Exact , ext. as_ref ( ) . clone ( ) , num_rows) ?;
102106 let norms_array: ArrayRef = norms_sfn. into_array ( ) . execute ( ctx) ?;
103107 let norms_prim: PrimitiveArray = norms_array. to_canonical ( ) ?. into_primitive ( ) ;
@@ -125,6 +129,12 @@ fn turboquant_quantize_core(
125129
126130 let f32_slice = f32_elements. as_slice :: < f32 > ( ) ;
127131 for row in 0 ..num_rows {
132+ // Null vectors get all-zero codes.
133+ if !validity. is_valid ( row) ? {
134+ all_indices. extend ( std:: iter:: repeat_n ( 0u8 , padded_dim) ) ;
135+ continue ;
136+ }
137+
128138 let x = & f32_slice[ row * dimension..( row + 1 ) * dimension] ;
129139 let norm = f32_norms[ row] ;
130140
@@ -189,12 +199,10 @@ fn build_turboquant(
189199 )
190200}
191201
192- /// Encode a [`Vector`] extension array into a `TurboQuantArray`.
193- ///
194- /// The input must be a non-nullable [`Vector`] extension array. TurboQuant is a lossy encoding
195- /// that does not preserve null positions; callers must handle validity externally.
202+ /// Encode a [`Vector`](crate::vector::Vector) extension array into a `TurboQuantArray`.
196203///
197- /// [`Vector`]: crate::vector::Vector
204+ /// Nullable inputs are supported: null vectors get all-zero codes and null norms. The validity
205+ /// of the resulting TurboQuant array is carried by the norms child.
198206pub fn turboquant_encode (
199207 ext : & ExtensionArray ,
200208 config : & TurboQuantConfig ,
@@ -204,10 +212,6 @@ pub fn turboquant_encode(
204212 let storage = ext. storage_array ( ) ;
205213 let fsl = storage. to_canonical ( ) ?. into_fixed_size_list ( ) ;
206214
207- vortex_ensure ! (
208- fsl. dtype( ) . nullability( ) == Nullability :: NonNullable ,
209- "TurboQuant requires non-nullable input, got nullable FixedSizeListArray"
210- ) ;
211215 vortex_ensure ! (
212216 config. bit_width >= 1 && config. bit_width <= 8 ,
213217 "bit_width must be 1-8, got {}" ,
@@ -228,10 +232,11 @@ pub fn turboquant_encode(
228232 0 ,
229233 ) ?;
230234
231- // Norms dtype matches the element type.
235+ // Norms dtype matches the element type and carries the parent's nullability .
232236 let element_ptype = fsl. elements ( ) . dtype ( ) . as_ptype ( ) ;
237+ let norms_nullability = ext_dtype. nullability ( ) ;
233238 let empty_norms: ArrayRef = match_each_float_ptype ! ( element_ptype, |T | {
234- PrimitiveArray :: empty:: <T >( Nullability :: NonNullable ) . into_array( )
239+ PrimitiveArray :: empty:: <T >( norms_nullability ) . into_array( )
235240 } ) ;
236241
237242 let empty_centroids = PrimitiveArray :: empty :: < f32 > ( Nullability :: NonNullable ) ;
@@ -246,8 +251,9 @@ pub fn turboquant_encode(
246251 . into_array ( ) ) ;
247252 }
248253
254+ let validity = ext. as_ref ( ) . validity ( ) ?;
249255 let seed = config. seed . unwrap_or ( 42 ) ;
250- let core = turboquant_quantize_core ( ext, & fsl, seed, config. bit_width , ctx) ?;
256+ let core = turboquant_quantize_core ( ext, & fsl, seed, config. bit_width , & validity , ctx) ?;
251257
252258 Ok ( build_turboquant ( & fsl, core, ext_dtype) ?. into_array ( ) )
253259}
0 commit comments