Skip to content

Commit 109ef8a

Browse files
lwwmanningclaude
andcommitted
chore[turboquant]: address review — hot loop opts, tests, perf TODOs
Hot loop optimizations in compress.rs: - Remove unnecessary `residual.fill(0.0)` — [..dim] is overwritten every row, [dim..] stays zero from initialization - Move `projected.fill(0.0)` into the else branch (only needed when residual_norm == 0, since rotate() overwrites when called) New tests (88 total, +3): - all_zero_vectors_roundtrip: exercises the norm==0 branch, verifies zero-in → zero-out - f64_input_encodes_successfully: exercises the f64→f32 conversion path in extract_f32_elements - mse_serde_roundtrip: serializes metadata via VTable::serialize, deserializes, rebuilds from children, and verifies identical decode Performance TODOs documented: - Double extract_f32_elements materialization in encode_qjl (existing) - Double RotationMatrix::try_new in encode_qjl (new) - Centroids Vec→BufferMut copy (new) - Per-element QJL sign bit extraction in decompress (new) Signed-off-by: Will Manning <will@spiraldb.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Will Manning <will@willmanning.io>
1 parent ee172d6 commit 109ef8a

3 files changed

Lines changed: 149 additions & 4 deletions

File tree

encodings/turboquant/src/compress.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ pub fn turboquant_encode_mse(
137137
let norms_array = PrimitiveArray::new::<f32>(norms_buf.freeze(), Validity::NonNullable);
138138

139139
// Store centroids as a child array.
140+
// TODO(perf): `get_centroids` returns Vec<f32>; could avoid the copy by
141+
// supporting Buffer::from(Vec<T>) or caching as Buffer directly.
140142
let mut centroids_buf = BufferMut::<f32>::with_capacity(centroids.len());
141143
centroids_buf.extend_from_slice(&centroids);
142144
let centroids_array = PrimitiveArray::new::<f32>(centroids_buf.freeze(), Validity::NonNullable);
@@ -194,6 +196,8 @@ pub fn turboquant_encode_qjl(
194196
};
195197
let mse_inner = turboquant_encode_mse(fsl, &mse_config)?;
196198

199+
// TODO(perf): `turboquant_encode_mse` above already constructs the same
200+
// RotationMatrix from the same seed. Refactor to share it.
197201
let rotation = RotationMatrix::try_new(seed, dim)?;
198202
let padded_dim = rotation.padded_dim();
199203

@@ -250,18 +254,20 @@ pub fn turboquant_encode_qjl(
250254
}
251255
}
252256

253-
// Compute residual.
254-
residual.fill(0.0);
257+
// Compute residual: r = x - x̂. Only [..dim] is written; tail stays zero
258+
// from initialization and is never modified.
255259
for j in 0..dim {
256260
residual[j] = x[j] - dequantized[j];
257261
}
258262
let residual_norm = l2_norm(&residual[..dim]);
259263
residual_norms_buf.push(residual_norm);
260264

261-
// QJL: sign(S * r).
262-
projected.fill(0.0);
265+
// QJL: sign(S · r). rotate() writes all of `projected` when called;
266+
// when residual_norm == 0 we must zero it since it has stale data.
263267
if residual_norm > 0.0 {
264268
qjl_rotation.rotate(&residual, &mut projected);
269+
} else {
270+
projected.fill(0.0);
265271
}
266272

267273
let bit_offset = row * padded_dim;

