Skip to content

Commit 258d323

Browse files
committed
clean up more code
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 1280166 commit 258d323

14 files changed

Lines changed: 198 additions & 183 deletions

File tree

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new
3636
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
3737
/// quantization levels for the coordinate distribution after random rotation in
3838
/// `dimension`-dimensional space.
39-
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
39+
pub fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
4040
vortex_ensure!(
4141
(1..=MAX_BIT_WIDTH).contains(&bit_width),
4242
"TurboQuant bit_width must be 1-{}, got {bit_width}",
@@ -239,7 +239,7 @@ mod tests {
239239
#[case] bits: u8,
240240
#[case] expected: usize,
241241
) -> VortexResult<()> {
242-
let centroids = get_centroids(dim, bits)?;
242+
let centroids = compute_or_get_centroids(dim, bits)?;
243243
assert_eq!(centroids.len(), expected);
244244
Ok(())
245245
}
@@ -251,7 +251,7 @@ mod tests {
251251
#[case(128, 4)]
252252
#[case(768, 2)]
253253
fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
254-
let centroids = get_centroids(dim, bits)?;
254+
let centroids = compute_or_get_centroids(dim, bits)?;
255255
for window in centroids.windows(2) {
256256
assert!(
257257
window[0] < window[1],
@@ -268,7 +268,7 @@ mod tests {
268268
#[case(256, 2)]
269269
#[case(768, 2)]
270270
fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
271-
let centroids = get_centroids(dim, bits)?;
271+
let centroids = compute_or_get_centroids(dim, bits)?;
272272
let count = centroids.len();
273273
for idx in 0..count / 2 {
274274
let diff = (centroids[idx] + centroids[count - 1 - idx]).abs();
@@ -287,7 +287,7 @@ mod tests {
287287
#[case(128, 1)]
288288
#[case(128, 4)]
289289
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
290-
let centroids = get_centroids(dim, bits)?;
290+
let centroids = compute_or_get_centroids(dim, bits)?;
291291
for &val in centroids.iter() {
292292
assert!(
293293
(-1.0..=1.0).contains(&val),
@@ -299,15 +299,15 @@ mod tests {
299299

300300
#[test]
301301
fn centroids_cached() -> VortexResult<()> {
302-
let c1 = get_centroids(128, 2)?;
303-
let c2 = get_centroids(128, 2)?;
302+
let c1 = compute_or_get_centroids(128, 2)?;
303+
let c2 = compute_or_get_centroids(128, 2)?;
304304
assert_eq!(c1, c2);
305305
Ok(())
306306
}
307307

308308
#[test]
309309
fn find_nearest_basic() -> VortexResult<()> {
310-
let centroids = get_centroids(128, 2)?;
310+
let centroids = compute_or_get_centroids(128, 2)?;
311311
let boundaries = compute_centroid_boundaries(&centroids);
312312
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
313313

@@ -324,9 +324,9 @@ mod tests {
324324

325325
#[test]
326326
fn rejects_invalid_params() {
327-
assert!(get_centroids(128, 0).is_err());
328-
assert!(get_centroids(128, 9).is_err());
329-
assert!(get_centroids(1, 2).is_err());
330-
assert!(get_centroids(127, 2).is_err());
327+
assert!(compute_or_get_centroids(128, 0).is_err());
328+
assert!(compute_or_get_centroids(128, 9).is_err());
329+
assert!(compute_or_get_centroids(1, 2).is_err());
330+
assert!(compute_or_get_centroids(127, 2).is_err());
331331
}
332332
}

0 commit comments

Comments
 (0)