@@ -69,7 +69,8 @@ impl L2Denorm {
6969 /// This is the correct constructor for [`L2Denorm`] arrays. In addition to the structural
7070 /// checks performed by [`ScalarFnArray::try_new`], it validates that every valid row of the
7171 /// `normalized` child has L2 norm `1.0` (or `0.0` for zero rows), within the tolerance implied
72- /// by the child element precision.
72+ /// by the child element precision. It also validates that stored norms are non-negative, and
73+ /// that any row with stored norm `0.0` has an all-zero normalized row.
7374 ///
7475 /// # Errors
7576 ///
@@ -82,10 +83,15 @@ impl L2Denorm {
8283 len : usize ,
8384 ctx : & mut ExecutionCtx ,
8485 ) -> VortexResult < ScalarFnArray > {
85- validate_l2_normalized_rows ( normalized. clone ( ) , ctx) ?;
86+ let result = ScalarFnArray :: try_new (
87+ L2Denorm :: new ( options) . erased ( ) ,
88+ vec ! [ normalized. clone( ) , norms. clone( ) ] ,
89+ len,
90+ ) ?;
8691
87- // SAFETY: We just validated that it is normalized.
88- unsafe { Self :: new_array_unchecked ( options, normalized, norms, len) }
92+ validate_l2_denorm_children ( normalized, norms, ctx) ?;
93+
94+ Ok ( result)
8995 }
9096
9197 /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually
@@ -114,49 +120,6 @@ impl L2Denorm {
114120 }
115121}
116122
117- /// Returns the acceptable unit-norm drift for the given element precision.
118- fn unit_norm_tolerance ( element_ptype : PType ) -> f64 {
119- match element_ptype {
120- PType :: F16 => 2e-3 ,
121- PType :: F32 => 2e-6 ,
122- PType :: F64 => 1e-10 ,
123- _ => unreachable ! ( "L2Denorm requires float elements, got {element_ptype:?}" ) ,
124- }
125- }
126-
127- /// Validates that every valid row of `input` is already L2-normalized.
128- pub fn validate_l2_normalized_rows ( input : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ( ) > {
129- let row_count = input. len ( ) ;
130- if row_count == 0 {
131- return Ok ( ( ) ) ;
132- }
133-
134- let tensor_match = validate_tensor_float_input ( input. dtype ( ) ) ?;
135- let element_ptype = tensor_match. element_ptype ( ) ;
136- let tolerance = unit_norm_tolerance ( element_ptype) ;
137-
138- let norms_sfn = L2Norm :: try_new_array ( & ApproxOptions :: Exact , input, row_count) ?;
139- let norms: PrimitiveArray = norms_sfn. into_array ( ) . execute ( ctx) ?;
140- let norms_validity = norms. validity ( ) ?;
141-
142- match_each_float_ptype ! ( element_ptype, |T | {
143- for ( i, & norm) in norms. as_slice:: <T >( ) . iter( ) . enumerate( ) {
144- if !norms_validity. is_valid( i) ? {
145- continue ;
146- }
147-
148- let norm_f64 = ToPrimitive :: to_f64( & norm) . unwrap_or( f64 :: NAN ) ;
149- vortex_ensure!(
150- norm_f64 == 0.0 || ( norm_f64 - 1.0 ) . abs( ) <= tolerance,
151- "L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \
152- {norm_f64:.6}",
153- ) ;
154- }
155- } ) ;
156-
157- Ok ( ( ) )
158- }
159-
160123impl ScalarFnVTable for L2Denorm {
161124 type Options = ApproxOptions ;
162125
@@ -373,6 +336,104 @@ fn build_tensor_array<T: NativePType>(
373336 Ok ( ExtensionArray :: new ( dtype. as_extension ( ) . clone ( ) , storage. into_array ( ) ) . into_array ( ) )
374337}
375338
339+ /// Returns the acceptable unit-norm drift for the given element precision.
340+ fn unit_norm_tolerance ( element_ptype : PType ) -> f64 {
341+ match element_ptype {
342+ PType :: F16 => 2e-3 ,
343+ PType :: F32 => 2e-6 ,
344+ PType :: F64 => 1e-10 ,
345+ _ => unreachable ! ( "L2Denorm requires float elements, got {element_ptype:?}" ) ,
346+ }
347+ }
348+
349+ /// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0).
350+ pub fn validate_l2_normalized_rows ( input : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ( ) > {
351+ validate_l2_normalized_rows_impl ( input, None , ctx)
352+ }
353+
354+ /// Validates that the `normalized` and `norms` children jointly satisfy the [`L2Denorm`]
355+ /// invariants, which are:
356+ ///
357+ /// - All vectors in the normalized array have length 1 or 0.
358+ /// - If the vector has a norm of 0, then the vector must be all 0s.
359+ fn validate_l2_denorm_children (
360+ normalized : ArrayRef ,
361+ norms : ArrayRef ,
362+ ctx : & mut ExecutionCtx ,
363+ ) -> VortexResult < ( ) > {
364+ validate_l2_normalized_rows_impl ( normalized, Some ( norms) , ctx)
365+ }
366+
367+ fn validate_l2_normalized_rows_impl (
368+ normalized : ArrayRef ,
369+ norms : Option < ArrayRef > ,
370+ ctx : & mut ExecutionCtx ,
371+ ) -> VortexResult < ( ) > {
372+ let row_count = normalized. len ( ) ;
373+ if row_count == 0 {
374+ return Ok ( ( ) ) ;
375+ }
376+
377+ let tensor_match = validate_tensor_float_input ( normalized. dtype ( ) ) ?;
378+ let element_ptype = tensor_match. element_ptype ( ) ;
379+ let tolerance = unit_norm_tolerance ( element_ptype) ;
380+ let tensor_flat_size = tensor_match. list_size ( ) ;
381+
382+ let normalized: ExtensionArray = normalized. execute ( ctx) ?;
383+ let normalized_validity = normalized. as_ref ( ) . validity ( ) ?;
384+ let flat = extract_flat_elements ( normalized. storage_array ( ) , tensor_flat_size, ctx) ?;
385+ let norms = norms
386+ . map ( |norms| norms. execute :: < PrimitiveArray > ( ctx) )
387+ . transpose ( ) ?;
388+
389+ let combined_validity = match & norms {
390+ Some ( norms) => normalized_validity. and ( norms. validity ( ) ?) ?,
391+ None => normalized_validity,
392+ } ;
393+
394+ match_each_float_ptype ! ( element_ptype, |T | {
395+ let stored_norms = norms. as_ref( ) . map( |norms| norms. as_slice:: <T >( ) ) ;
396+
397+ for i in 0 ..row_count {
398+ if !combined_validity. is_valid( i) ? {
399+ continue ;
400+ }
401+
402+ let ( row_norm_sq, is_zero_row) =
403+ flat. row:: <T >( i)
404+ . iter( )
405+ . fold( ( 0.0f64 , true ) , |( sum_sq, is_zero) , x| {
406+ let value = ToPrimitive :: to_f64( x) . unwrap_or( f64 :: NAN ) ;
407+ ( sum_sq + value * value, is_zero && value. abs( ) <= tolerance)
408+ } ) ;
409+ let row_norm = row_norm_sq. sqrt( ) ;
410+
411+ vortex_ensure!(
412+ row_norm == 0.0 || ( row_norm - 1.0 ) . abs( ) <= tolerance,
413+ "L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \
414+ {row_norm:.6}",
415+ ) ;
416+
417+ if let Some ( stored_norms) = stored_norms {
418+ let stored_norm_f64 = ToPrimitive :: to_f64( & stored_norms[ i] ) . unwrap_or( f64 :: NAN ) ;
419+ vortex_ensure!(
420+ stored_norm_f64 >= 0.0 ,
421+ "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}" ,
422+ ) ;
423+
424+ if stored_norm_f64 == 0.0 {
425+ vortex_ensure!(
426+ is_zero_row,
427+ "L2Denorm normalized child must be all zeros when norms row {i} is 0.0" ,
428+ ) ;
429+ }
430+ }
431+ }
432+ } ) ;
433+
434+ Ok ( ( ) )
435+ }
436+
376437#[ cfg( test) ]
377438mod tests {
378439 use std:: sync:: LazyLock ;
@@ -590,6 +651,28 @@ mod tests {
590651 Ok ( ( ) )
591652 }
592653
654+ #[ test]
655+ fn l2_denorm_try_new_array_rejects_nonzero_row_with_zero_norm ( ) -> VortexResult < ( ) > {
656+ let normalized = vector_array ( 2 , & [ 1.0 , 0.0 , 0.0 , 0.0 ] ) ?;
657+ let norms = PrimitiveArray :: from_iter ( [ 0.0f64 , 0.0 ] ) . into_array ( ) ;
658+ let mut ctx = SESSION . create_execution_ctx ( ) ;
659+
660+ let result = L2Denorm :: try_new_array ( & ApproxOptions :: Exact , normalized, norms, 2 , & mut ctx) ;
661+ assert ! ( result. is_err( ) ) ;
662+ Ok ( ( ) )
663+ }
664+
665+ #[ test]
666+ fn l2_denorm_try_new_array_rejects_negative_norms ( ) -> VortexResult < ( ) > {
667+ let normalized = vector_array ( 2 , & [ 1.0 , 0.0 , 0.0 , 1.0 ] ) ?;
668+ let norms = PrimitiveArray :: from_iter ( [ 1.0f64 , -1.0 ] ) . into_array ( ) ;
669+ let mut ctx = SESSION . create_execution_ctx ( ) ;
670+
671+ let result = L2Denorm :: try_new_array ( & ApproxOptions :: Exact , normalized, norms, 2 , & mut ctx) ;
672+ assert ! ( result. is_err( ) ) ;
673+ Ok ( ( ) )
674+ }
675+
593676 #[ test]
594677 fn l2_denorm_new_array_unchecked_accepts_unnormalized_child ( ) -> VortexResult < ( ) > {
595678 let normalized = vector_array ( 2 , & [ 3.0 , 4.0 , 1.0 , 0.0 ] ) ?;
0 commit comments