Skip to content

Commit cca6e9e

Browse files
authored
Fix TurboQuant metadata to be protobuf (#7301)
## Summary Right now, the `TurboQuant` metadata is just a byte, this changes it to a prost message. ## Testing Small roundtrip tests. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 11d607e commit cca6e9e

File tree

3 files changed

+74
-14
lines changed

3 files changed

+74
-14
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Protobuf-backed metadata for TurboQuant encoding.
5+
6+
use prost::Message;
7+
use vortex_error::VortexResult;
8+
use vortex_error::vortex_ensure;
9+
use vortex_error::vortex_err;
10+
11+
/// Serialized metadata for TurboQuant arrays.
12+
#[derive(Clone, PartialEq, Message)]
13+
pub(super) struct TurboQuantMetadata {
14+
/// The number of bits per coordinate.
15+
#[prost(uint32, required, tag = "1")]
16+
bit_width: u32,
17+
}
18+
19+
impl TurboQuantMetadata {
20+
/// Creates metadata for the given bit width.
21+
pub(super) fn new(bit_width: u8) -> Self {
22+
Self {
23+
bit_width: u32::from(bit_width),
24+
}
25+
}
26+
27+
/// Returns the validated TurboQuant bit width.
28+
pub(super) fn bit_width(&self) -> VortexResult<u8> {
29+
let bit_width = u8::try_from(self.bit_width).map_err(|_| {
30+
vortex_err!(
31+
"TurboQuant bit_width must fit into u8, got {}",
32+
self.bit_width
33+
)
34+
})?;
35+
vortex_ensure!(
36+
bit_width <= 8,
37+
"bit_width is expected to be between 0 and 8, got {bit_width}"
38+
);
39+
40+
Ok(bit_width)
41+
}
42+
}
43+
44+
#[cfg(test)]
45+
mod tests {
46+
use prost::Message;
47+
use rstest::rstest;
48+
use vortex_error::VortexResult;
49+
50+
use super::TurboQuantMetadata;
51+
52+
#[rstest]
53+
#[case(0)]
54+
#[case(3)]
55+
#[case(8)]
56+
fn protobuf_metadata_roundtrip(#[case] bit_width: u8) -> VortexResult<()> {
57+
let bytes = TurboQuantMetadata::new(bit_width).encode_to_vec();
58+
assert_eq!(
59+
TurboQuantMetadata::decode(bytes.as_slice())?.bit_width()?,
60+
bit_width
61+
);
62+
63+
Ok(())
64+
}
65+
}

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ pub use array::scheme::TurboQuantScheme;
9797

9898
pub(crate) mod compute;
9999

100+
mod metadata;
101+
100102
mod vtable;
101103
pub use vtable::TurboQuant;
102104
pub use vtable::TurboQuantArray;

vortex-tensor/src/encodings/turboquant/vtable.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::hash::Hash;
77
use std::hash::Hasher;
88
use std::sync::Arc;
99

10+
use prost::Message;
1011
use vortex_array::Array;
1112
use vortex_array::ArrayEq;
1213
use vortex_array::ArrayHash;
@@ -40,6 +41,7 @@ use crate::encodings::turboquant::array::slots::Slot;
4041
use crate::encodings::turboquant::compute::rules::PARENT_KERNELS;
4142
use crate::encodings::turboquant::compute::rules::RULES;
4243
use crate::encodings::turboquant::decompress::execute_decompress;
44+
use crate::encodings::turboquant::metadata::TurboQuantMetadata;
4345
use crate::utils::tensor_element_ptype;
4446
use crate::utils::tensor_list_size;
4547
use crate::vector::Vector;
@@ -197,7 +199,9 @@ impl VTable for TurboQuant {
197199
}
198200

199201
fn serialize(array: ArrayView<'_, Self>) -> VortexResult<Option<Vec<u8>>> {
200-
Ok(Some(vec![array.bit_width]))
202+
Ok(Some(
203+
TurboQuantMetadata::new(array.bit_width).encode_to_vec(),
204+
))
201205
}
202206

203207
fn deserialize(
@@ -209,19 +213,8 @@ impl VTable for TurboQuant {
209213
children: &dyn ArrayChildren,
210214
_session: &VortexSession,
211215
) -> VortexResult<ArrayParts<Self>> {
212-
vortex_ensure_eq!(
213-
metadata.len(),
214-
1,
215-
"TurboQuant metadata must be exactly 1 byte, got {}",
216-
metadata.len()
217-
);
218-
vortex_ensure!(
219-
metadata[0] <= 8,
220-
"bit_width is expected to be between 0 and 8, got {}",
221-
metadata[0]
222-
);
223-
224-
let bit_width = metadata[0];
216+
let metadata = TurboQuantMetadata::decode(metadata)?;
217+
let bit_width = metadata.bit_width()?;
225218

226219
// bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with
227220
// bit_width == 0 would have zero centroids while codes reference centroid indices.

0 commit comments

Comments
 (0)