Skip to content

Commit b59e8b8

Browse files
committed
even more cleanup
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 6111f27 commit b59e8b8

2 files changed

Lines changed: 57 additions & 68 deletions

File tree

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 55 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use vortex_array::stats::ArrayStats;
99
use vortex_error::VortexExpect;
1010
use vortex_error::VortexResult;
1111
use vortex_error::vortex_ensure;
12+
use vortex_error::vortex_ensure_eq;
1213

1314
use crate::encodings::turboquant::array::slots::Slot;
1415
use crate::encodings::turboquant::vtable::TurboQuant;
@@ -17,22 +18,22 @@ use crate::utils::extension_list_size;
1718

1819
/// TurboQuant array data.
1920
///
20-
/// TurboQuant is a lossy vector quantization encoding for [`Vector`] extension arrays.
21-
/// It stores quantized coordinate codes and per-vector norms, along with shared codebook
22-
/// centroids and SRHT rotation signs. See the [module docs](super) for algorithmic details.
21+
/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector)
22+
/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared
23+
/// codebook centroids and SRHT rotation signs.
2324
///
24-
/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty.
25+
/// See the [module docs](super) for algorithmic details.
2526
///
26-
/// [`Vector`]: crate::vector::Vector
27+
/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty.
2728
#[derive(Clone, Debug)]
2829
pub struct TurboQuantData {
29-
/// The [`Vector`] extension dtype that this array encodes. The storage dtype within the
30-
/// extension determines the element type (f16, f32, or f64) and the list size (dimension).
30+
/// The [`Vector`](crate::vector::Vector) extension dtype that this array encodes.
3131
///
32-
/// [`Vector`]: crate::vector::Vector
32+
/// The storage dtype within the extension determines the element type (f16, f32, or f64) and
33+
/// the list size (dimension).
3334
pub(crate) dtype: DType,
3435

35-
/// Child arrays stored as optional slots. See [`Slot`] for positions:
36+
/// Child arrays stored as slots. See [`Slot`] for positions:
3637
///
3738
/// - [`Codes`](Slot::Codes): `FixedSizeListArray<u8>` with `list_size == padded_dim`. Each row
3839
/// holds one u8 centroid index per padded coordinate. The cascade compressor handles packing
@@ -53,13 +54,13 @@ pub struct TurboQuantData {
5354
pub(crate) slots: Vec<Option<ArrayRef>>,
5455

5556
/// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size.
57+
///
5658
/// Stored as a convenience field to avoid repeatedly extracting it from `dtype`.
57-
/// Non-power-of-2 dimensions are zero-padded to [`padded_dim`](Self::padded_dim) for the
58-
/// Walsh-Hadamard transform.
5959
pub(crate) dimension: u32,
6060

6161
/// The number of bits per coordinate (1-8), derived from `log2(centroids.len())`.
62-
/// Zero for degenerate empty arrays.
62+
///
63+
/// This is 0 for degenerate empty arrays.
6364
pub(crate) bit_width: u8,
6465

6566
/// The stats for this array.
@@ -100,8 +101,8 @@ impl TurboQuantData {
100101
/// is >= 3.
101102
/// - `codes` is a `FixedSizeListArray<u8>` with `list_size == padded_dim` and
102103
/// `codes.len() == norms.len()`.
103-
/// - `norms` is a non-nullable primitive array whose ptype matches the element type of the
104-
/// Vector's storage dtype.
104+
/// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage
105+
/// dtype. This must match the validity of the `codes` array.
105106
/// - `centroids` is a non-nullable `PrimitiveArray<f32>` whose length is a power of 2 in
106107
/// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays.
107108
/// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays.
@@ -124,13 +125,22 @@ impl TurboQuantData {
124125
.and_then(|ext| extension_list_size(ext).ok())
125126
.vortex_expect("dtype must be a Vector extension type with FixedSizeList storage");
126127

127-
let bit_width = derive_bit_width(&centroids);
128+
let bit_width = if centroids.is_empty() {
129+
0
130+
} else {
131+
// Guaranteed to be 0-8 by validate().
132+
#[expect(clippy::cast_possible_truncation)]
133+
{
134+
centroids.len().trailing_zeros() as u8
135+
}
136+
};
128137

129138
let mut slots = vec![None; Slot::COUNT];
130139
slots[Slot::Codes as usize] = Some(codes);
131140
slots[Slot::Norms as usize] = Some(norms);
132141
slots[Slot::Centroids as usize] = Some(centroids);
133142
slots[Slot::RotationSigns as usize] = Some(rotation_signs);
143+
134144
Self {
135145
dtype,
136146
slots,
@@ -153,15 +163,19 @@ impl TurboQuantData {
153163
let ext = TurboQuant::validate_dtype(dtype)?;
154164
let dimension = extension_list_size(ext)?;
155165

156-
let num_rows = norms.len();
166+
let num_rows = codes.len();
167+
vortex_ensure_eq!(
168+
norms.len(),
169+
num_rows,
170+
"norms length must match codes length",
171+
);
172+
173+
// TODO(connor): Should we check that the codes and norms have the same validity? We could
174+
// also make it so that norms holds the validity and any null vectors encoded as codes is
175+
// just 0...
157176

158-
// Degenerate (empty) case: all children must be empty, bit_width is 0.
177+
// Degenerate (empty) case: all children must be empty, and bit_width is 0.
159178
if num_rows == 0 {
160-
vortex_ensure!(
161-
codes.is_empty(),
162-
"degenerate TurboQuant must have empty codes, got length {}",
163-
codes.len()
164-
);
165179
vortex_ensure!(
166180
centroids.is_empty(),
167181
"degenerate TurboQuant must have empty centroids, got length {}",
@@ -183,7 +197,7 @@ impl TurboQuantData {
183197
);
184198

185199
// Guaranteed to be 1-8 by the preceding power-of-2 and range checks.
186-
#[allow(clippy::cast_possible_truncation)]
200+
#[expect(clippy::cast_possible_truncation)]
187201
let bit_width = num_centroids.trailing_zeros() as u8;
188202
vortex_ensure!(
189203
(1..=8).contains(&bit_width),
@@ -193,44 +207,34 @@ impl TurboQuantData {
193207
// Norms dtype must match the element ptype of the Vector.
194208
let element_ptype = extension_element_ptype(ext)?;
195209
let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable);
196-
vortex_ensure!(
197-
*norms.dtype() == expected_norms_dtype,
198-
"norms dtype {} does not match expected {expected_norms_dtype} \
210+
vortex_ensure_eq!(
211+
*norms.dtype(),
212+
expected_norms_dtype,
213+
"norms dtype does not match expected {expected_norms_dtype} \
199214
(must match Vector element type)",
200-
norms.dtype()
201215
);
202216

203217
// Centroids are always f32 regardless of element type.
204-
let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable);
205-
vortex_ensure!(
206-
*centroids.dtype() == f32_nn,
207-
"centroids dtype {} must be non-nullable f32",
208-
centroids.dtype()
209-
);
210-
211-
// Row count consistency.
212-
vortex_ensure!(
213-
codes.len() == num_rows,
214-
"codes length {} does not match norms length {num_rows}",
215-
codes.len()
218+
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
219+
vortex_ensure_eq!(
220+
*centroids.dtype(),
221+
centroids_dtype,
222+
"centroids dtype must be non-nullable f32",
216223
);
217224

218225
// Rotation signs count must be 3 * padded_dim.
219226
let padded_dim = dimension.next_power_of_two() as usize;
220-
vortex_ensure!(
221-
rotation_signs.len() == 3 * padded_dim,
222-
"rotation_signs length {} does not match expected 3 * {padded_dim} = {}",
227+
vortex_ensure_eq!(
223228
rotation_signs.len(),
224-
3 * padded_dim
229+
3 * padded_dim,
230+
"rotation_signs length does not match expected 3 * {padded_dim}",
225231
);
226232

227233
Ok(())
228234
}
229235

230-
/// The vector dimension `d`, as stored in the [`Vector`] extension dtype's
231-
/// `FixedSizeList` storage.
232-
///
233-
/// [`Vector`]: crate::vector::Vector
236+
/// The vector dimension `d`, as stored in the [`Vector`](crate::vector::Vector) extension
237+
/// dtype's `FixedSizeList` storage.
234238
pub fn dimension(&self) -> u32 {
235239
self.dimension
236240
}
@@ -248,12 +252,6 @@ impl TurboQuantData {
248252
self.dimension.next_power_of_two()
249253
}
250254

251-
fn slot(&self, idx: usize) -> &ArrayRef {
252-
self.slots[idx]
253-
.as_ref()
254-
.vortex_expect("required slot is None")
255-
}
256-
257255
/// The quantized codes child (`FixedSizeListArray<u8>`, one row per vector).
258256
pub fn codes(&self) -> &ArrayRef {
259257
self.slot(Slot::Codes as usize)
@@ -278,19 +276,10 @@ impl TurboQuantData {
278276
pub fn rotation_signs(&self) -> &ArrayRef {
279277
self.slot(Slot::RotationSigns as usize)
280278
}
281-
}
282279

283-
/// Derive `bit_width` from the centroids array length.
284-
///
285-
/// Returns 0 for empty centroids (degenerate array), otherwise `log2(centroids.len())`.
286-
fn derive_bit_width(centroids: &ArrayRef) -> u8 {
287-
if centroids.is_empty() {
288-
0
289-
} else {
290-
// Guaranteed to be 0-8 by validate().
291-
#[allow(clippy::cast_possible_truncation)]
292-
{
293-
centroids.len().trailing_zeros() as u8
294-
}
280+
fn slot(&self, idx: usize) -> &ArrayRef {
281+
self.slots[idx]
282+
.as_ref()
283+
.vortex_expect("required slot is None")
295284
}
296285
}

vortex-tensor/src/encodings/turboquant/array/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pub(crate) mod data;
88
pub(crate) mod metadata;
99
pub(crate) mod slots;
1010

11-
pub(crate) mod scheme;
12-
1311
pub(crate) mod centroids;
1412
pub(crate) mod rotation;
13+
14+
pub(crate) mod scheme;

0 commit comments

Comments
 (0)