encodings/turboquant/src/decompress.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ pub fn execute_decompress_qjl(
146146
let mse_row = &mse_elements[row * dim..(row + 1) * dim];
147147
let residual_norm = residual_norms[row];
148148

149+
// TODO(perf): Per-element bit extraction + branch is hard to autovectorize.
150+
// Unlike MSE rotation signs (which are amortized once for all rows), QJL
151+
// signs change per row so they can't be pre-expanded. Consider reading raw
152+
// bytes and using bitwise ops to generate ±1.0 f32s in bulk.
149153
let bit_offset = row * padded_dim;
150154
for idx in 0..padded_dim {
151155
qjl_signs_vec[idx] = if qjl_bit_buf.value(bit_offset + idx) {

encodings/turboquant/src/lib.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,4 +645,139 @@ mod tests {
645645
);
646646
Ok(())
647647
}
648+
649+
// -----------------------------------------------------------------------
650+
// Edge case and input format tests
651+
// -----------------------------------------------------------------------
652+
653+
/// Verify that all-zero vectors roundtrip correctly (norm == 0 branch).
654+
#[test]
655+
fn all_zero_vectors_roundtrip() -> VortexResult<()> {
656+
let num_rows = 10;
657+
let dim = 128;
658+
let buf = BufferMut::<f32>::full(0.0f32, num_rows * dim);
659+
let elements = PrimitiveArray::new::<f32>(buf.freeze(), Validity::NonNullable);
660+
let fsl = FixedSizeListArray::try_new(
661+
elements.into_array(),
662+
dim as u32,
663+
Validity::NonNullable,
664+
num_rows,
665+
)?;
666+
667+
let config = TurboQuantConfig {
668+
bit_width: 3,
669+
seed: Some(42),
670+
};
671+
let (original, decoded) = encode_decode_mse(&fsl, &config)?;
672+
// All-zero vectors should decode to all-zero (norm=0 → 0 * anything = 0).
673+
for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() {
674+
assert_eq!(o, 0.0, "original[{i}] not zero");
675+
assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input");
676+
}
677+
Ok(())
678+
}
679+
680+
/// Verify that f64 input is accepted and encoded (converted to f32 internally).
681+
#[test]
682+
fn f64_input_encodes_successfully() -> VortexResult<()> {
683+
let num_rows = 10;
684+
let dim = 64;
685+
let mut rng = StdRng::seed_from_u64(99);
686+
let normal = Normal::new(0.0f64, 1.0).unwrap();
687+
688+
let mut buf = BufferMut::<f64>::with_capacity(num_rows * dim);
689+
for _ in 0..(num_rows * dim) {
690+
buf.push(normal.sample(&mut rng));
691+
}
692+
let elements = PrimitiveArray::new::<f64>(buf.freeze(), Validity::NonNullable);
693+
let fsl = FixedSizeListArray::try_new(
694+
elements.into_array(),
695+
dim as u32,
696+
Validity::NonNullable,
697+
num_rows,
698+
)?;
699+
700+
let config = TurboQuantConfig {
701+
bit_width: 3,
702+
seed: Some(42),
703+
};
704+
// Verify encoding succeeds with f64 input (f64→f32 conversion).
705+
let encoded = turboquant_encode_mse(&fsl, &config)?;
706+
assert_eq!(encoded.norms().len(), num_rows);
707+
assert_eq!(encoded.dimension(), dim as u32);
708+
Ok(())
709+
}
710+
711+
/// Verify serde roundtrip: serialize MSE array metadata + children, then rebuild.
712+
#[test]
713+
fn mse_serde_roundtrip() -> VortexResult<()> {
714+
use vortex_array::DynArray;
715+
use vortex_array::SerializeMetadata;
716+
use vortex_array::vtable::VTable;
717+
718+
let fsl = make_fsl(10, 128, 42);
719+
let config = TurboQuantConfig {
720+
bit_width: 3,
721+
seed: Some(123),
722+
};
723+
let encoded = turboquant_encode_mse(&fsl, &config)?;
724+
725+
// Serialize metadata.
726+
let metadata = <crate::mse::TurboQuantMSE as VTable>::metadata(&encoded)?;
727+
let serialized = <crate::mse::TurboQuantMSE as VTable>::serialize(metadata)?
728+
.expect("metadata should serialize");
729+
730+
// Collect children.
731+
let nchildren = <crate::mse::TurboQuantMSE as VTable>::nchildren(&encoded);
732+
assert_eq!(nchildren, 4);
733+
let children: Vec<ArrayRef> = (0..nchildren)
734+
.map(|i| <crate::mse::TurboQuantMSE as VTable>::child(&encoded, i))
735+
.collect();
736+
737+
// Deserialize and rebuild.
738+
let deserialized = <crate::mse::TurboQuantMSE as VTable>::deserialize(
739+
&serialized,
740+
encoded.dtype(),
741+
encoded.len(),
742+
&[],
743+
&SESSION,
744+
)?;
745+
746+
// Verify metadata fields survived roundtrip.
747+
assert_eq!(deserialized.dimension, encoded.dimension());
748+
assert_eq!(deserialized.bit_width, encoded.bit_width() as u32);
749+
assert_eq!(deserialized.padded_dim, encoded.padded_dim());
750+
assert_eq!(deserialized.rotation_seed, encoded.rotation_seed());
751+
752+
// Verify the rebuilt array decodes identically.
753+
let mut ctx = SESSION.create_execution_ctx();
754+
let decoded_original = encoded
755+
.clone()
756+
.into_array()
757+
.execute::<FixedSizeListArray>(&mut ctx)?;
758+
let original_elements = decoded_original.elements().to_canonical()?.into_primitive();
759+
760+
// Rebuild from children (simulating deserialization).
761+
let rebuilt = crate::mse::array::TurboQuantMSEArray::try_new(
762+
encoded.dtype().clone(),
763+
children[0].clone(),
764+
children[1].clone(),
765+
children[2].clone(),
766+
children[3].clone(),
767+
deserialized.dimension,
768+
deserialized.bit_width as u8,
769+
deserialized.padded_dim,
770+
deserialized.rotation_seed,
771+
)?;
772+
let decoded_rebuilt = rebuilt
773+
.into_array()
774+
.execute::<FixedSizeListArray>(&mut ctx)?;
775+
let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive();
776+
777+
assert_eq!(
778+
original_elements.as_slice::<f32>(),
779+
rebuilt_elements.as_slice::<f32>()
780+
);
781+
Ok(())
782+
}
648783
}

0 commit comments

Comments
 (0)