@@ -28,9 +28,8 @@ fn qjl_correction_scale(padded_dim: usize) -> f32 {
2828/// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats.
2929///
3030/// Reads stored centroids and rotation signs from the array's children,
31- /// avoiding any recomputation. If QJL correction is present, the MSE decode
32- /// and QJL correction are fused into a single pass over rows to avoid an
33- /// intermediate buffer allocation and extra memory traffic.
31+ /// avoiding any recomputation. If QJL correction is present, applies
32+ /// the residual correction after MSE decoding.
3433pub fn execute_decompress (
3534 array : TurboQuantArray ,
3635 ctx : & mut ExecutionCtx ,
@@ -55,7 +54,8 @@ pub fn execute_decompress(
5554 let centroids = centroids_prim. as_slice :: < f32 > ( ) ;
5655
5756 // FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values,
58- // then we expand to u32 XOR masks once (amortized over all rows).
57+ // then we expand to u32 XOR masks once (amortized over all rows). This enables
58+ // branchless XOR-based sign application in the per-row SRHT hot loop.
5959 let signs_prim = array
6060 . rotation_signs
6161 . clone ( )
@@ -69,57 +69,73 @@ pub fn execute_decompress(
6969 let norms_prim = array. norms . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
7070 let norms = norms_prim. as_slice :: < f32 > ( ) ;
7171
72- // Prepare QJL data (if present) before entering the row loop.
73- // QJL reuses the MSE rotation matrix — no separate rotation to reconstruct.
74- let qjl_scale = qjl_correction_scale ( padded_dim) ;
75- let qjl_data = if let Some ( qjl) = & array. qjl {
76- let qjl_signs_prim = qjl. signs . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
77- let residual_norms_prim = qjl. residual_norms . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
78- Some ( ( qjl_signs_prim, residual_norms_prim) )
79- } else {
80- None
81- } ;
82-
83- // Single fused loop: MSE decode + optional QJL correction per row.
84- let mut output = BufferMut :: < f32 > :: with_capacity ( num_rows * dim) ;
72+ // MSE decode: dequantize → inverse rotate → scale by norm.
73+ let mut mse_output = BufferMut :: < f32 > :: with_capacity ( num_rows * dim) ;
8574 let mut dequantized = vec ! [ 0.0f32 ; padded_dim] ;
8675 let mut unrotated = vec ! [ 0.0f32 ; padded_dim] ;
87- // QJL scratch buffers (only used when qjl_data is Some).
88- let mut qjl_signs_vec = vec ! [ 0.0f32 ; padded_dim] ;
89- let mut qjl_projected = vec ! [ 0.0f32 ; padded_dim] ;
9076
9177 for row in 0 ..num_rows {
9278 let row_indices = & indices[ row * padded_dim..( row + 1 ) * padded_dim] ;
9379 let norm = norms[ row] ;
9480
95- // MSE: dequantize → inverse rotate → scale by norm.
9681 for idx in 0 ..padded_dim {
9782 dequantized[ idx] = centroids[ row_indices[ idx] as usize ] ;
9883 }
84+
9985 rotation. inverse_rotate ( & dequantized, & mut unrotated) ;
86+
10087 for idx in 0 ..dim {
10188 unrotated[ idx] *= norm;
10289 }
10390
104- if let Some ( ( ref qjl_signs_prim, ref residual_norms_prim) ) = qjl_data {
105- // QJL: apply residual correction inline, reusing the MSE rotation.
106- let qjl_signs_u8 = qjl_signs_prim. as_slice :: < u8 > ( ) ;
107- let residual_norms = residual_norms_prim. as_slice :: < f32 > ( ) ;
108- let residual_norm = residual_norms[ row] ;
109-
110- let row_signs = & qjl_signs_u8[ row * padded_dim..( row + 1 ) * padded_dim] ;
111- for idx in 0 ..padded_dim {
112- qjl_signs_vec[ idx] = if row_signs[ idx] != 0 { 1.0 } else { -1.0 } ;
113- }
114-
115- rotation. inverse_rotate ( & qjl_signs_vec, & mut qjl_projected) ;
116- let scale = qjl_scale * residual_norm;
117-
118- for idx in 0 ..dim {
119- output. push ( unrotated[ idx] + scale * qjl_projected[ idx] ) ;
120- }
121- } else {
122- output. extend_from_slice ( & unrotated[ ..dim] ) ;
91+ mse_output. extend_from_slice ( & unrotated[ ..dim] ) ;
92+ }
93+
94+ // If no QJL correction, we're done.
95+ let Some ( qjl) = & array. qjl else {
96+ let elements = PrimitiveArray :: new :: < f32 > ( mse_output. freeze ( ) , Validity :: NonNullable ) ;
97+ return Ok ( FixedSizeListArray :: try_new (
98+ elements. into_array ( ) ,
99+ array. dimension ( ) ,
100+ Validity :: NonNullable ,
101+ num_rows,
102+ ) ?
103+ . into_array ( ) ) ;
104+ } ;
105+
106+ // Apply QJL residual correction.
107+ // FastLanes SIMD-unpacks the 1-bit bitpacked QJL signs into u8 0/1 values.
108+ let qjl_signs_prim = qjl. signs . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
109+ let qjl_signs_u8 = qjl_signs_prim. as_slice :: < u8 > ( ) ;
110+
111+ let residual_norms_prim = qjl. residual_norms . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
112+ let residual_norms = residual_norms_prim. as_slice :: < f32 > ( ) ;
113+
114+ let qjl_rot_signs_prim = qjl. rotation_signs . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
115+ let qjl_rot = RotationMatrix :: from_u8_slice ( qjl_rot_signs_prim. as_slice :: < u8 > ( ) , dim) ?;
116+
117+ let qjl_scale = qjl_correction_scale ( padded_dim) ;
118+ let mse_elements = mse_output. as_ref ( ) ;
119+
120+ let mut output = BufferMut :: < f32 > :: with_capacity ( num_rows * dim) ;
121+ let mut qjl_signs_vec = vec ! [ 0.0f32 ; padded_dim] ;
122+ let mut qjl_projected = vec ! [ 0.0f32 ; padded_dim] ;
123+
124+ for row in 0 ..num_rows {
125+ let mse_row = & mse_elements[ row * dim..( row + 1 ) * dim] ;
126+ let residual_norm = residual_norms[ row] ;
127+
128+ // Convert u8 0/1 → f32 ±1.0 for this row's signs.
129+ let row_signs = & qjl_signs_u8[ row * padded_dim..( row + 1 ) * padded_dim] ;
130+ for idx in 0 ..padded_dim {
131+ qjl_signs_vec[ idx] = if row_signs[ idx] != 0 { 1.0 } else { -1.0 } ;
132+ }
133+
134+ qjl_rot. inverse_rotate ( & qjl_signs_vec, & mut qjl_projected) ;
135+ let scale = qjl_scale * residual_norm;
136+
137+ for idx in 0 ..dim {
138+ output. push ( mse_row[ idx] + scale * qjl_projected[ idx] ) ;
123139 }
124140 }
125141
0 commit comments