Skip to content

Commit b6ee27b

Browse files
authored
Clean up vortex-tensor even more (#7610)
## Summary Tracking issue: #7297 This is mostly just cosmetic changes that will help with some future changes incoming. - Extract validate_vector_storage_dtype from Vector::validate_dtype into types/vector/mod.rs. - Move unit_norm_tolerance and BinaryTensorOpMetadata into utils.rs; switch unit_norm_tolerance to the c*sqrt(d)*epsilon bound with a dimensions parameter. - Drop op_name from validate_binary_tensor_float_inputs. - TurboQuantConfig::seed Option<u64> -> u64. - SorfOptions::dimension -> SorfOptions::dimensions and matching SorfMatrix::try_new parameter. - centroids::get_centroids -> compute_or_get_centroids. - Crate-root #![cfg_attr(test, allow(clippy::unwrap_used, clippy::expect_used, clippy::unwrap_in_result))]. - Extract validate_vector_storage_dtype from Vector::validate_dtype into types/vector/mod.rs. - Move unit_norm_tolerance and BinaryTensorOpMetadata into utils.rs; switch unit_norm_tolerance to the c*sqrt(d)*epsilon bound with a dimensions parameter. - Drop op_name from validate_binary_tensor_float_inputs. - TurboQuantConfig::seed Option<u64> -> u64. - SorfOptions::dimension -> SorfOptions::dimensions and matching SorfMatrix::try_new parameter. - centroids::get_centroids -> compute_or_get_centroids. - Crate-root #![cfg_attr(test, allow(clippy::unwrap_used, clippy::expect_used, clippy::unwrap_in_result))]. ## API Changes Some renames ## Testing N/A Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 56b0731 commit b6ee27b

21 files changed

Lines changed: 289 additions & 245 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8
2828

2929
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::num_rounds: u8
3030

31-
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: core::option::Option<u64>
31+
pub vortex_tensor::encodings::turboquant::TurboQuantConfig::seed: u64
3232

3333
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantConfig
3434

@@ -440,11 +440,11 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::padded_dim(&self)
440440

441441
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::rotate(&self, input: &[f32], output: &mut [f32])
442442

443-
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::try_new(seed: u64, dimension: usize, num_rounds: usize) -> vortex_error::VortexResult<Self>
443+
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::try_new(seed: u64, dimensions: usize, num_rounds: usize) -> vortex_error::VortexResult<Self>
444444

445445
pub struct vortex_tensor::scalar_fns::sorf_transform::SorfOptions
446446

447-
pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::dimension: u32
447+
pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::dimensions: u32
448448

449449
pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::element_ptype: vortex_array::dtype::ptype::PType
450450

@@ -490,7 +490,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) ->
490490

491491
impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform
492492

493-
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
493+
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
494494

