Skip to content

Commit c3bd48c

Browse files
committed
change nullability semantics
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent e1a9726 commit c3bd48c

12 files changed

Lines changed: 499 additions & 143 deletions

File tree

vortex-turboquant/src/centroids.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
//!
1212
//! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this
1313
//! distribution.
14+
//!
15+
//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from
16+
//! `(padded_dim, bit_width)` and cached process-locally.
17+
//!
18+
//! The centroid model follows the random-rotation marginal used by the TurboQuant paper. This
19+
//! encoder applies a SORF-style structured rotation instead of a dense random Gaussian or
20+
//! orthogonal matrix, so paper-level error bounds should not be treated as verified for this
21+
//! implementation without separate empirical validation.
1422
1523
use std::sync::LazyLock;
1624

vortex-turboquant/src/lib.rs

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,60 +25,36 @@
2525
//! [`turboquant_pack()`]: crate::turboquant_pack
2626
//! [`turboquant_unpack()`]: crate::turboquant_unpack
2727
//!
28-
//! The full packed tree is:
28+
//! The packed storage is a row-aligned extension tree:
2929
//!
3030
//! ```text
3131
//! Extension<TurboQuant>(
3232
//! Struct {
33-
//! norms: Primitive<element_ptype>,
34-
//! codes: FixedSizeList<Primitive<u8>, padded_dim>,
33+
//! norms: Primitive<element_ptype, row_validity>,
34+
//! codes: FixedSizeList<Primitive<u8>, padded_dim, row_validity>,
3535
//! }
3636
//! )
3737
//! ```
3838
//!
39-
//! Row validity is stored on the `StructArray`, preserving original vector nulls. The `norms` and
40-
//! `codes` children are non-nullable and may contain deterministic placeholder values for null
41-
//! rows. Centroids are not stored; they are deterministically derived from the padded dimension and
42-
//! bit width, and cached process-locally.
43-
//!
4439
//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized
4540
//! directions are not guaranteed to have unit norm after scalar quantization and inverse rotation.
4641
//!
47-
//! The TurboQuant paper analyzes a full random orthogonal rotation. The current Vortex
48-
//! implementation instead uses a fixed 3-round Walsh-Hadamard-based structured transform with
49-
//! random sign diagonals generated by Vortex's frozen local SplitMix64 stream. This is a practical
50-
//! approximation chosen for encode/decode efficiency, and should be understood as an
51-
//! implementation choice rather than the exact construction used in the paper's proofs.
52-
//!
53-
//! The current encoding is also intentionally MSE-only. It does not yet implement the paper's QJL
54-
//! residual correction for unbiased inner-product estimation, and it still uses internal
55-
//! power-of-2 padding rather than the block decomposition proposed in RFC 0033.
56-
//!
57-
//! # Theoretical error bounds
42+
//! # Source map
5843
//!
59-
//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1
60-
//! guarantees normalized MSE distortion:
44+
//! Implementation details are documented next to the code that owns them:
6145
//!
62-
//! > `E[||x - x_hat||² / ||x||²] <= (√3 · π / 2) / 4^b`
46+
//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level
47+
//! validity for null vectors.
48+
//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor
49+
//! crate's null-row zeroing helper.
50+
//! - `vector/quantize.rs`: SORF rotation, centroid lookup, and why invalid rows are skipped rather
51+
//! than quantized.
52+
//! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching.
53+
//! - `sorf/`: the Walsh-Hadamard-based structured rotation and the stable SplitMix64 sign stream.
6354
//!
64-
//! | Bits | MSE bound | Quality |
65-
//! |------|------------|-------------------|
66-
//! | 1 | 6.80e-01 | Poor |
67-
//! | 2 | 1.70e-01 | Usable for ANN |
68-
//! | 3 | 4.25e-02 | Good |
69-
//! | 4 | 1.06e-02 | Very good |
70-
//! | 5 | 2.66e-03 | Excellent |
71-
//! | 6 | 6.64e-04 | Near-lossless |
72-
//! | 7 | 1.66e-04 | Near-lossless |
73-
//! | 8 | 4.15e-05 | Near-lossless |
74-
//!
75-
//! # Storage notes
76-
//!
77-
//! Each vector is logically stored as `padded_dim` u8 quantized codes plus one stored norm in the
78-
//! vector's element float type. Non-power-of-2 dimensions are padded to the next power of 2 for
79-
//! the structured rotation, which affects the storage size. Physical compression of those child
80-
//! arrays is left to the normal Vortex compressor rather than implemented as a TurboQuant-specific
81-
//! compressor scheme.
55+
//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL
56+
//! residual correction for unbiased inner-product estimation, and it still uses internal
57+
//! power-of-2 padding rather than the block decomposition proposed in RFC 0033.
8258
8359
mod centroids;
8460
mod config;

