Skip to content

Commit a872ba9

Browse files
committed
fix review comments
Signed-off-by: Will Manning <will@willmanning.io>
1 parent 5582f1e commit a872ba9

5 files changed

Lines changed: 84 additions & 12 deletions

File tree

encodings/turboquant/src/centroids.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,16 @@ fn conditional_mean(lo: f64, hi: f64, exponent: f64) -> f64 {
103103
return (lo + hi) / 2.0;
104104
}
105105

106-
let num_points = INTEGRATION_POINTS;
107-
let dx = (hi - lo) / num_points as f64;
106+
let dx = (hi - lo) / INTEGRATION_POINTS as f64;
108107

109108
let mut numerator = 0.0;
110109
let mut denominator = 0.0;
111110

112-
for step in 0..=num_points {
111+
for step in 0..=INTEGRATION_POINTS {
113112
let x_val = lo + (step as f64) * dx;
114113
let weight = pdf_unnormalized(x_val, exponent);
115114

116-
let trap_weight = if step == 0 || step == num_points {
115+
let trap_weight = if step == 0 || step == INTEGRATION_POINTS {
117116
0.5
118117
} else {
119118
1.0

encodings/turboquant/src/compress.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ fn turboquant_quantize_core(
138138

139139
/// Build a `TurboQuantArray` (MSE-only) from quantization results.
140140
fn build_turboquant_mse(
141-
dtype: &FixedSizeListArray,
141+
fsl: &FixedSizeListArray,
142142
core: MseQuantizationResult,
143143
bit_width: u8,
144144
) -> VortexResult<TurboQuantArray> {
145-
let dimension = dtype.list_size();
145+
let dimension = fsl.list_size();
146146

147147
let codes =
148148
PrimitiveArray::new::<u8>(core.all_indices.freeze(), Validity::NonNullable).into_array();
@@ -159,7 +159,7 @@ fn build_turboquant_mse(
159159
let rotation_signs = bitpack_rotation_signs(&core.rotation)?;
160160

161161
TurboQuantArray::try_new_mse(
162-
dtype.dtype().clone(),
162+
fsl.dtype().clone(),
163163
codes,
164164
norms_array,
165165
centroids_array,

encodings/turboquant/src/lib.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,4 +772,77 @@ mod tests {
772772
);
773773
Ok(())
774774
}
775+
776+
/// Verify serde roundtrip for QJL: serialize metadata + children, then rebuild.
777+
#[test]
778+
fn qjl_serde_roundtrip() -> VortexResult<()> {
779+
use vortex_array::DynArray;
780+
use vortex_array::vtable::VTable;
781+
782+
let fsl = make_fsl(10, 128, 42);
783+
let config = TurboQuantConfig {
784+
bit_width: 4,
785+
seed: Some(456),
786+
};
787+
let encoded = turboquant_encode_qjl(&fsl, &config)?;
788+
let encoded = TurboQuant::try_match(&*encoded).unwrap();
789+
790+
// Serialize metadata.
791+
let metadata = <TurboQuant as VTable>::metadata(encoded)?;
792+
let serialized =
793+
<TurboQuant as VTable>::serialize(metadata)?.expect("metadata should serialize");
794+
795+
// Collect children — QJL has 7 (4 MSE + 3 QJL).
796+
let nchildren = <TurboQuant as VTable>::nchildren(encoded);
797+
assert_eq!(nchildren, 7);
798+
let children: Vec<ArrayRef> = (0..nchildren)
799+
.map(|i| <TurboQuant as VTable>::child(encoded, i))
800+
.collect();
801+
802+
// Deserialize metadata.
803+
let deserialized = <TurboQuant as VTable>::deserialize(
804+
&serialized,
805+
encoded.dtype(),
806+
encoded.len(),
807+
&[],
808+
&SESSION,
809+
)?;
810+
811+
assert!(deserialized.has_qjl);
812+
assert_eq!(deserialized.dimension, encoded.dimension());
813+
814+
// Verify decode: original vs rebuilt from children.
815+
let mut ctx = SESSION.create_execution_ctx();
816+
let decoded_original = encoded
817+
.clone()
818+
.into_array()
819+
.execute::<FixedSizeListArray>(&mut ctx)?;
820+
let original_elements = decoded_original.elements().to_canonical()?.into_primitive();
821+
822+
// Rebuild with QJL children.
823+
let rebuilt = crate::array::TurboQuantArray::try_new_qjl(
824+
encoded.dtype().clone(),
825+
children[0].clone(),
826+
children[1].clone(),
827+
children[2].clone(),
828+
children[3].clone(),
829+
crate::array::QjlCorrection {
830+
signs: children[4].clone(),
831+
residual_norms: children[5].clone(),
832+
rotation_signs: children[6].clone(),
833+
},
834+
deserialized.dimension,
835+
deserialized.bit_width as u8,
836+
)?;
837+
let decoded_rebuilt = rebuilt
838+
.into_array()
839+
.execute::<FixedSizeListArray>(&mut ctx)?;
840+
let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive();
841+
842+
assert_eq!(
843+
original_elements.as_slice::<f32>(),
844+
rebuilt_elements.as_slice::<f32>()
845+
);
846+
Ok(())
847+
}
775848
}

encodings/turboquant/src/rotation.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl RotationMatrix {
129129
let mut out = Vec::with_capacity(total);
130130

131131
// Store in inverse order: sign_masks[2] (D₃), sign_masks[1] (D₂), sign_masks[0] (D₁)
132-
for &sign_idx in &[2, 1, 0] {
132+
for sign_idx in [2, 1, 0] {
133133
for &mask in &self.sign_masks[sign_idx] {
134134
out.push(if mask == 0 { 1u8 } else { 0u8 });
135135
}
@@ -157,9 +157,9 @@ impl RotationMatrix {
157157
// Reconstruct in storage order (inverse): [D₃, D₂, D₁] → sign_masks[2], [1], [0]
158158
let mut sign_masks: [Vec<u32>; 3] = std::array::from_fn(|_| Vec::with_capacity(padded_dim));
159159

160-
for (round, sign_idx) in [2, 1, 0].iter().enumerate() {
160+
for (round, sign_idx) in [2, 1, 0].into_iter().enumerate() {
161161
let offset = round * padded_dim;
162-
sign_masks[*sign_idx] = signs_u8[offset..offset + padded_dim]
162+
sign_masks[sign_idx] = signs_u8[offset..offset + padded_dim]
163163
.iter()
164164
.map(|&v| if v != 0 { 0u32 } else { F32_SIGN_BIT })
165165
.collect();

vortex/benches/single_encoding_throughput.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ macro_rules! turboquant_bench {
509509

510510
turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4);
511511
turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4);
512-
turboquant_bench!(compress, 768, 4, bench_tq_compress_768_2);
513-
turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_2);
512+
turboquant_bench!(compress, 768, 4, bench_tq_compress_768_4);
513+
turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_4);
514514
turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2);
515515
turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2);
516516
turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4);

0 commit comments

Comments
 (0)