495495
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
496496

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new
3636
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
3737
/// quantization levels for the coordinate distribution after random rotation in
3838
/// `dimension`-dimensional space.
39-
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
39+
pub fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
4040
vortex_ensure!(
4141
(1..=MAX_BIT_WIDTH).contains(&bit_width),
4242
"TurboQuant bit_width must be 1-{}, got {bit_width}",
@@ -239,7 +239,7 @@ mod tests {
239239
#[case] bits: u8,
240240
#[case] expected: usize,
241241
) -> VortexResult<()> {
242-
let centroids = get_centroids(dim, bits)?;
242+
let centroids = compute_or_get_centroids(dim, bits)?;
243243
assert_eq!(centroids.len(), expected);
244244
Ok(())
245245
}
@@ -251,7 +251,7 @@ mod tests {
251251
#[case(128, 4)]
252252
#[case(768, 2)]
253253
fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
254-
let centroids = get_centroids(dim, bits)?;
254+
let centroids = compute_or_get_centroids(dim, bits)?;
255255
for window in centroids.windows(2) {
256256
assert!(
257257
window[0] < window[1],
@@ -268,7 +268,7 @@ mod tests {
268268
#[case(256, 2)]
269269
#[case(768, 2)]
270270
fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
271-
let centroids = get_centroids(dim, bits)?;
271+
let centroids = compute_or_get_centroids(dim, bits)?;
272272
let count = centroids.len();
273273
for idx in 0..count / 2 {
274274
let diff = (centroids[idx] + centroids[count - 1 - idx]).abs();
@@ -287,7 +287,7 @@ mod tests {
287287
#[case(128, 1)]
288288
#[case(128, 4)]
289289
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
290-
let centroids = get_centroids(dim, bits)?;
290+
let centroids = compute_or_get_centroids(dim, bits)?;
291291
for &val in centroids.iter() {
292292
assert!(
293293
(-1.0..=1.0).contains(&val),
@@ -299,15 +299,15 @@ mod tests {
299299

300300
#[test]
301301
fn centroids_cached() -> VortexResult<()> {
302-
let c1 = get_centroids(128, 2)?;
303-
let c2 = get_centroids(128, 2)?;
302+
let c1 = compute_or_get_centroids(128, 2)?;
303+
let c2 = compute_or_get_centroids(128, 2)?;
304304
assert_eq!(c1, c2);
305305
Ok(())
306306
}
307307

308308
#[test]
309309
fn find_nearest_basic() -> VortexResult<()> {
310-
let centroids = get_centroids(128, 2)?;
310+
let centroids = compute_or_get_centroids(128, 2)?;
311311
let boundaries = compute_centroid_boundaries(&centroids);
312312
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
313313

@@ -324,9 +324,9 @@ mod tests {
324324

325325
#[test]
326326
fn rejects_invalid_params() {
327-
assert!(get_centroids(128, 0).is_err());
328-
assert!(get_centroids(128, 9).is_err());
329-
assert!(get_centroids(1, 2).is_err());
330-
assert!(get_centroids(127, 2).is_err());
327+
assert!(compute_or_get_centroids(128, 0).is_err());
328+
assert!(compute_or_get_centroids(128, 9).is_err());
329+
assert!(compute_or_get_centroids(1, 2).is_err());
330+
assert!(compute_or_get_centroids(127, 2).is_err());
331331
}
332332
}

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use vortex_error::vortex_ensure;
3232
use crate::encodings::turboquant::MAX_BIT_WIDTH;
3333
use crate::encodings::turboquant::MIN_DIMENSION;
3434
use crate::encodings::turboquant::centroids::compute_centroid_boundaries;
35+
use crate::encodings::turboquant::centroids::compute_or_get_centroids;
3536
use crate::encodings::turboquant::centroids::find_nearest_centroid;
36-
use crate::encodings::turboquant::centroids::get_centroids;
3737
use crate::scalar_fns::l2_denorm::L2Denorm;
3838
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
3939
use crate::scalar_fns::sorf_transform::SorfMatrix;
@@ -48,8 +48,8 @@ use crate::utils::cast_to_f32;
4848
pub struct TurboQuantConfig {
4949
/// Bits per coordinate (1-8).
5050
pub bit_width: u8,
51-
/// Optional seed for the rotation matrix. If None, the default seed is used.
52-
pub seed: Option<u64>,
51+
/// Seed for the rotation matrix.
52+
pub seed: u64,
5353
/// Number of sign-diagonal + WHT rounds in the structured rotation (default 3).
5454
pub num_rounds: u8,
5555
}
@@ -58,7 +58,7 @@ impl Default for TurboQuantConfig {
5858
fn default() -> Self {
5959
Self {
6060
bit_width: MAX_BIT_WIDTH,
61-
seed: Some(42),
61+
seed: 42,
6262
num_rounds: 3,
6363
}
6464
}
@@ -141,7 +141,7 @@ pub unsafe fn turboquant_encode_unchecked(
141141
let vector_metadata = ext_dtype.as_extension().metadata::<AnyVector>();
142142
let element_ptype = vector_metadata.element_ptype();
143143

144-
let seed = config.seed.unwrap_or(42);
144+
let seed = config.seed;
145145
let num_rows = fsl.len();
146146

147147
if fsl.is_empty() {
@@ -161,7 +161,7 @@ pub unsafe fn turboquant_encode_unchecked(
161161
let sorf_options = SorfOptions {
162162
seed,
163163
num_rounds: config.num_rounds,
164-
dimension,
164+
dimensions: dimension,
165165
element_ptype,
166166
};
167167
return Ok(
@@ -177,7 +177,7 @@ pub unsafe fn turboquant_encode_unchecked(
177177
let sorf_options = SorfOptions {
178178
seed,
179179
num_rounds: config.num_rounds,
180-
dimension,
180+
dimensions: dimension,
181181
element_ptype,
182182
};
183183
Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
@@ -213,7 +213,7 @@ fn turboquant_quantize_core(
213213
let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
214214
let f32_elements = cast_to_f32(elements_prim)?;
215215

216-
let centroids = get_centroids(padded_dim_u32, bit_width)?;
216+
let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?;
217217
let boundaries = compute_centroid_boundaries(&centroids);
218218

219219
let mut all_indices = BufferMut::<u8>::with_capacity(num_rows * padded_dim);

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
//! // Normalize and quantize at 2 bits per coordinate in one pass.
122122
//! let session = VortexSession::empty().with::<ArraySession>();
123123
//! let mut ctx = session.create_execution_ctx();
124-
//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 };
124+
//! let config = TurboQuantConfig { bit_width: 2, seed: 42, num_rounds: 3 };
125125
//! let tq = turboquant_encode(vector, &config, &mut ctx).unwrap();
126126
//!
127127
//! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input.

vortex-tensor/src/encodings/turboquant/scheme.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,25 @@ impl Scheme for TurboQuantScheme {
111111
fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vectors: usize) -> f64 {
112112
let config = TurboQuantConfig::default();
113113
let padded_dim = dimensions.next_power_of_two() as usize;
114+
let element_bits = usize::from(element_bit_width);
114115

115-
// Per-vector: MSE codes per padded coordinate, plus one stored norm in the input element
116-
// float width.
117-
let compressed_bits_per_vector =
118-
usize::from(element_bit_width) + usize::from(config.bit_width) * padded_dim;
116+
// Get the size of the fully uncompressed vector data.
117+
let uncompressed_size_bits = element_bits * dimensions as usize * num_vectors;
118+
119+
// Per-vector: MSE codes per padded coordinate, plus one stored norm in the input element float
120+
// width.
121+
let norm_bits = element_bits;
122+
let compressed_bits_per_vector = usize::from(config.bit_width) * padded_dim;
123+
let total_bits_per_vector = norm_bits + compressed_bits_per_vector;
119124

120125
// Shared overhead: codebook centroids (2^bit_width f32 values).
121-
// Note: rotation signs are no longer stored — rotation is deterministic from seed.
122126
let num_centroids = 1usize << config.bit_width;
123127
debug_assert!(num_centroids <= MAX_CENTROIDS);
124128
let overhead_bits = num_centroids * 32; // centroids are always f32
125129

126-
let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits;
130+
// This includes the quantized vectors, norms, and centroid codebook.
131+
let compressed_size_bits = total_bits_per_vector * num_vectors + overhead_bits;
127132

128-
let uncompressed_size_bits = usize::from(element_bit_width) * dimensions as usize * num_vectors;
129133
uncompressed_size_bits as f64 / compressed_size_bits as f64
130134
}
131135

vortex-tensor/src/encodings/turboquant/tests/compute.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ fn slice_preserves_data() -> VortexResult<()> {
4040
let ext = make_vector_ext(&fsl);
4141
let config = TurboQuantConfig {
4242
bit_width: 3,
43-
seed: Some(123),
43+
seed: 123,
4444
num_rounds: 4,
4545
};
4646
let mut ctx = SESSION.create_execution_ctx();
@@ -85,7 +85,7 @@ fn scalar_at_matches_decompress() -> VortexResult<()> {
8585
let ext = make_vector_ext(&fsl);
8686
let config = TurboQuantConfig {
8787
bit_width: 3,
88-
seed: Some(123),
88+
seed: 123,
8989
num_rounds: 2,
9090
};
9191
let mut ctx = SESSION.create_execution_ctx();
@@ -108,7 +108,7 @@ fn l2_norm_readthrough() -> VortexResult<()> {
108108
let ext = make_vector_ext(&fsl);
109109
let config = TurboQuantConfig {
110110
bit_width: 3,
111-
seed: Some(123),
111+
seed: 123,
112112
num_rounds: 5,
113113
};
114114
let mut ctx = SESSION.create_execution_ctx();
@@ -146,7 +146,7 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()>
146146
let ext = make_vector_ext(&fsl);
147147
let config = TurboQuantConfig {
148148
bit_width: 1,
149-
seed: Some(123),
149+
seed: 123,
150150
num_rounds: 3,
151151
};
152152
let mut ctx = SESSION.create_execution_ctx();
@@ -183,7 +183,7 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR
183183
let ext = make_vector_ext(&fsl);
184184
let config = TurboQuantConfig {
185185
bit_width: 1,
186-
seed: Some(123),
186+
seed: 123,
187187
num_rounds: 3,
188188
};
189189
let mut ctx = SESSION.create_execution_ctx();

vortex-tensor/src/encodings/turboquant/tests/nullable.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> {
2323

2424
let config = TurboQuantConfig {
2525
bit_width: 3,
26-
seed: Some(123),
26+
seed: 123,
2727
num_rounds: 4,
2828
};
2929
let mut ctx = SESSION.create_execution_ctx();
@@ -84,7 +84,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> {
8484

8585
let config = TurboQuantConfig {
8686
bit_width: 2,
87-
seed: Some(123),
87+
seed: 123,
8888
num_rounds: 3,
8989
};
9090
let mut ctx = SESSION.create_execution_ctx();
@@ -114,7 +114,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> {
114114

115115
let config = TurboQuantConfig {
116116
bit_width: 3,
117-
seed: Some(123),
117+
seed: 123,
118118
num_rounds: 3,
119119
};
120120
let mut ctx = SESSION.create_execution_ctx();
@@ -156,7 +156,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> {
156156

157157
let config = TurboQuantConfig {
158158
bit_width: 3,
159-
seed: Some(123),
159+
seed: 123,
160160
num_rounds: 2,
161161
};
162162
let mut ctx = SESSION.create_execution_ctx();

0 commit comments

Comments
 (0)