vortex-turboquant/src/sorf/transform.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
//! approximation to a random orthogonal matrix using random sign diagonals interleaved with the
88
//! Fast Walsh-Hadamard Transform (FWHT).
99
//!
10+
//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf
11+
//!
1012
//! For `k` rounds, the transform is `norm * H * D_k * ... * H * D_1 * x`, where `D_1` is the
1113
//! first sign diagonal applied. The number of rounds is configurable (typically 3). Each round
1214
//! applies a random sign diagonal `D_i` and then the Hadamard matrix `H`, giving O(d log d) cost
1315
//! per matrix-vector product instead of the O(d^2) cost of a dense orthogonal matrix.
1416
//!
15-
//! Vortex defines those sign diagonals using a frozen local SplitMix64 stream rather than an
17+
//! This implementation defines those sign diagonals using a frozen local SplitMix64 stream rather
18+
//! than an
1619
//! external RNG crate. The contract is:
1720
//!
1821
//! - state is a single `u64` seed,
@@ -22,10 +25,13 @@
2225
//! - each generated `u64` contributes 64 signs in least-significant-bit-first order,
2326
//! - bit `1` means `+1` and bit `0` means `-1`.
2427
//!
25-
//! This makes SORF sign generation stable as a Vortex format contract even if external RNG
28+
//! This makes SORF sign generation stable as an extension format contract even if external RNG
2629
//! implementations change.
2730
//!
28-
//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf
31+
//! This transform is the crate's practical structured-rotation choice for TurboQuant. It is not
32+
//! the dense random Gaussian or orthogonal matrix used by some theoretical analyses, so theoretical
33+
//! bounds from those models need separate validation before being presented as implementation
34+
//! guarantees.
2935
//!
3036
//! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2
3137
//! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n)

vortex-turboquant/src/tests/malformed.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use rstest::rstest;
45
use vortex_array::IntoArray;
56
use vortex_array::VortexSessionExecute;
67
use vortex_array::arrays::ExtensionArray;
78
use vortex_array::arrays::FixedSizeListArray;
89
use vortex_array::arrays::PrimitiveArray;
910
use vortex_array::arrays::StructArray;
1011
use vortex_array::dtype::FieldNames;
12+
use vortex_array::dtype::Nullability;
1113
use vortex_array::dtype::PType;
1214
use vortex_array::validity::Validity;
1315
use vortex_buffer::Buffer;
@@ -17,6 +19,76 @@ use crate::TurboQuant;
1719
use crate::TurboQuantMetadata;
1820
use crate::turboquant_unpack;
1921

