@@ -152,11 +152,9 @@ fn encode_decode(
152152// -----------------------------------------------------------------------
153153
154154#[ rstest]
155- #[ case( 32 , 1 ) ]
156- #[ case( 32 , 2 ) ]
157- #[ case( 32 , 3 ) ]
158- #[ case( 32 , 4 ) ]
155+ #[ case( 128 , 1 ) ]
159156#[ case( 128 , 2 ) ]
157+ #[ case( 128 , 3 ) ]
160158#[ case( 128 , 4 ) ]
161159#[ case( 128 , 6 ) ]
162160#[ case( 128 , 8 ) ]
@@ -280,8 +278,9 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> {
280278
281279#[ rstest]
282280#[ case( 1 ) ]
283- #[ case( 2 ) ]
284- fn rejects_dimension_below_3 ( #[ case] dim : usize ) {
281+ #[ case( 64 ) ]
282+ #[ case( 127 ) ]
283+ fn rejects_dimension_below_128 ( #[ case] dim : usize ) {
285284 let fsl = make_fsl_small ( dim) ;
286285 let ext = make_vector_ext ( & fsl) ;
287286 let config = TurboQuantConfig {
@@ -340,7 +339,7 @@ fn all_zero_vectors_roundtrip() -> VortexResult<()> {
340339#[ test]
341340fn f64_input_encodes_successfully ( ) -> VortexResult < ( ) > {
342341 let num_rows = 10 ;
343- let dim = 64 ;
342+ let dim = 128 ;
344343 let mut rng = StdRng :: seed_from_u64 ( 99 ) ;
345344 let normal = Normal :: new ( 0.0f64 , 1.0 ) . unwrap ( ) ;
346345
@@ -371,6 +370,48 @@ fn f64_input_encodes_successfully() -> VortexResult<()> {
371370 Ok ( ( ) )
372371}
373372
373+ /// Verify that f16 input is accepted and encoded (upcast to f32 internally).
374+ #[ test]
375+ fn f16_input_encodes_successfully ( ) -> VortexResult < ( ) > {
376+ let num_rows = 10 ;
377+ let dim = 128 ;
378+ let mut rng = StdRng :: seed_from_u64 ( 99 ) ;
379+ let normal = Normal :: new ( 0.0f32 , 1.0 ) . unwrap ( ) ;
380+
381+ let mut buf = BufferMut :: < half:: f16 > :: with_capacity ( num_rows * dim) ;
382+ for _ in 0 ..( num_rows * dim) {
383+ buf. push ( half:: f16:: from_f32 ( normal. sample ( & mut rng) ) ) ;
384+ }
385+ let elements = PrimitiveArray :: new :: < half:: f16 > ( buf. freeze ( ) , Validity :: NonNullable ) ;
386+ let fsl = FixedSizeListArray :: try_new (
387+ elements. into_array ( ) ,
388+ dim. try_into ( )
389+ . expect ( "somehow got dimension greater than u32::MAX" ) ,
390+ Validity :: NonNullable ,
391+ num_rows,
392+ ) ?;
393+
394+ let ext = make_vector_ext ( & fsl) ;
395+ let config = TurboQuantConfig {
396+ bit_width : 3 ,
397+ seed : Some ( 42 ) ,
398+ } ;
399+ let mut ctx = SESSION . create_execution_ctx ( ) ;
400+ let encoded = turboquant_encode ( & ext, & config, & mut ctx) ?;
401+ let tq = encoded. as_opt :: < TurboQuant > ( ) . unwrap ( ) ;
402+ assert_eq ! ( tq. norms( ) . len( ) , num_rows) ;
403+ assert_eq ! ( tq. dimension( ) as usize , dim) ;
404+
405+ // Verify roundtrip: decode and check reconstruction is reasonable.
406+ let decoded_ext = encoded. execute :: < ExtensionArray > ( & mut ctx) ?;
407+ let decoded_fsl = decoded_ext
408+ . storage_array ( )
409+ . to_canonical ( ) ?
410+ . into_fixed_size_list ( ) ;
411+ assert_eq ! ( decoded_fsl. len( ) , num_rows) ;
412+ Ok ( ( ) )
413+ }
414+
374415// -----------------------------------------------------------------------
375416// Verification tests for stored metadata
376417// -----------------------------------------------------------------------
@@ -494,7 +535,7 @@ fn slice_preserves_data() -> VortexResult<()> {
494535
495536#[ test]
496537fn scalar_at_matches_decompress ( ) -> VortexResult < ( ) > {
497- let fsl = make_fsl ( 10 , 64 , 42 ) ;
538+ let fsl = make_fsl ( 10 , 128 , 42 ) ;
498539 let ext = make_vector_ext ( & fsl) ;
499540 let config = TurboQuantConfig {
500541 bit_width : 3 ,
@@ -593,7 +634,9 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> {
593634 . sum :: < f32 > ( )
594635 } ;
595636
596- // 4-bit quantization: expect reasonable accuracy.
637+ // At 4-bit, the theoretical MSE bound per coordinate is ~0.0106 (Theorem 1). For cosine
638+ // similarity (bounded [-1, 1]), the error is bounded roughly by 2*sqrt(MSE) ~ 0.2. We use
639+ // 0.15 as a tighter empirical bound.
597640 let error = ( exact_cos - approx_cos) . abs ( ) ;
598641 assert ! (
599642 error < 0.15 ,
@@ -604,6 +647,105 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> {
604647 Ok ( ( ) )
605648}
606649
650+ /// Verify approximate dot product in the quantized domain.
651+ ///
652+ /// NOTE: The MSE quantizer (TurboQuant_mse) has inherent **multiplicative bias** for inner
653+ /// products — the quantized dot product systematically over- or under-estimates the true value.
654+ /// This is a fundamental property: the paper's `TurboQuant_prod` variant adds QJL specifically
655+ /// to debias inner products, but we only implement the MSE-only variant.
656+ ///
657+ /// Even at 8-bit (near-lossless reconstruction, MSE ~4e-5), the quantized-domain dot product
658+ /// can have ~10-15% relative error due to this bias. This tolerance is therefore intentionally
659+ /// loose — we're testing that the approximation is in the right ballpark, not that it's precise.
660+ ///
661+ /// TODO(connor): Revisit these tolerances when we have TurboQuant_prod (QJL debiasing).
662+ #[ test]
663+ fn dot_product_quantized_accuracy ( ) -> VortexResult < ( ) > {
664+ let fsl = make_fsl ( 20 , 128 , 42 ) ;
665+ let ext = make_vector_ext ( & fsl) ;
666+ let config = TurboQuantConfig {
667+ bit_width : 8 ,
668+ seed : Some ( 123 ) ,
669+ } ;
670+ let mut ctx = SESSION . create_execution_ctx ( ) ;
671+ let encoded = turboquant_encode ( & ext, & config, & mut ctx) ?;
672+ let tq = encoded. as_opt :: < TurboQuant > ( ) . unwrap ( ) ;
673+
674+ let input_prim = fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
675+ let input_f32 = input_prim. as_slice :: < f32 > ( ) ;
676+
677+ let mut ctx = SESSION . create_execution_ctx ( ) ;
678+ let pd = tq. padded_dim ( ) as usize ;
679+ let norms_prim = tq. norms ( ) . clone ( ) . execute :: < PrimitiveArray > ( & mut ctx) ?;
680+ let norms = norms_prim. as_slice :: < f32 > ( ) ;
681+ let codes_fsl = tq. codes ( ) . clone ( ) . execute :: < FixedSizeListArray > ( & mut ctx) ?;
682+ let codes_prim = codes_fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
683+ let all_codes = codes_prim. as_slice :: < u8 > ( ) ;
684+ let centroids_prim = tq. centroids ( ) . clone ( ) . execute :: < PrimitiveArray > ( & mut ctx) ?;
685+ let centroid_vals = centroids_prim. as_slice :: < f32 > ( ) ;
686+
687+ for ( row_a, row_b) in [ ( 0 , 1 ) , ( 5 , 10 ) , ( 0 , 19 ) ] {
688+ let vec_a = & input_f32[ row_a * 128 ..( row_a + 1 ) * 128 ] ;
689+ let vec_b = & input_f32[ row_b * 128 ..( row_b + 1 ) * 128 ] ;
690+
691+ let exact_dot: f32 = vec_a. iter ( ) . zip ( vec_b. iter ( ) ) . map ( |( & x, & y) | x * y) . sum ( ) ;
692+
693+ let codes_a = & all_codes[ row_a * pd..( row_a + 1 ) * pd] ;
694+ let codes_b = & all_codes[ row_b * pd..( row_b + 1 ) * pd] ;
695+ let unit_dot: f32 = codes_a
696+ . iter ( )
697+ . zip ( codes_b. iter ( ) )
698+ . map ( |( & ca, & cb) | centroid_vals[ ca as usize ] * centroid_vals[ cb as usize ] )
699+ . sum ( ) ;
700+ let approx_dot = norms[ row_a] * norms[ row_b] * unit_dot;
701+
702+ // See doc comment above: 15% relative error is expected due to MSE quantizer bias.
703+ let scale = exact_dot. abs ( ) . max ( 1.0 ) ;
704+ let rel_error = ( exact_dot - approx_dot) . abs ( ) / scale;
705+ assert ! (
706+ rel_error < 0.15 ,
707+ "dot product error too large for ({row_a}, {row_b}): \
708+ exact={exact_dot:.4}, approx={approx_dot:.4}, rel_error={rel_error:.4}"
709+ ) ;
710+ }
711+ Ok ( ( ) )
712+ }
713+
714+ /// Roundtrip at large embedding dimensions to validate padding and SRHT at common sizes.
715+ ///
716+ /// NOTE: The theoretical MSE bound (Theorem 1) is proved for Haar-distributed random orthogonal
717+ /// matrices, not SRHT. The SRHT is a practical O(d log d) approximation that doesn't exactly
718+ /// satisfy the Haar assumption, so empirical MSE can slightly exceed the theoretical bound. We
719+ /// use a 2x multiplier to account for this gap.
720+ ///
721+ /// The 1024-d case uses 5-bit instead of 4-bit because at 4-bit the SRHT approximation error
722+ /// at d=1024 pushes MSE ~20% above the 1x theoretical bound (0.0127 vs bound 0.0106).
723+ ///
724+ /// TODO(connor): Revisit after Stage 2 block decomposition — at d=768 with block_size=256,
725+ /// the per-block SRHT will be lower-dimensional and may have different error characteristics.
726+ #[ rstest]
727+ #[ case( 768 , 4 ) ]
728+ #[ case( 1024 , 5 ) ]
729+ fn large_dimension_roundtrip ( #[ case] dim : usize , #[ case] bit_width : u8 ) -> VortexResult < ( ) > {
730+ let num_rows = 10 ;
731+ let fsl = make_fsl ( num_rows, dim, 42 ) ;
732+ let config = TurboQuantConfig {
733+ bit_width,
734+ seed : Some ( 123 ) ,
735+ } ;
736+ let ( original, decoded) = encode_decode ( & fsl, & config) ?;
737+ assert_eq ! ( decoded. len( ) , original. len( ) ) ;
738+
739+ let normalized_mse = per_vector_normalized_mse ( & original, & decoded, dim, num_rows) ;
740+ // 2x slack for the SRHT-vs-Haar gap (see doc comment above).
741+ let bound = 2.0 * theoretical_mse_bound ( bit_width) ;
742+ assert ! (
743+ normalized_mse < bound,
744+ "Normalized MSE {normalized_mse:.6} exceeds 2x bound {bound:.6} for dim={dim}, bits={bit_width}" ,
745+ ) ;
746+ Ok ( ( ) )
747+ }
748+
607749/// Verify that the encoded array's dtype is a Vector extension type.
608750#[ test]
609751fn encoded_dtype_is_vector_extension ( ) -> VortexResult < ( ) > {
@@ -702,7 +844,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> {
702844#[ test]
703845fn nullable_norms_match_validity ( ) -> VortexResult < ( ) > {
704846 let validity = Validity :: from_iter ( [ true , false , true , false , true ] ) ;
705- let fsl = make_fsl_with_validity ( 5 , 64 , 42 , validity) ;
847+ let fsl = make_fsl_with_validity ( 5 , 128 , 42 , validity) ;
706848 let ext = make_vector_ext ( & fsl) ;
707849
708850 let config = TurboQuantConfig {
@@ -729,7 +871,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> {
729871#[ test]
730872fn nullable_l2_norm_readthrough ( ) -> VortexResult < ( ) > {
731873 let validity = Validity :: from_iter ( [ true , false , true , false , true ] ) ;
732- let fsl = make_fsl_with_validity ( 5 , 64 , 42 , validity) ;
874+ let fsl = make_fsl_with_validity ( 5 , 128 , 42 , validity) ;
733875 let ext = make_vector_ext ( & fsl) ;
734876
735877 let config = TurboQuantConfig {
@@ -749,7 +891,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> {
749891 for row in 0 ..5 {
750892 if row % 2 == 0 {
751893 assert ! ( norms. is_valid( row) ?, "row {row} should be valid" ) ;
752- let expected: f32 = orig_f32[ row * 64 ..( row + 1 ) * 64 ]
894+ let expected: f32 = orig_f32[ row * 128 ..( row + 1 ) * 128 ]
753895 . iter ( )
754896 . map ( |& v| v * v)
755897 . sum :: < f32 > ( )
@@ -773,7 +915,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> {
773915 let validity = Validity :: from_iter ( [
774916 true , true , false , true , true , false , true , false , true , true ,
775917 ] ) ;
776- let fsl = make_fsl_with_validity ( 10 , 64 , 42 , validity) ;
918+ let fsl = make_fsl_with_validity ( 10 , 128 , 42 , validity) ;
777919 let ext = make_vector_ext ( & fsl) ;
778920
779921 let config = TurboQuantConfig {
0 commit comments