@@ -16,10 +16,12 @@ use vortex_array::dtype::PType;
1616use vortex_array:: validity:: Validity ;
1717use vortex_buffer:: Buffer ;
1818use vortex_error:: VortexResult ;
19+ use vortex_tensor:: scalar_fns:: l2_norm:: L2Norm ;
1920
2021use super :: execute_tq_decode;
2122use super :: execute_tq_encode;
2223use super :: f32_vector_array;
24+ use super :: tensor_test_session;
2325use super :: test_session;
2426use super :: turboquant_storage;
2527use super :: vector_array;
@@ -29,6 +31,7 @@ use super::vector_values_f32;
2931use crate :: TurboQuantConfig ;
3032use crate :: centroids:: compute_or_get_centroids;
3133use crate :: vector:: normalize:: tq_normalize_as_l2_denorm;
34+ use crate :: vector:: storage:: parse_storage;
3235
3336#[ rstest]
3437#[ case:: zero_bits( 0 , 42 , 3 ) ]
@@ -105,6 +108,10 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> {
105108 . unmasked_field_by_name ( "norms" ) ?
106109 . clone ( )
107110 . execute ( & mut ctx) ?;
111+ let inv_direction_norms: PrimitiveArray = storage
112+ . unmasked_field_by_name ( "inv_direction_norms" ) ?
113+ . clone ( )
114+ . execute ( & mut ctx) ?;
108115 let codes: FixedSizeListArray = storage
109116 . unmasked_field_by_name ( "codes" ) ?
110117 . clone ( )
@@ -114,13 +121,21 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> {
114121 assert ! ( !mask. value( 1 ) ) ;
115122 assert ! ( mask. value( 2 ) ) ;
116123 assert_eq ! ( norms. validity( ) ?. nullability( ) , Nullability :: Nullable ) ;
124+ assert_eq ! (
125+ inv_direction_norms. validity( ) ?. nullability( ) ,
126+ Nullability :: Nullable
127+ ) ;
117128 assert_eq ! ( codes. validity( ) ?. nullability( ) , Nullability :: Nullable ) ;
118129
119130 let norms_validity = norms. validity ( ) ?. execute_mask ( 3 , & mut ctx) ?;
131+ let inv_direction_norms_validity = inv_direction_norms. validity ( ) ?. execute_mask ( 3 , & mut ctx) ?;
120132 let codes_validity = codes. validity ( ) ?. execute_mask ( 3 , & mut ctx) ?;
121133 assert ! ( norms_validity. value( 0 ) ) ;
122134 assert ! ( !norms_validity. value( 1 ) ) ;
123135 assert ! ( norms_validity. value( 2 ) ) ;
136+ assert ! ( inv_direction_norms_validity. value( 0 ) ) ;
137+ assert ! ( !inv_direction_norms_validity. value( 1 ) ) ;
138+ assert ! ( inv_direction_norms_validity. value( 2 ) ) ;
124139 assert ! ( codes_validity. value( 0 ) ) ;
125140 assert ! ( !codes_validity. value( 1 ) ) ;
126141 assert ! ( codes_validity. value( 2 ) ) ;
@@ -134,6 +149,57 @@ fn encode_stores_norms_and_struct_validity() -> VortexResult<()> {
134149 Ok ( ( ) )
135150}
136151
152+ #[ test]
153+ fn encode_stores_zero_inv_direction_norm_for_zero_rows ( ) -> VortexResult < ( ) > {
154+ let session = test_session ( ) ;
155+ let mut ctx = session. create_execution_ctx ( ) ;
156+ let mut values = vec ! [ 0.0f32 ; 3 * 128 ] ;
157+ values[ 0 ] = 3.0 ;
158+ values[ 1 ] = 4.0 ;
159+ values[ 256 ] = 1.0 ;
160+ let input = vector_array ( 128 , & values, Validity :: NonNullable ) ?;
161+
162+ let encoded = execute_tq_encode ( input, & TurboQuantConfig :: default ( ) , & mut ctx) ?;
163+ let storage = turboquant_storage ( encoded, & mut ctx) ?;
164+ let inv_direction_norms: PrimitiveArray = storage
165+ . unmasked_field_by_name ( "inv_direction_norms" ) ?
166+ . clone ( )
167+ . execute ( & mut ctx) ?;
168+
169+ let values = inv_direction_norms. as_slice :: < f32 > ( ) ;
170+ assert ! ( values[ 0 ] . is_finite( ) && values[ 0 ] > 0.0 ) ;
171+ assert_eq ! ( values[ 1 ] , 0.0 ) ;
172+ assert ! ( values[ 2 ] . is_finite( ) && values[ 2 ] > 0.0 ) ;
173+ Ok ( ( ) )
174+ }
175+
176+ #[ test]
177+ fn decode_preserves_original_l2_norms_for_non_power_of_two_dimensions ( ) -> VortexResult < ( ) > {
178+ let session = tensor_test_session ( ) ;
179+ let mut ctx = session. create_execution_ctx ( ) ;
180+ let input = f32_vector_array ( 129 , 3 , 0.25 , Validity :: NonNullable ) ?;
181+ let config = TurboQuantConfig :: try_new ( 3 , 42 , 3 ) ?;
182+
183+ let encoded = execute_tq_encode ( input, & config, & mut ctx) ?;
184+ let expected_norms = parse_storage ( encoded. clone ( ) , & mut ctx) ?. norms ;
185+ let decoded = execute_tq_decode ( encoded, & mut ctx) ?;
186+ let decoded_norms: PrimitiveArray = L2Norm :: try_new_array ( decoded, 3 ) ?
187+ . into_array ( )
188+ . execute ( & mut ctx) ?;
189+
190+ for ( actual, expected) in decoded_norms
191+ . as_slice :: < f32 > ( )
192+ . iter ( )
193+ . zip ( expected_norms. as_slice :: < f32 > ( ) )
194+ {
195+ assert ! (
196+ ( * actual - * expected) . abs( ) <= 1e-4 * expected. max( 1.0 ) ,
197+ "decoded norm {actual} did not match stored norm {expected}"
198+ ) ;
199+ }
200+ Ok ( ( ) )
201+ }
202+
137203#[ test]
138204fn normalize_as_l2_denorm_preserves_child_validity ( ) -> VortexResult < ( ) > {
139205 let session = test_session ( ) ;
0 commit comments