22+
fn validity_for_nullability(nullability: Nullability) -> Validity {
23+
match nullability {
24+
Nullability::NonNullable => Validity::NonNullable,
25+
Nullability::Nullable => Validity::AllValid,
26+
}
27+
}
28+
29+
#[rstest]
30+
#[case::nullable_norms_under_nonnullable_struct(
31+
Nullability::NonNullable,
32+
Nullability::Nullable,
33+
Nullability::NonNullable
34+
)]
35+
#[case::nullable_codes_under_nonnullable_struct(
36+
Nullability::NonNullable,
37+
Nullability::NonNullable,
38+
Nullability::Nullable
39+
)]
40+
#[case::nonnullable_norms_under_nullable_struct(
41+
Nullability::Nullable,
42+
Nullability::NonNullable,
43+
Nullability::Nullable
44+
)]
45+
#[case::nonnullable_codes_under_nullable_struct(
46+
Nullability::Nullable,
47+
Nullability::Nullable,
48+
Nullability::NonNullable
49+
)]
50+
fn unpack_rejects_row_nullability_mismatch(
51+
#[case] struct_nullability: Nullability,
52+
#[case] norms_nullability: Nullability,
53+
#[case] codes_nullability: Nullability,
54+
) {
55+
let session = test_session();
56+
let mut ctx = session.create_execution_ctx();
57+
let metadata = TurboQuantMetadata {
58+
element_ptype: PType::F32,
59+
dimensions: 128,
60+
bit_width: 1,
61+
seed: 42,
62+
num_rounds: 3,
63+
};
64+
let norms = PrimitiveArray::new::<f32>(
65+
Buffer::copy_from([1.0]),
66+
validity_for_nullability(norms_nullability),
67+
)
68+
.into_array();
69+
let codes = PrimitiveArray::new::<u8>(vec![0u8; 128], Validity::NonNullable);
70+
let codes = FixedSizeListArray::try_new(
71+
codes.into_array(),
72+
128,
73+
validity_for_nullability(codes_nullability),
74+
1,
75+
)
76+
.unwrap()
77+
.into_array();
78+
let storage = StructArray::try_new(
79+
FieldNames::from(["norms", "codes"]),
80+
vec![norms, codes],
81+
1,
82+
validity_for_nullability(struct_nullability),
83+
)
84+
.unwrap();
85+
let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())
86+
.unwrap()
87+
.into_array();
88+
89+
assert!(turboquant_unpack(tq, &mut ctx).is_err());
90+
}
91+
2092
#[test]
2193
#[should_panic(expected = "TurboQuant code exceeds centroid count")]
2294
fn unpack_panics_on_codes_outside_centroid_table() {

vortex-turboquant/src/tests/metadata.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
46
use prost::Message;
57
use rstest::rstest;
68
use vortex_array::dtype::DType;
@@ -11,10 +13,13 @@ use vortex_array::dtype::StructFields;
1113
use vortex_array::dtype::extension::ExtDType;
1214
use vortex_array::dtype::extension::ExtVTable;
1315
use vortex_error::VortexResult;
16+
use vortex_error::vortex_err;
1417

1518
use crate::TurboQuant;
1619
use crate::TurboQuantMetadata;
17-
use crate::vector::storage::tq_storage_dtype;
20+
use crate::vector::storage::CODES_FIELD;
21+
use crate::vector::storage::NORMS_FIELD;
22+
use crate::vector::tq_padded_dim;
1823

1924
#[derive(Clone, PartialEq, Message)]
2025
struct MetadataWire {
@@ -30,6 +35,28 @@ struct MetadataWire {
3035
num_rounds: u32,
3136
}
3237

38+
fn tq_storage_dtype(
39+
metadata: &TurboQuantMetadata,
40+
row_nullability: Nullability,
41+
) -> VortexResult<DType> {
42+
let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?)
43+
.map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?;
44+
Ok(DType::Struct(
45+
StructFields::new(
46+
FieldNames::from([NORMS_FIELD, CODES_FIELD]),
47+
vec![
48+
DType::Primitive(metadata.element_ptype, row_nullability),
49+
DType::FixedSizeList(
50+
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
51+
padded_dim,
52+
row_nullability,
53+
),
54+
],
55+
),
56+
row_nullability,
57+
))
58+
}
59+
3360
#[rstest]
3461
#[case::f16(PType::F16)]
3562
#[case::f32(PType::F32)]
@@ -94,7 +121,27 @@ fn dtype_validation_accepts_expected_storage() -> VortexResult<()> {
94121
num_rounds: 3,
95122
};
96123

97-
ExtDType::<TurboQuant>::try_new(metadata, tq_storage_dtype(&metadata)?)?;
124+
ExtDType::<TurboQuant>::try_new(
125+
metadata,
126+
tq_storage_dtype(&metadata, Nullability::Nullable)?,
127+
)?;
128+
Ok(())
129+
}
130+
131+
#[test]
132+
fn dtype_validation_accepts_nonnullable_storage() -> VortexResult<()> {
133+
let metadata = TurboQuantMetadata {
134+
element_ptype: PType::F32,
135+
dimensions: 129,
136+
bit_width: 2,
137+
seed: 42,
138+
num_rounds: 3,
139+
};
140+
141+
ExtDType::<TurboQuant>::try_new(
142+
metadata,
143+
tq_storage_dtype(&metadata, Nullability::NonNullable)?,
144+
)?;
98145
Ok(())
99146
}
100147

@@ -113,7 +160,7 @@ fn dtype_validation_rejects_malformed_storage() {
113160
vec![
114161
DType::Primitive(PType::F32, Nullability::Nullable),
115162
DType::FixedSizeList(
116-
DType::Primitive(PType::U8, Nullability::NonNullable).into(),
163+
DType::Primitive(PType::U8, Nullability::Nullable).into(),
117164
128,
118165
Nullability::NonNullable,
119166
),

0 commit comments

Comments
 (0)