Skip to content

Commit 76e8004

Browse files
committed
min dimension 3
Signed-off-by: Will Manning <will@willmanning.io>
1 parent a29c252 commit 76e8004

4 files changed

Lines changed: 33 additions & 16 deletions

File tree

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
3636
if !(1..=8).contains(&bit_width) {
3737
vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}");
3838
}
39-
if dimension < 2 {
40-
vortex_bail!("TurboQuant dimension must be >= 2, got {dimension}");
39+
if dimension < 3 {
40+
vortex_bail!("TurboQuant dimension must be >= 3, got {dimension}");
4141
}
4242

4343
if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) {
@@ -306,5 +306,6 @@ mod tests {
306306
assert!(get_centroids(128, 0).is_err());
307307
assert!(get_centroids(128, 9).is_err());
308308
assert!(get_centroids(1, 2).is_err());
309+
assert!(get_centroids(2, 2).is_err());
309310
}
310311
}

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ pub fn turboquant_encode_mse(
198198
);
199199
let dimension = fsl.list_size();
200200
vortex_ensure!(
201-
dimension >= 2,
202-
"TurboQuant requires dimension >= 2, got {dimension}"
201+
dimension >= 3,
202+
"TurboQuant requires dimension >= 3, got {dimension}"
203203
);
204204

205205
if fsl.is_empty() {
@@ -233,8 +233,8 @@ pub fn turboquant_encode_qjl(
233233
);
234234
let dimension = fsl.list_size();
235235
vortex_ensure!(
236-
dimension >= 2,
237-
"TurboQuant requires dimension >= 2, got {dimension}"
236+
dimension >= 3,
237+
"TurboQuant requires dimension >= 3, got {dimension}"
238238
);
239239

240240
if fsl.is_empty() {

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -468,31 +468,38 @@ mod tests {
468468
Ok(())
469469
}
470470

471-
#[test]
472-
fn mse_rejects_dimension_below_2() {
473-
let fsl = make_fsl_dim1();
471+
#[rstest]
472+
#[case(1)]
473+
#[case(2)]
474+
fn mse_rejects_dimension_below_3(#[case] dim: usize) {
475+
let fsl = make_fsl_small(dim);
474476
let config = TurboQuantConfig {
475477
bit_width: 2,
476478
seed: Some(0),
477479
};
478480
assert!(turboquant_encode_mse(&fsl, &config).is_err());
479481
}
480482

481-
#[test]
482-
fn qjl_rejects_dimension_below_2() {
483-
let fsl = make_fsl_dim1();
483+
#[rstest]
484+
#[case(1)]
485+
#[case(2)]
486+
fn qjl_rejects_dimension_below_3(#[case] dim: usize) {
487+
let fsl = make_fsl_small(dim);
484488
let config = TurboQuantConfig {
485489
bit_width: 3,
486490
seed: Some(0),
487491
};
488492
assert!(turboquant_encode_qjl(&fsl, &config).is_err());
489493
}
490494

491-
fn make_fsl_dim1() -> FixedSizeListArray {
492-
let mut buf = BufferMut::<f32>::with_capacity(1);
493-
buf.push(1.0);
495+
fn make_fsl_small(dim: usize) -> FixedSizeListArray {
496+
let mut buf = BufferMut::<f32>::with_capacity(dim);
497+
for i in 0..dim {
498+
buf.push(i as f32 + 1.0);
499+
}
494500
let elements = PrimitiveArray::new::<f32>(buf.freeze(), Validity::NonNullable);
495-
FixedSizeListArray::try_new(elements.into_array(), 1, Validity::NonNullable, 1).unwrap()
501+
FixedSizeListArray::try_new(elements.into_array(), dim as u32, Validity::NonNullable, 1)
502+
.unwrap()
496503
}
497504

498505
// -----------------------------------------------------------------------

vortex-tensor/src/encodings/turboquant/scheme.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ fn get_tensor_element_ptype_and_length(dtype: &DType) -> VortexResult<(PType, u3
130130
),
131131
};
132132

133+
// TurboQuant requires dimension >= 3: the marginal coordinate distribution
134+
// (1 - x^2)^((d-3)/2) has a singularity at d=2 (arcsine distribution) that
135+
// causes NaN in the Max-Lloyd centroid computation.
136+
vortex_ensure!(
137+
*fsl_len >= 3,
138+
"TurboQuant requires dimension >= 3, got {}",
139+
fsl_len
140+
);
141+
133142
if let &DType::Primitive(ptype, Nullability::NonNullable) = element_dtype.as_ref() {
134143
Ok((ptype, *fsl_len))
135144
} else {

0 commit comments

Comments
 (0)