Skip to content

Commit f978d7b

Browse files
committed
Revert "share rotation matrix between MSE and QJL"
This reverts commit 0c5e8e73af9afc001e20405c91d11d59a8129796. Signed-off-by: Will Manning <will@willmanning.io>
1 parent 683467d commit f978d7b

File tree

4 files changed

+89
-56
lines changed

4 files changed

+89
-56
lines changed

encodings/turboquant/src/array.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,15 @@ pub struct TurboQuantMetadata {
3737
}
3838

3939
/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased
40-
/// inner product estimation. When present, adds 2 additional children.
41-
///
42-
/// The QJL correction reuses the MSE rotation matrix (stored in `rotation_signs`)
43-
/// rather than maintaining a separate rotation. This halves the rotation sign
44-
/// storage and avoids reconstructing a second `RotationMatrix` at decode time.
40+
/// inner product estimation. When present, adds 3 additional children.
4541
#[derive(Clone, Debug)]
4642
pub struct QjlCorrection {
47-
/// Sign bits: `BitPackedArray` (1-bit), length `num_rows * padded_dim`.
43+
/// Sign bits: `BoolArray`, length `num_rows * padded_dim`.
4844
pub(crate) signs: ArrayRef,
4945
/// Residual norms: `PrimitiveArray<f32>`, length `num_rows`.
5046
pub(crate) residual_norms: ArrayRef,
47+
/// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order).
48+
pub(crate) rotation_signs: ArrayRef,
5149
}
5250

