88//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm),
99//! which the [`TurboQuantScheme`](super::TurboQuantScheme) calls before invoking this function.
1010
11+ use num_traits:: ToPrimitive ;
1112use vortex_array:: ArrayRef ;
1213use vortex_array:: ArrayView ;
1314use vortex_array:: ExecutionCtx ;
@@ -19,6 +20,7 @@ use vortex_array::arrays::extension::ExtensionArrayExt;
1920use vortex_array:: arrays:: fixed_size_list:: FixedSizeListArrayExt ;
2021use vortex_array:: dtype:: Nullability ;
2122use vortex_array:: dtype:: PType ;
23+ use vortex_array:: match_each_float_ptype;
2224use vortex_array:: validity:: Validity ;
2325use vortex_buffer:: BufferMut ;
2426use vortex_error:: VortexExpect ;
@@ -33,6 +35,13 @@ use crate::encodings::turboquant::array::centroids::find_nearest_centroid;
3335use crate :: encodings:: turboquant:: array:: centroids:: get_centroids;
3436use crate :: encodings:: turboquant:: array:: rotation:: RotationMatrix ;
3537use crate :: encodings:: turboquant:: vtable:: TurboQuantArray ;
38+ use crate :: scalar_fns:: ApproxOptions ;
39+ use crate :: scalar_fns:: l2_norm:: L2Norm ;
40+ use crate :: vector:: AnyVector ;
41+
42+ /// Tolerance for the unit-norm check in [`turboquant_encode`]. Each row's L2 norm must be within
43+ /// this distance of 1.0 (or be exactly 0.0 for zero vectors).
44+ const UNIT_NORM_TOLERANCE : f64 = 1e-10 ;
3645
3746/// Configuration for TurboQuant encoding.
3847#[ derive( Clone , Debug ) ]
@@ -99,8 +108,9 @@ struct QuantizationResult {
99108
100109/// Core quantization: rotate and quantize already-normalized rows.
101110///
102- /// The input `fsl` must contain unit-norm vectors (already L2-normalized). The rotation and
103- /// centroid lookup happen in f32.
111+ /// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null
112+ /// vectors are not supported and must be zeroed out before reaching this function. The rotation
113+ /// and centroid lookup happen in f32.
104114fn turboquant_quantize_core (
105115 fsl : & FixedSizeListArray ,
106116 seed : u64 ,
@@ -186,7 +196,12 @@ fn build_turboquant(
186196/// [`TurboQuantArray`].
187197///
188198/// The input must be a non-nullable Vector extension array whose rows are already unit-norm.
189- /// Normalization is handled externally (e.g. by [`normalize_as_l2_denorm`]).
199+ /// **Null vectors are not supported.** The caller must normalize and strip nullability before
200+ /// calling this function, for example via [`normalize_as_l2_denorm`].
201+ ///
202+ /// This function validates that every row has L2 norm within `UNIT_NORM_TOLERANCE` of 1.0 (or is
203+ /// exactly 0.0). Use [`turboquant_encode_unchecked`] to skip this check when the caller has just
204+ /// performed normalization.
190205///
191206/// The returned array is a plain [`TurboQuantArray`] that decompresses to unit-norm vectors.
192207/// The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray if the original
@@ -200,13 +215,61 @@ pub fn turboquant_encode(
200215 ctx : & mut ExecutionCtx ,
201216) -> VortexResult < ArrayRef > {
202217 let ext_dtype = ext. dtype ( ) . clone ( ) ;
203- let storage = ext. storage_array ( ) ;
204- let fsl = storage. clone ( ) . execute :: < FixedSizeListArray > ( ctx) ?;
205218
206219 vortex_ensure ! (
207220 !ext_dtype. is_nullable( ) ,
208221 "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}" ,
209222 ) ;
223+
224+ // Validate that all rows are unit-norm (or zero).
225+ let num_rows = ext. as_ref ( ) . len ( ) ;
226+ if num_rows > 0 {
227+ let norms_sfn =
228+ L2Norm :: try_new_array ( & ApproxOptions :: Exact , ext. as_ref ( ) . clone ( ) , num_rows) ?;
229+ let norms: PrimitiveArray = norms_sfn. into_array ( ) . execute ( ctx) ?;
230+
231+ let element_ptype = ext_dtype
232+ . as_extension ( )
233+ . metadata :: < AnyVector > ( )
234+ . element_ptype ( ) ;
235+
236+ match_each_float_ptype ! ( element_ptype, |T | {
237+ for ( i, & norm) in norms. as_slice:: <T >( ) . iter( ) . enumerate( ) {
238+ let norm_f64: f64 = ToPrimitive :: to_f64( & norm) . unwrap_or( f64 :: NAN ) ;
239+ vortex_ensure!(
240+ norm_f64 == 0.0 || ( norm_f64 - 1.0 ) . abs( ) < UNIT_NORM_TOLERANCE ,
241+ "TurboQuant requires unit-norm input, but row {i} has L2 norm {norm_f64:.6} \
242+ (expected 1.0 or 0.0)",
243+ ) ;
244+ }
245+ } ) ;
246+ }
247+
248+ // SAFETY: We just validated that the input is non-nullable and all rows are unit-norm.
249+ unsafe { turboquant_encode_unchecked ( ext, config, ctx) }
250+ }
251+
252+ /// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a
253+ /// [`TurboQuantArray`], without validating the unit-norm precondition.
254+ ///
255+ /// # Safety
256+ ///
257+ /// The caller must ensure:
258+ ///
259+ /// - The input dtype is non-nullable.
260+ /// - Every row is L2-normalized (unit norm) or is a zero vector.
261+ ///
262+ /// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently
263+ /// incorrect quantization results.
264+ pub unsafe fn turboquant_encode_unchecked (
265+ ext : ArrayView < Extension > ,
266+ config : & TurboQuantConfig ,
267+ ctx : & mut ExecutionCtx ,
268+ ) -> VortexResult < ArrayRef > {
269+ let ext_dtype = ext. dtype ( ) . clone ( ) ;
270+ let storage = ext. storage_array ( ) ;
271+ let fsl = storage. clone ( ) . execute :: < FixedSizeListArray > ( ctx) ?;
272+
210273 vortex_ensure ! (
211274 config. bit_width >= 1 && config. bit_width <= TurboQuant :: MAX_BIT_WIDTH ,
212275 "bit_width must be 1-{}, got {}" ,
0 commit comments