Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vortex-turboquant/src/scalar_fns/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,21 @@ fn build_empty_vector(
})
}

/// Borrowed bundle of the per-array decode inputs passed to the typed inner loop.
///
/// Packaged as a struct rather than positional arguments because `decode_typed` runs through
/// [`vortex_array::match_each_float_ptype!`] which expands once per supported element ptype.
/// Each expansion takes the same set of inputs, and the struct keeps the call site short.
struct DecodeInputs<'a> {
/// TurboQuant metadata recovered from the input extension dtype.
metadata: &'a TurboQuantMetadata,
/// SORF transform reconstructed from `metadata.seed` and `metadata.num_rounds`.
sorf_matrix: &'a SorfMatrix,
/// Centroid codebook for `(padded_dim, bit_width)`, in f32.
centroids: &'a [f32],
/// Per-row stored L2 norm of the original input vector, in the element ptype.
norms: &'a PrimitiveArray,
/// Flat per-row centroid indices, `num_vectors * padded_dim` bytes.
codes: &'a PrimitiveArray,
}

Expand Down
4 changes: 3 additions & 1 deletion vortex-turboquant/src/sorf/splitmix64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9;
/// Second SplitMix64 mixing multiplier from the reference implementation.
const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB;

/// Frozen local SplitMix64 stream used to define SORF sign diagonals.
/// Frozen local SplitMix64 stream used to define SORF sign diagonals. Bit-identical to the
/// reference implementation linked at the module top, which makes the sign stream part of the
/// encoding's wire contract.
pub(crate) struct SplitMix64 {
state: u64,
}
Expand Down
8 changes: 8 additions & 0 deletions vortex-turboquant/src/vector/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Vector-side helpers: normalization, quantization, and physical storage layout.

pub(crate) mod normalize;
pub(crate) mod quantize;
pub(crate) mod storage;
Expand All @@ -9,6 +11,12 @@ use vortex_error::VortexResult;
use vortex_error::vortex_err;

/// Compute the padded SORF dimension for an original vector dimension.
///
/// The SORF transform requires a power-of-two width, so non-power-of-two input dimensions are
/// padded with zeros up to the next power of two. The padded dimension is stored implicitly via
/// [`TurboQuantMetadata::dimensions`](crate::TurboQuantMetadata) plus the codes child's
/// `FixedSizeList` width and recovered at decode time via this function. Returns an error when
/// the next power of two overflows the input integer type.
pub(crate) fn tq_padded_dim(dimensions: u32) -> VortexResult<usize> {
let padded_dim = dimensions
.checked_next_power_of_two()
Expand Down
17 changes: 11 additions & 6 deletions vortex-turboquant/src/vector/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the
//! field-level validity records that those rows were not quantized.
//!
//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than the
//! struct validity, for example after a generic mask only updates the struct validity, but each
//! child must be valid wherever the struct row is valid.
//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than
//! the struct validity (for example after a generic mask only updates the struct validity), but
//! each child must be valid wherever the struct row is valid.

use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
Expand Down Expand Up @@ -48,14 +48,19 @@ pub(crate) const NORMS_FIELD: &str = "norms";
/// Name of the stored quantized-code child.
pub(crate) const CODES_FIELD: &str = "codes";

/// Parsed TurboQuant storage arrays.
///
/// We use this as a helper struct for working with a TurboQuant extension array.
/// Executed storage children of a TurboQuant extension array plus the authoritative outer
/// struct validity. Every child is row-aligned to `len` and every child's validity covers
/// `vector_validity`.
pub(crate) struct TurboQuantParsedStorage {
/// Metadata recovered from the input extension dtype.
pub(crate) metadata: TurboQuantMetadata,
/// Authoritative row validity for the quantized vectors, taken from the outer struct.
pub(crate) vector_validity: Validity,
/// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`.
pub(crate) norms: PrimitiveArray,
/// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long.
pub(crate) codes: PrimitiveArray,
/// Row count.
pub(crate) len: usize,
}

Expand Down
18 changes: 13 additions & 5 deletions vortex-turboquant/src/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ use crate::vector::storage::CODES_FIELD;
use crate::vector::storage::NORMS_FIELD;
use crate::vector::tq_padded_dim;

/// TurboQuant logical extension type.
/// TurboQuant logical extension type. Per-array configuration lives in [`TurboQuantMetadata`].
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct TurboQuant;

/// Serialized metadata for a TurboQuant extension array.
/// Serialized metadata for a TurboQuant extension array. The fields together suffice to
/// reconstruct the SORF transform and centroid codebook at decode time.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct TurboQuantMetadata {
/// Original vector element type and stored norm type.
/// Original vector element ptype and stored row-norm ptype. Restricted to `f16` / `f32` /
/// `f64`.
pub element_ptype: PType,
/// Original vector dimension before SORF padding.
/// Original vector dimension before SORF padding to the next power of two.
pub dimensions: u32,
/// Bits per coordinate in the scalar quantizer codebook.
/// Bits per coordinate in the scalar quantizer codebook (`1..=8`).
pub bit_width: u8,
/// Seed used to derive the deterministic SORF transform.
pub seed: u64,
Expand Down Expand Up @@ -106,6 +108,8 @@ impl ExtVTable for TurboQuant {
}
}

/// Wire-format representation of [`TurboQuantMetadata`]. Field tags MUST NOT change once
/// shipped; new fields must use unused tags and remain optional.
#[derive(Clone, PartialEq, Message)]
struct TurboQuantMetadataProto {
#[prost(enumeration = "PType", tag = "1")]
Expand Down Expand Up @@ -158,6 +162,8 @@ pub(crate) fn tq_storage_dtype(
))
}

/// Validate [`TurboQuantMetadata`] invariants. Called on both serialize and deserialize so a
/// corrupted on-disk metadata block errors out rather than decoding into nonsense.
fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> {
vortex_ensure!(
metadata.dimensions >= MIN_DIMENSION,
Expand All @@ -175,6 +181,8 @@ fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> {
TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds).map(|_| ())
}

/// Validate that `dtype` matches the storage shape produced by [`tq_storage_dtype`] for
/// `metadata`. Called from [`TurboQuant::validate_dtype`].
fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> VortexResult<()> {
let DType::Struct(fields, _) = dtype else {
vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}");
Expand Down
Loading