@@ -456,9 +456,46 @@ pub fn train_hybrid(
456456
457457// === Internal utilities ===
458458
459- /// Squared L2 distance between two slices.
459+ /// Squared L2 distance between two slices via `crate::simd`.
460+ ///
461+ /// For 16D subvectors (CAM-PQ subspace dimension), this is one F32x16
462+ /// load-subtract-multiply-reduce. Consumer never sees hardware details.
460463#[ inline( always) ]
461464fn squared_l2 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
465+ debug_assert_eq ! ( a. len( ) , b. len( ) ) ;
466+ let n = a. len ( ) ;
467+
468+ // Fast path: exactly 16 elements = one F32x16 lane (most common in CAM-PQ).
469+ if n == 16 {
470+ use crate :: simd:: F32x16 ;
471+ let va = F32x16 :: from_slice ( a) ;
472+ let vb = F32x16 :: from_slice ( b) ;
473+ let diff = va - vb;
474+ return ( diff * diff) . reduce_sum ( ) ;
475+ }
476+
477+ // Medium path: process 16 elements at a time, accumulate remainder scalar.
478+ if n >= 16 {
479+ use crate :: simd:: F32x16 ;
480+ let mut acc = F32x16 :: splat ( 0.0 ) ;
481+ let chunks = n / 16 ;
482+ for i in 0 ..chunks {
483+ let off = i * 16 ;
484+ let va = F32x16 :: from_slice ( & a[ off..off + 16 ] ) ;
485+ let vb = F32x16 :: from_slice ( & b[ off..off + 16 ] ) ;
486+ let diff = va - vb;
487+ acc = diff. mul_add ( diff, acc) ;
488+ }
489+ let mut sum = acc. reduce_sum ( ) ;
490+ // Scalar remainder
491+ for i in ( chunks * 16 ) ..n {
492+ let d = a[ i] - b[ i] ;
493+ sum += d * d;
494+ }
495+ return sum;
496+ }
497+
498+ // Scalar fallback for tiny slices.
462499 a. iter ( ) . zip ( b. iter ( ) ) . map ( |( x, y) | ( x - y) * ( x - y) ) . sum ( )
463500}
464501
0 commit comments