Skip to content

Commit 4eab406

Browse files
lwwmanningclaude
authored andcommitted
TurboQuant encoding for Vectors (#7167)
Lossy quantization for vector data (e.g., embeddings) based on TurboQuant (https://arxiv.org/abs/2504.19874). Supports both MSE-optimal and inner-product-optimal (Prod with QJL correction) variants at 1-8 bits per coordinate. Key components: - Single TurboQuant array encoding with optional QJL correction fields, storing quantized codes, norms, centroids, and rotation signs as children. - Structured Random Hadamard Transform (SRHT) for O(d log d) rotation, fully self-contained with no external linear algebra library. - Max-Lloyd centroid computation on Beta(d/2, d/2) distribution. - Approximate cosine similarity and dot product compute directly on quantized arrays without full decompression. - Pluggable TurboQuantScheme for BtrBlocks, exposed via WriteStrategyBuilder::with_vector_quantization(). - Benchmarks covering common embedding dimensions (128, 768, 1024, 1536). Also refactors CompressingStrategy to a single constructor, and adds vortex_tensor::initialize() for session registration of tensor types, encodings, and scalar functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-Authored-By: Will Manning <will@willmanning.io> Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 02b0949 commit 4eab406

File tree

27 files changed

+3634
-14
lines changed

27 files changed

+3634
-14
lines changed

Cargo.lock

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

_typos.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[default]
2-
extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui"]
2+
extend-ignore-identifiers-re = ["ffor", "FFOR", "FoR", "typ", "ratatui", "wht", "WHT"]
33
# We support a few common special comments to tell the checker to ignore sections of code
44
extend-ignore-re = [
55
"(#|//)\\s*spellchecker:ignore-next-line\\n.*", # Ignore the next line

vortex-file/src/strategy.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ use vortex_pco::Pco;
5656
use vortex_runend::RunEnd;
5757
use vortex_sequence::Sequence;
5858
use vortex_sparse::Sparse;
59+
#[cfg(feature = "unstable_encodings")]
60+
use vortex_tensor::encodings::turboquant::TurboQuant;
5961
use vortex_utils::aliases::hash_map::HashMap;
6062
use vortex_zigzag::ZigZag;
6163
#[cfg(feature = "zstd")]
@@ -104,6 +106,8 @@ pub static ALLOWED_ENCODINGS: LazyLock<ArrayRegistry> = LazyLock::new(|| {
104106
session.register(RunEnd);
105107
session.register(Sequence);
106108
session.register(Sparse);
109+
#[cfg(feature = "unstable_encodings")]
110+
session.register(TurboQuant);
107111
session.register(ZigZag);
108112

109113
#[cfg(feature = "zstd")]

vortex-layout/src/layouts/table.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@ impl TableStrategy {
8686
/// ```ignore
8787
/// # use std::sync::Arc;
8888
/// # use vortex_array::dtype::{field_path, Field, FieldPath};
89+
/// # use vortex_btrblocks::BtrBlocksCompressor;
8990
/// # use vortex_layout::layouts::compressed::CompressingStrategy;
9091
/// # use vortex_layout::layouts::flat::writer::FlatLayoutStrategy;
9192
/// # use vortex_layout::layouts::table::TableStrategy;
9293
///
93-
/// # use vortex_btrblocks::BtrBlocksCompressor;
9494
/// // A strategy for compressing data using the balanced BtrBlocks compressor.
95-
/// let compress = CompressingStrategy::new(FlatLayoutStrategy::default(), BtrBlocksCompressor::default());
95+
/// let compress =
96+
/// CompressingStrategy::new(FlatLayoutStrategy::default(), BtrBlocksCompressor::default());
9697
///
9798
/// // Our combined strategy uses no compression for validity buffers, BtrBlocks compression
9899
/// // for most columns, and stores a nested binary column uncompressed (flat) because it

vortex-tensor/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@ workspace = true
1919
[dependencies]
2020
vortex-array = { workspace = true }
2121
vortex-buffer = { workspace = true }
22+
vortex-compressor = { workspace = true }
2223
vortex-error = { workspace = true }
24+
vortex-fastlanes = { workspace = true }
2325
vortex-session = { workspace = true }
26+
vortex-utils = { workspace = true }
2427

28+
half = { workspace = true }
2529
itertools = { workspace = true }
2630
num-traits = { workspace = true }
2731
prost = { workspace = true }
32+
rand = { workspace = true }
2833

2934
[dev-dependencies]
35+
rand_distr = { workspace = true }
3036
rstest = { workspace = true }
31-
vortex-buffer = { workspace = true }

vortex-tensor/public-api.lock

Lines changed: 205 additions & 1 deletion
Large diffs are not rendered by default.

vortex-tensor/src/encodings/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@
77
// pub mod norm; // Unit-normalized vectors.
88
// pub mod spherical; // Spherical transform on unit-normalized vectors.
99

10-
// TODO(will):
11-
// pub mod turboquant;
10+
pub mod turboquant;
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! TurboQuant array definition: stores quantized coordinate codes, norms,
5+
//! centroids (codebook), rotation signs, and optional QJL correction fields.
6+
7+
use vortex_array::ArrayId;
8+
use vortex_array::ArrayRef;
9+
use vortex_array::dtype::DType;
10+
use vortex_array::stats::ArrayStats;
11+
use vortex_array::vtable;
12+
use vortex_error::VortexExpect;
13+
use vortex_error::VortexResult;
14+
use vortex_error::vortex_ensure;
15+
16+
/// Encoding marker type for TurboQuant.
17+
#[derive(Clone, Debug)]
18+
pub struct TurboQuant;
19+
20+
impl TurboQuant {
21+
pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant");
22+
}
23+
24+
vtable!(TurboQuant, TurboQuant, TurboQuantData);
25+
26+
/// Protobuf metadata for TurboQuant encoding.
27+
#[derive(Clone, prost::Message)]
28+
pub struct TurboQuantMetadata {
29+
/// Vector dimension d.
30+
#[prost(uint32, tag = "1")]
31+
pub dimension: u32,
32+
/// MSE bits per coordinate (1-8).
33+
#[prost(uint32, tag = "2")]
34+
pub bit_width: u32,
35+
/// Whether QJL correction children are present.
36+
#[prost(bool, tag = "3")]
37+
pub has_qjl: bool,
38+
}
39+
40+
/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased
41+
/// inner product estimation. When present, adds 3 additional children.
42+
#[derive(Clone, Debug)]
43+
pub struct QjlCorrection {
44+
/// Sign bits: `BoolArray`, length `num_rows * padded_dim`.
45+
pub(crate) signs: ArrayRef,
46+
/// Residual norms: `PrimitiveArray<f32>`, length `num_rows`.
47+
pub(crate) residual_norms: ArrayRef,
48+
/// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order).
49+
pub(crate) rotation_signs: ArrayRef,
50+
}
51+
52+
impl QjlCorrection {
53+
/// The QJL sign bits.
54+
pub fn signs(&self) -> &ArrayRef {
55+
&self.signs
56+
}
57+
58+
/// The residual norms.
59+
pub fn residual_norms(&self) -> &ArrayRef {
60+
&self.residual_norms
61+
}
62+
63+
/// The QJL rotation signs (BoolArray, inverse application order).
64+
pub fn rotation_signs(&self) -> &ArrayRef {
65+
&self.rotation_signs
66+
}
67+
}
68+
69+
/// Slot positions for TurboQuantArray children.
70+
#[repr(usize)]
71+
#[derive(Clone, Copy, Debug)]
72+
pub(crate) enum Slot {
73+
Codes = 0,
74+
Norms = 1,
75+
Centroids = 2,
76+
RotationSigns = 3,
77+
QjlSigns = 4,
78+
QjlResidualNorms = 5,
79+
QjlRotationSigns = 6,
80+
}
81+
82+
impl Slot {
83+
pub(crate) const COUNT: usize = 7;
84+
85+
pub(crate) fn name(self) -> &'static str {
86+
match self {
87+
Self::Codes => "codes",
88+
Self::Norms => "norms",
89+
Self::Centroids => "centroids",
90+
Self::RotationSigns => "rotation_signs",
91+
Self::QjlSigns => "qjl_signs",
92+
Self::QjlResidualNorms => "qjl_residual_norms",
93+
Self::QjlRotationSigns => "qjl_rotation_signs",
94+
}
95+
}
96+
97+
pub(crate) fn from_index(idx: usize) -> Self {
98+
match idx {
99+
0 => Self::Codes,
100+
1 => Self::Norms,
101+
2 => Self::Centroids,
102+
3 => Self::RotationSigns,
103+
4 => Self::QjlSigns,
104+
5 => Self::QjlResidualNorms,
105+
6 => Self::QjlRotationSigns,
106+
_ => vortex_error::vortex_panic!("invalid slot index {idx}"),
107+
}
108+
}
109+
}
110+
111+
/// TurboQuant array.
112+
///
113+
/// Slots (always present):
114+
/// - 0: `codes` — `FixedSizeListArray<u8>` (quantized indices, list_size=padded_dim)
115+
/// - 1: `norms` — `PrimitiveArray<f32>` (one per vector row)
116+
/// - 2: `centroids` — `PrimitiveArray<f32>` (codebook, length 2^bit_width)
117+
/// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order)
118+
///
119+
/// Optional QJL slots (None when MSE-only):
120+
/// - 4: `qjl_signs` — `FixedSizeListArray<u8>` (num_rows * padded_dim, 1-bit)
121+
/// - 5: `qjl_residual_norms` — `PrimitiveArray<f32>` (one per row)
122+
/// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation)
123+
#[derive(Clone, Debug)]
124+
pub struct TurboQuantData {
125+
pub(crate) dtype: DType,
126+
pub(crate) slots: Vec<Option<ArrayRef>>,
127+
pub(crate) dimension: u32,
128+
pub(crate) bit_width: u8,
129+
pub(crate) stats_set: ArrayStats,
130+
}
131+
132+
impl TurboQuantData {
133+
/// Build a TurboQuant array with MSE-only encoding (no QJL correction).
134+
#[allow(clippy::too_many_arguments)]
135+
pub fn try_new_mse(
136+
dtype: DType,
137+
codes: ArrayRef,
138+
norms: ArrayRef,
139+
centroids: ArrayRef,
140+
rotation_signs: ArrayRef,
141+
dimension: u32,
142+
bit_width: u8,
143+
) -> VortexResult<Self> {
144+
vortex_ensure!(
145+
(1..=8).contains(&bit_width),
146+
"MSE bit_width must be 1-8, got {bit_width}"
147+
);
148+
let mut slots = vec![None; Slot::COUNT];
149+
slots[Slot::Codes as usize] = Some(codes);
150+
slots[Slot::Norms as usize] = Some(norms);
151+
slots[Slot::Centroids as usize] = Some(centroids);
152+
slots[Slot::RotationSigns as usize] = Some(rotation_signs);
153+
Ok(Self {
154+
dtype,
155+
slots,
156+
dimension,
157+
bit_width,
158+
stats_set: Default::default(),
159+
})
160+
}
161+
162+
/// Build a TurboQuant array with QJL correction (MSE + QJL).
163+
#[allow(clippy::too_many_arguments)]
164+
pub fn try_new_qjl(
165+
dtype: DType,
166+
codes: ArrayRef,
167+
norms: ArrayRef,
168+
centroids: ArrayRef,
169+
rotation_signs: ArrayRef,
170+
qjl: QjlCorrection,
171+
dimension: u32,
172+
bit_width: u8,
173+
) -> VortexResult<Self> {
174+
vortex_ensure!(
175+
(1..=8).contains(&bit_width),
176+
"MSE bit_width must be 1-8, got {bit_width}"
177+
);
178+
let mut slots = vec![None; Slot::COUNT];
179+
slots[Slot::Codes as usize] = Some(codes);
180+
slots[Slot::Norms as usize] = Some(norms);
181+
slots[Slot::Centroids as usize] = Some(centroids);
182+
slots[Slot::RotationSigns as usize] = Some(rotation_signs);
183+
slots[Slot::QjlSigns as usize] = Some(qjl.signs);
184+
slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms);
185+
slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs);
186+
Ok(Self {
187+
dtype,
188+
slots,
189+
dimension,
190+
bit_width,
191+
stats_set: Default::default(),
192+
})
193+
}
194+
195+
/// The vector dimension d.
196+
pub fn dimension(&self) -> u32 {
197+
self.dimension
198+
}
199+
200+
/// MSE bits per coordinate.
201+
pub fn bit_width(&self) -> u8 {
202+
self.bit_width
203+
}
204+
205+
/// Padded dimension (next power of 2 >= dimension).
206+
pub fn padded_dim(&self) -> u32 {
207+
self.dimension.next_power_of_two()
208+
}
209+
210+
/// Whether QJL correction is present.
211+
pub fn has_qjl(&self) -> bool {
212+
self.slots[Slot::QjlSigns as usize].is_some()
213+
}
214+
215+
fn slot(&self, idx: usize) -> &ArrayRef {
216+
self.slots[idx]
217+
.as_ref()
218+
.vortex_expect("required slot is None")
219+
}
220+
221+
/// The quantized codes child (FixedSizeListArray).
222+
pub fn codes(&self) -> &ArrayRef {
223+
self.slot(Slot::Codes as usize)
224+
}
225+
226+
/// The norms child (`PrimitiveArray<f32>`).
227+
pub fn norms(&self) -> &ArrayRef {
228+
self.slot(Slot::Norms as usize)
229+
}
230+
231+
/// The centroids (codebook) child (`PrimitiveArray<f32>`).
232+
pub fn centroids(&self) -> &ArrayRef {
233+
self.slot(Slot::Centroids as usize)
234+
}
235+
236+
/// The MSE rotation signs child (BitPackedArray, length 3 * padded_dim).
237+
pub fn rotation_signs(&self) -> &ArrayRef {
238+
self.slot(Slot::RotationSigns as usize)
239+
}
240+
241+
/// The optional QJL correction fields, reconstructed from slots.
242+
pub fn qjl(&self) -> Option<QjlCorrection> {
243+
Some(QjlCorrection {
244+
signs: self.slots[Slot::QjlSigns as usize].clone()?,
245+
residual_norms: self.slots[Slot::QjlResidualNorms as usize].clone()?,
246+
rotation_signs: self.slots[Slot::QjlRotationSigns as usize].clone()?,
247+
})
248+
}
249+
250+
/// Set the QJL correction fields on this array.
251+
pub(crate) fn set_qjl(&mut self, qjl: QjlCorrection) {
252+
self.slots[Slot::QjlSigns as usize] = Some(qjl.signs);
253+
self.slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms);
254+
self.slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs);
255+
}
256+
}

0 commit comments

Comments
 (0)