5351
impl QjlCorrection {
@@ -60,6 +58,11 @@ impl QjlCorrection {
6058
pub fn residual_norms(&self) -> &ArrayRef {
6159
&self.residual_norms
6260
}
61+
62+
/// The QJL rotation signs (BoolArray, inverse application order).
63+
pub fn rotation_signs(&self) -> &ArrayRef {
64+
&self.rotation_signs
65+
}
6366
}
6467

6568
/// TurboQuant array.
@@ -68,11 +71,12 @@ impl QjlCorrection {
6871
/// - 0: `codes` — `BitPackedArray` or `PrimitiveArray<u8>` (quantized indices)
6972
/// - 1: `norms` — `PrimitiveArray<f32>` (one per vector row)
7073
/// - 2: `centroids` — `PrimitiveArray<f32>` (codebook, length 2^bit_width)
71-
/// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order)
74+
/// - 3: `rotation_signs` — `BoolArray` (3 * padded_dim bits, inverse application order)
7275
///
7376
/// Optional QJL children (when `has_qjl` is true):
74-
/// - 4: `qjl_signs` — `BitPackedArray` (num_rows * padded_dim, 1-bit u8 0/1)
77+
/// - 4: `qjl_signs` — `BoolArray` (num_rows * padded_dim bits)
7578
/// - 5: `qjl_residual_norms` — `PrimitiveArray<f32>` (one per row)
79+
/// - 6: `qjl_rotation_signs` — `BoolArray` (3 * padded_dim bits, QJL rotation, inverse order)
7680
#[derive(Clone, Debug)]
7781
pub struct TurboQuantArray {
7882
pub(crate) dtype: DType,

encodings/turboquant/src/compress.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,9 @@ pub fn turboquant_encode_qjl(
237237
let core = turboquant_quantize_core(fsl, seed, mse_bit_width)?;
238238
let padded_dim = core.padded_dim;
239239

240-
// QJL reuses the MSE rotation matrix. This saves one stored rotation child
241-
// and one RotationMatrix reconstruction at decode time. Empirically verified
242-
// via the qjl_inner_product_bias test suite to not introduce significant bias.
240+
// QJL uses a different rotation than the MSE stage to ensure statistical
241+
// independence between the quantization noise and the sign projection.
242+
let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(25), dim)?;
243243

244244
let num_rows = fsl.len();
245245
let mut residual_norms_buf = BufferMut::<f32>::with_capacity(num_rows);
@@ -281,9 +281,9 @@ pub fn turboquant_encode_qjl(
281281
let residual_norm = l2_norm(&residual[..dim]);
282282
residual_norms_buf.push(residual_norm);
283283

284-
// QJL: sign(S · r), reusing the MSE rotation S.
284+
// QJL: sign(S · r).
285285
if residual_norm > 0.0 {
286-
core.rotation.rotate(&residual, &mut projected);
286+
qjl_rotation.rotate(&residual, &mut projected);
287287
} else {
288288
projected.fill(0.0);
289289
}
@@ -297,16 +297,17 @@ pub fn turboquant_encode_qjl(
297297
// Build the MSE part.
298298
let mut array = build_turboquant_mse(fsl, core, mse_bit_width)?;
299299

300-
// Attach QJL correction. The QJL reuses the MSE rotation matrix (already
301-
// stored as rotation_signs), so we only need to store signs and residual norms.
300+
// Attach QJL correction.
302301
let residual_norms_array =
303302
PrimitiveArray::new::<f32>(residual_norms_buf.freeze(), Validity::NonNullable);
304303
let qjl_signs_prim = PrimitiveArray::new::<u8>(qjl_sign_u8.freeze(), Validity::NonNullable);
305304
let qjl_signs_packed = bitpack_encode(&qjl_signs_prim, 1, None)?.into_array();
305+
let qjl_rotation_signs = bitpack_rotation_signs(&qjl_rotation)?;
306306

307307
array.qjl = Some(QjlCorrection {
308308
signs: qjl_signs_packed,
309309
residual_norms: residual_norms_array.into_array(),
310+
rotation_signs: qjl_rotation_signs,
310311
});
311312

312313
Ok(array.into_array())

encodings/turboquant/src/decompress.rs

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ fn qjl_correction_scale(padded_dim: usize) -> f32 {
2828
/// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats.
2929
///
3030
/// Reads stored centroids and rotation signs from the array's children,
31-
/// avoiding any recomputation. If QJL correction is present, the MSE decode
32-
/// and QJL correction are fused into a single pass over rows to avoid an
33-
/// intermediate buffer allocation and extra memory traffic.
31+
/// avoiding any recomputation. If QJL correction is present, applies
32+
/// the residual correction after MSE decoding.
3433
pub fn execute_decompress(
3534
array: TurboQuantArray,
3635
ctx: &mut ExecutionCtx,
@@ -55,7 +54,8 @@ pub fn execute_decompress(
5554
let centroids = centroids_prim.as_slice::<f32>();
5655

5756
// FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values,
58-
// then we expand to u32 XOR masks once (amortized over all rows).
57+
// then we expand to u32 XOR masks once (amortized over all rows). This enables
58+
// branchless XOR-based sign application in the per-row SRHT hot loop.
5959
let signs_prim = array
6060
.rotation_signs
6161
.clone()
@@ -69,57 +69,73 @@ pub fn execute_decompress(
6969
let norms_prim = array.norms.clone().execute::<PrimitiveArray>(ctx)?;
7070
let norms = norms_prim.as_slice::<f32>();
7171

72-
// Prepare QJL data (if present) before entering the row loop.
73-
// QJL reuses the MSE rotation matrix — no separate rotation to reconstruct.
74-
let qjl_scale = qjl_correction_scale(padded_dim);
75-
let qjl_data = if let Some(qjl) = &array.qjl {
76-
let qjl_signs_prim = qjl.signs.clone().execute::<PrimitiveArray>(ctx)?;
77-
let residual_norms_prim = qjl.residual_norms.clone().execute::<PrimitiveArray>(ctx)?;
78-
Some((qjl_signs_prim, residual_norms_prim))
79-
} else {
80-
None
81-
};
82-
83-
// Single fused loop: MSE decode + optional QJL correction per row.
84-
let mut output = BufferMut::<f32>::with_capacity(num_rows * dim);
72+
// MSE decode: dequantize → inverse rotate → scale by norm.
73+
let mut mse_output = BufferMut::<f32>::with_capacity(num_rows * dim);
8574
let mut dequantized = vec![0.0f32; padded_dim];
8675
let mut unrotated = vec![0.0f32; padded_dim];
87-
// QJL scratch buffers (only used when qjl_data is Some).
88-
let mut qjl_signs_vec = vec![0.0f32; padded_dim];
89-
let mut qjl_projected = vec![0.0f32; padded_dim];
9076

9177
for row in 0..num_rows {
9278
let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim];
9379
let norm = norms[row];
9480

95-
// MSE: dequantize → inverse rotate → scale by norm.
9681
for idx in 0..padded_dim {
9782
dequantized[idx] = centroids[row_indices[idx] as usize];
9883
}
84+
9985
rotation.inverse_rotate(&dequantized, &mut unrotated);
86+
10087
for idx in 0..dim {
10188
unrotated[idx] *= norm;
10289
}
10390

104-
if let Some((ref qjl_signs_prim, ref residual_norms_prim)) = qjl_data {
105-
// QJL: apply residual correction inline, reusing the MSE rotation.
106-
let qjl_signs_u8 = qjl_signs_prim.as_slice::<u8>();
107-
let residual_norms = residual_norms_prim.as_slice::<f32>();
108-
let residual_norm = residual_norms[row];
109-
110-
let row_signs = &qjl_signs_u8[row * padded_dim..(row + 1) * padded_dim];
111-
for idx in 0..padded_dim {
112-
qjl_signs_vec[idx] = if row_signs[idx] != 0 { 1.0 } else { -1.0 };
113-
}
114-
115-
rotation.inverse_rotate(&qjl_signs_vec, &mut qjl_projected);
116-
let scale = qjl_scale * residual_norm;
117-
118-
for idx in 0..dim {
119-
output.push(unrotated[idx] + scale * qjl_projected[idx]);
120-
}
121-
} else {
122-
output.extend_from_slice(&unrotated[..dim]);
91+
mse_output.extend_from_slice(&unrotated[..dim]);
92+
}
93+
94+
// If no QJL correction, we're done.
95+
let Some(qjl) = &array.qjl else {
96+
let elements = PrimitiveArray::new::<f32>(mse_output.freeze(), Validity::NonNullable);
97+
return Ok(FixedSizeListArray::try_new(
98+
elements.into_array(),
99+
array.dimension(),
100+
Validity::NonNullable,
101+
num_rows,
102+
)?
103+
.into_array());
104+
};
105+
106+
// Apply QJL residual correction.
107+
// FastLanes SIMD-unpacks the 1-bit bitpacked QJL signs into u8 0/1 values.
108+
let qjl_signs_prim = qjl.signs.clone().execute::<PrimitiveArray>(ctx)?;
109+
let qjl_signs_u8 = qjl_signs_prim.as_slice::<u8>();
110+
111+
let residual_norms_prim = qjl.residual_norms.clone().execute::<PrimitiveArray>(ctx)?;
112+
let residual_norms = residual_norms_prim.as_slice::<f32>();
113+
114+
let qjl_rot_signs_prim = qjl.rotation_signs.clone().execute::<PrimitiveArray>(ctx)?;
115+
let qjl_rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::<u8>(), dim)?;
116+
117+
let qjl_scale = qjl_correction_scale(padded_dim);
118+
let mse_elements = mse_output.as_ref();
119+
120+
let mut output = BufferMut::<f32>::with_capacity(num_rows * dim);
121+
let mut qjl_signs_vec = vec![0.0f32; padded_dim];
122+
let mut qjl_projected = vec![0.0f32; padded_dim];
123+
124+
for row in 0..num_rows {
125+
let mse_row = &mse_elements[row * dim..(row + 1) * dim];
126+
let residual_norm = residual_norms[row];
127+
128+
// Convert u8 0/1 → f32 ±1.0 for this row's signs.
129+
let row_signs = &qjl_signs_u8[row * padded_dim..(row + 1) * padded_dim];
130+
for idx in 0..padded_dim {
131+
qjl_signs_vec[idx] = if row_signs[idx] != 0 { 1.0 } else { -1.0 };
132+
}
133+
134+
qjl_rot.inverse_rotate(&qjl_signs_vec, &mut qjl_projected);
135+
let scale = qjl_scale * residual_norm;
136+
137+
for idx in 0..dim {
138+
output.push(mse_row[idx] + scale * qjl_projected[idx]);
123139
}
124140
}
125141

encodings/turboquant/src/vtable.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use crate::array::TurboQuantMetadata;
4242
use crate::decompress::execute_decompress;
4343

4444
const MSE_CHILDREN: usize = 4;
45-
const QJL_CHILDREN: usize = 2;
45+
const QJL_CHILDREN: usize = 3;
4646

4747
impl VTable for TurboQuant {
4848
type Array = TurboQuantArray;
@@ -86,6 +86,7 @@ impl VTable for TurboQuant {
8686
if let Some(qjl) = &array.qjl {
8787
qjl.signs.array_hash(state, precision);
8888
qjl.residual_norms.array_hash(state, precision);
89+
qjl.rotation_signs.array_hash(state, precision);
8990
}
9091
}
9192

@@ -104,6 +105,7 @@ impl VTable for TurboQuant {
104105
(Some(a), Some(b)) => {
105106
a.signs.array_eq(&b.signs, precision)
106107
&& a.residual_norms.array_eq(&b.residual_norms, precision)
108+
&& a.rotation_signs.array_eq(&b.rotation_signs, precision)
107109
}
108110
(None, None) => true,
109111
_ => false,
@@ -148,6 +150,12 @@ impl VTable for TurboQuant {
148150
.vortex_expect("QJL child requested but has_qjl is false")
149151
.residual_norms
150152
.clone(),
153+
6 => array
154+
.qjl
155+
.as_ref()
156+
.vortex_expect("QJL child requested but has_qjl is false")
157+
.rotation_signs
158+
.clone(),
151159
_ => vortex_panic!("TurboQuantArray child index {idx} out of bounds"),
152160
}
153161
}
@@ -160,6 +168,7 @@ impl VTable for TurboQuant {
160168
3 => "rotation_signs".to_string(),
161169
4 => "qjl_signs".to_string(),
162170
5 => "qjl_residual_norms".to_string(),
171+
6 => "qjl_rotation_signs".to_string(),
163172
_ => vortex_panic!("TurboQuantArray child_name index {idx} out of bounds"),
164173
}
165174
}
@@ -213,9 +222,11 @@ impl VTable for TurboQuant {
213222
let qjl = if metadata.has_qjl {
214223
let qjl_signs = children.get(4, &signs_dtype, len * padded_dim)?;
215224
let qjl_residual_norms = children.get(5, &norms_dtype, len)?;
225+
let qjl_rotation_signs = children.get(6, &signs_dtype, 3 * padded_dim)?;
216226
Some(QjlCorrection {
217227
signs: qjl_signs,
218228
residual_norms: qjl_residual_norms,
229+
rotation_signs: qjl_rotation_signs,
219230
})
220231
} else {
221232
None
@@ -253,6 +264,7 @@ impl VTable for TurboQuant {
253264
if let Some(qjl) = &mut array.qjl {
254265
qjl.signs = iter.next().vortex_expect("qjl_signs child");
255266
qjl.residual_norms = iter.next().vortex_expect("qjl_residual_norms child");
267+
qjl.rotation_signs = iter.next().vortex_expect("qjl_rotation_signs child");
256268
}
257269
Ok(())
258270
}

0 commit comments

Comments
 (0)