Skip to content

Commit a0c6252

Browse files
committed
more fixups
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent a3048ea commit a0c6252

13 files changed

Lines changed: 36 additions & 75 deletions

File tree

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1212
use std::sync::LazyLock;
1313

14+
use vortex_buffer::Buffer;
1415
use vortex_error::VortexResult;
1516
use vortex_error::vortex_ensure;
1617
use vortex_utils::aliases::dash_map::DashMap;
@@ -29,14 +30,14 @@ const INTEGRATION_POINTS: usize = 1000;
2930

3031
// TODO(connor): Maybe we should just store an `ArrayRef` here?
3132
/// Global centroid cache keyed by (dimension, bit_width).
32-
static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(DashMap::default);
33+
static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new(DashMap::default);
3334

3435
/// Get or compute cached centroids for the given dimension and bit width.
3536
///
3637
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
3738
/// quantization levels for the coordinate distribution after random rotation in
3839
/// `dimension`-dimensional space.
39-
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
40+
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
4041
vortex_ensure!(
4142
(1..=MAX_BIT_WIDTH).contains(&bit_width),
4243
"TurboQuant bit_width must be 1-{}, got {bit_width}",
@@ -92,7 +93,7 @@ impl HalfIntExponent {
9293
/// The probability distribution function is:
9394
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
9495
/// where `C_d` is the normalizing constant.
95-
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
96+
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer<f32> {
9697
debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width));
9798
let num_centroids = 1usize << bit_width;
9899

@@ -288,7 +289,7 @@ mod tests {
288289
#[case(128, 4)]
289290
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
290291
let centroids = get_centroids(dim, bits)?;
291-
for &val in &centroids {
292+
for &val in centroids.iter() {
292293
assert!(
293294
(-1.0..=1.0).contains(&val),
294295
"centroid out of [-1, 1]: {val}",

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ pub unsafe fn turboquant_encode_unchecked(
171171

172172
let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
173173
let quantized_fsl =
174-
build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?;
174+
build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?;
175175
let padded_vector = Vector::try_new_vector_array(quantized_fsl)?;
176176

177177
let sorf_options = SorfOptions {
@@ -185,8 +185,8 @@ pub unsafe fn turboquant_encode_unchecked(
185185

186186
/// Shared intermediate results from the quantization loop.
187187
struct QuantizationResult {
188-
centroids: Vec<f32>,
189-
all_indices: BufferMut<u8>,
188+
centroids: Buffer<f32>,
189+
all_indices: Buffer<u8>,
190190
padded_dim: usize,
191191
}
192192

@@ -202,8 +202,7 @@ fn turboquant_quantize_core(
202202
num_rounds: u8,
203203
ctx: &mut ExecutionCtx,
204204
) -> VortexResult<QuantizationResult> {
205-
let dimension =
206-
usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize");
205+
let dimension = fsl.list_size() as usize;
207206
let num_rows = fsl.len();
208207

209208
let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?;
@@ -238,7 +237,7 @@ fn turboquant_quantize_core(
238237

239238
Ok(QuantizationResult {
240239
centroids,
241-
all_indices,
240+
all_indices: all_indices.freeze(),
242241
padded_dim,
243242
})
244243
}
@@ -250,13 +249,12 @@ fn turboquant_quantize_core(
250249
/// without knowledge of the rotation.
251250
fn build_quantized_fsl(
252251
num_rows: usize,
253-
all_indices: BufferMut<u8>,
254-
centroids: &[f32],
252+
all_indices: Buffer<u8>,
253+
centroids: Buffer<f32>,
255254
padded_dim: usize,
256255
) -> VortexResult<ArrayRef> {
257-
let codes = PrimitiveArray::new::<u8>(all_indices.freeze(), Validity::NonNullable);
258-
let centroids_array =
259-
PrimitiveArray::new::<f32>(Buffer::copy_from(centroids), Validity::NonNullable);
256+
let codes = PrimitiveArray::new::<u8>(all_indices, Validity::NonNullable);
257+
let centroids_array = PrimitiveArray::new::<f32>(centroids, Validity::NonNullable);
260258

261259
let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?;
262260

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult<VectorMatcherMeta
169169
vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}")
170170
})?;
171171

172-
let dimensions = vector_metadata.list_size();
172+
let dimensions = vector_metadata.dimensions();
173173
vortex_ensure!(
174174
dimensions >= MIN_DIMENSION,
175175
"TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}",

vortex-tensor/src/encodings/turboquant/scheme.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl Scheme for TurboQuantScheme {
8383
.bit_width()
8484
.try_into()
8585
.vortex_expect("invalid bit width for TurboQuant");
86-
let dimension = vector_metadata.list_size();
86+
let dimension = vector_metadata.dimensions();
8787

8888
CompressionEstimate::Verdict(EstimateVerdict::Ratio(estimate_compression_ratio(
8989
element_bit_width,

vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -311,37 +311,3 @@ fn f16_input_encodes_successfully() -> VortexResult<()> {
311311
assert_eq!(decoded_fsl.len(), num_rows);
312312
Ok(())
313313
}
314-
315-
/// Verify that the checked encode accepts normalized f16 input.
316-
#[test]
317-
fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> {
318-
let num_rows = 10;
319-
let dim = 128;
320-
let mut rng = StdRng::seed_from_u64(99);
321-
let normal = Normal::new(0.0f32, 1.0).unwrap();
322-
323-
let mut buf = BufferMut::<half::f16>::with_capacity(num_rows * dim);
324-
for _ in 0..(num_rows * dim) {
325-
buf.push(half::f16::from_f32(normal.sample(&mut rng)));
326-
}
327-
let elements = PrimitiveArray::new::<half::f16>(buf.freeze(), Validity::NonNullable);
328-
let fsl = FixedSizeListArray::try_new(
329-
elements.into_array(),
330-
dim.try_into().unwrap(),
331-
Validity::NonNullable,
332-
num_rows,
333-
)?;
334-
335-
let ext = make_vector_ext(&fsl);
336-
let config = TurboQuantConfig {
337-
bit_width: 3,
338-
seed: Some(42),
339-
num_rounds: 3,
340-
};
341-
342-
let mut ctx = SESSION.create_execution_ctx();
343-
let normalized = normalize_as_l2_denorm(ext, &mut ctx)?.child_at(0).clone();
344-
let encoded = turboquant_encode(normalized, &config, &mut ctx)?;
345-
assert_eq!(encoded.len(), num_rows);
346-
Ok(())
347-
}

vortex-tensor/src/fixed_shape/matcher.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub struct FixedShapeTensorMatcherMetadata<'a> {
3131
///
3232
/// This matches the `FixedSizeList` list size in the storage dtype, which is the product of
3333
/// the logical shape dimensions.
34-
flat_list_size: usize,
34+
flat_list_size: u32,
3535
}
3636

3737
impl Matcher for AnyFixedShapeTensor {
@@ -64,7 +64,7 @@ impl Matcher for AnyFixedShapeTensor {
6464
Some(FixedShapeTensorMatcherMetadata {
6565
metadata,
6666
element_ptype: element_dtype.as_ptype(),
67-
flat_list_size: *list_size as usize,
67+
flat_list_size: *list_size,
6868
})
6969
}
7070
}
@@ -81,7 +81,7 @@ impl FixedShapeTensorMatcherMetadata<'_> {
8181
}
8282

8383
/// Returns the flattened element count for each tensor row.
84-
pub fn list_size(&self) -> usize {
84+
pub fn flat_list_size(&self) -> u32 {
8585
self.flat_list_size
8686
}
8787
}
@@ -118,7 +118,7 @@ mod tests {
118118

119119
let metadata = ext_dtype.metadata::<AnyFixedShapeTensor>();
120120
assert_eq!(metadata.element_ptype(), PType::F32);
121-
assert_eq!(metadata.list_size(), 24);
121+
assert_eq!(metadata.flat_list_size(), 24);
122122
assert_eq!(metadata.metadata().logical_shape(), &[2, 3, 4]);
123123
Ok(())
124124
}

vortex-tensor/src/matcher.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ impl TensorMatch<'_> {
4242
}
4343

4444
/// Returns the flattened element count for each logical tensor row.
45-
pub fn list_size(self) -> usize {
45+
pub fn list_size(self) -> u32 {
4646
match self {
47-
Self::FixedShapeTensor(metadata) => metadata.list_size(),
47+
Self::FixedShapeTensor(metadata) => metadata.flat_list_size(),
4848
Self::Vector(metadata) => metadata.dimensions(),
4949
}
5050
}

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ impl ScalarFnVTable for InnerProduct {
182182
let tensor_match = ext
183183
.metadata_opt::<AnyTensor>()
184184
.vortex_expect("we already validated this in `return_dtype`");
185-
let dimensions = tensor_match.list_size();
185+
let dimensions = tensor_match.list_size() as usize;
186186

187187
// Extract the storage array from each extension input. We pass the storage (FSL) rather
188188
// than the extension array to avoid canonicalizing the extension wrapper.

vortex-tensor/src/scalar_fns/l2_denorm.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ impl ScalarFnVTable for L2Denorm {
240240
.as_extension()
241241
.metadata_opt::<AnyTensor>()
242242
.vortex_expect("we already validated this in `return_dtype`");
243-
let tensor_flat_size = tensor_match.list_size();
243+
let tensor_flat_size = tensor_match.list_size() as usize;
244244

245245
let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?;
246246

@@ -423,7 +423,7 @@ pub fn normalize_as_l2_denorm(
423423
) -> VortexResult<ScalarFnArray> {
424424
let row_count = input.len();
425425
let tensor_match = validate_tensor_float_input(input.dtype())?;
426-
let tensor_flat_size = tensor_match.list_size();
426+
let tensor_flat_size = tensor_match.list_size() as usize;
427427

428428
// Constant fast path: if the input is a constant-backed extension, normalize the single
429429
// stored row once and return an `L2Denorm` whose children are both `ConstantArray`s.
@@ -520,7 +520,7 @@ pub(crate) fn try_build_constant_l2_denorm(
520520
.as_extension()
521521
.metadata_opt::<AnyTensor>()
522522
.vortex_expect("caller validated input has AnyTensor metadata");
523-
let list_size = tensor_match.list_size();
523+
let list_size = tensor_match.list_size() as usize;
524524
let original_nullability = input.dtype().nullability();
525525
let ext_dtype = input.dtype().as_extension().clone();
526526
let storage_fsl_nullability = storage.dtype().nullability();
@@ -630,7 +630,7 @@ fn validate_l2_normalized_rows_against_norms(
630630
let tensor_match = validate_tensor_float_input(normalized.dtype())?;
631631
let element_ptype = tensor_match.element_ptype();
632632
let tolerance = unit_norm_tolerance(element_ptype);
633-
let tensor_flat_size = tensor_match.list_size();
633+
let tensor_flat_size = tensor_match.list_size() as usize;
634634

635635
if let Some(norms) = norms {
636636
vortex_ensure_eq!(

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ impl ScalarFnVTable for L2Norm {
132132
let tensor_match = ext
133133
.metadata_opt::<AnyTensor>()
134134
.vortex_expect("we already validated this in `return_dtype`");
135-
let tensor_flat_size = tensor_match.list_size();
135+
let tensor_flat_size = tensor_match.list_size() as usize;
136136
let element_ptype = tensor_match.element_ptype();
137137

138138
let norm_dtype = DType::Primitive(element_ptype, ext.nullability());

0 commit comments

Comments
 (0)