Skip to content

Commit 849d6a5

Browse files
committed
fix cosine similarity and dot product
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 3d7dfed commit 849d6a5

4 files changed

Lines changed: 198 additions & 59 deletions

File tree

vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,60 +27,56 @@
2727
//! distortion: at 4 bits the error is typically < 0.1, at 8 bits < 0.001.
2828
//!
2929
//! For approximate nearest neighbor (ANN) search, biased-but-accurate ranking is
30-
//! usually sufficient the relative ordering of cosine similarities is preserved
30+
//! usually sufficient -- the relative ordering of cosine similarities is preserved
3131
//! even if the absolute values have bounded error.
3232
33+
use num_traits::FromPrimitive;
34+
use num_traits::Zero;
3335
use vortex_array::ArrayRef;
3436
use vortex_array::ArrayView;
3537
use vortex_array::ExecutionCtx;
3638
use vortex_array::IntoArray;
3739
use vortex_array::arrays::FixedSizeListArray;
3840
use vortex_array::arrays::PrimitiveArray;
41+
use vortex_array::match_each_float_ptype;
3942
use vortex_array::validity::Validity;
4043
use vortex_buffer::BufferMut;
4144
use vortex_error::VortexResult;
42-
use vortex_error::vortex_ensure;
45+
use vortex_error::vortex_ensure_eq;
4346

4447
use crate::encodings::turboquant::TurboQuant;
48+
use crate::utils::extension_element_ptype;
4549

46-
/// Shared helper: read codes, norms, and centroids from two TurboQuant arrays,
47-
/// then compute per-row quantized unit-norm dot products.
50+
/// Convert an f32 value to `T`, returning `T::zero()` if the conversion fails.
4851
///
49-
/// Both arrays must have the same dimension (vector length) and row count.
50-
/// They may have different codebooks (e.g., different bit widths), in which
51-
/// case each array's own centroids are used for its code lookups.
52+
/// This helper exists because `half::f16` has an inherent `from_f32` method that shadows
53+
/// the [`FromPrimitive`] trait method, causing compilation errors when used inside
54+
/// [`match_each_float_ptype!`].
55+
#[inline]
56+
fn f32_to_t<T: FromPrimitive + Zero>(v: f32) -> T {
57+
FromPrimitive::from_f32(v).unwrap_or_else(T::zero)
58+
}
59+
60+
/// Compute the per-row unit-norm dot products in f32 (centroids are always f32).
5261
///
53-
/// Returns `(norms_a, norms_b, unit_dots)` where `unit_dots[i]` is the dot product
54-
/// of the unit-norm quantized vectors for row i.
55-
fn quantized_unit_dots(
56-
lhs: ArrayView<TurboQuant>,
57-
rhs: ArrayView<TurboQuant>,
62+
/// Returns a `Vec<f32>` of length `num_rows`.
63+
fn compute_unit_dots(
64+
lhs: &ArrayView<TurboQuant>,
65+
rhs: &ArrayView<TurboQuant>,
5866
ctx: &mut ExecutionCtx,
59-
) -> VortexResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
60-
vortex_ensure!(
61-
lhs.dimension() == rhs.dimension(),
62-
"TurboQuant quantized dot product requires matching dimensions, got {} and {}",
63-
lhs.dimension(),
64-
rhs.dimension()
65-
);
66-
67+
) -> VortexResult<Vec<f32>> {
6768
let pd = lhs.padded_dim() as usize;
6869
let num_rows = lhs.norms().len();
6970

70-
let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?;
71-
let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?;
72-
let na = lhs_norms.as_slice::<f32>();
73-
let nb = rhs_norms.as_slice::<f32>();
74-
7571
let lhs_codes_fsl: FixedSizeListArray = lhs.codes().clone().execute(ctx)?;
7672
let rhs_codes_fsl: FixedSizeListArray = rhs.codes().clone().execute(ctx)?;
7773
let lhs_codes = lhs_codes_fsl.elements().to_canonical()?.into_primitive();
7874
let rhs_codes = rhs_codes_fsl.elements().to_canonical()?.into_primitive();
7975
let ca = lhs_codes.as_slice::<u8>();
8076
let cb = rhs_codes.as_slice::<u8>();
8177

82-
// Read centroids from both arrays — they may have different codebooks
83-
// (e.g., different bit widths).
78+
// Read centroids from both arrays. They may have different codebooks (e.g., different bit
79+
// widths).
8480
let lhs_centroids: PrimitiveArray = lhs.centroids().clone().execute(ctx)?;
8581
let rhs_centroids: PrimitiveArray = rhs.centroids().clone().execute(ctx)?;
8682
let cl = lhs_centroids.as_slice::<f32>();
@@ -98,49 +94,75 @@ fn quantized_unit_dots(
9894
dots.push(dot);
9995
}
10096

101-
Ok((na.to_vec(), nb.to_vec(), dots))
97+
Ok(dots)
10298
}
10399

104100
/// Compute approximate cosine similarity for all rows between two TurboQuant
105101
/// arrays (same rotation matrix and codebook) without full decompression.
102+
///
103+
/// Since TurboQuant stores unit-normalized rotated vectors, the dot product of the quantized
104+
/// codes directly approximates cosine similarity without needing the stored norms.
105+
///
106+
/// The output dtype matches the Vector's element type (f16, f32, or f64).
106107
pub fn cosine_similarity_quantized_column(
107108
lhs: ArrayView<TurboQuant>,
108109
rhs: ArrayView<TurboQuant>,
109110
ctx: &mut ExecutionCtx,
110111
) -> VortexResult<ArrayRef> {
111-
let num_rows = lhs.norms().len();
112-
let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?;
112+
vortex_ensure_eq!(
113+
lhs.dimension(),
114+
rhs.dimension(),
115+
"TurboQuant quantized dot product requires matching dimensions",
116+
);
113117

114-
let mut result = BufferMut::<f32>::with_capacity(num_rows);
115-
for row in 0..num_rows {
116-
if na[row] == 0.0 || nb[row] == 0.0 {
117-
result.push(0.0);
118-
} else {
119-
// Unit-norm dot product IS the cosine similarity.
120-
result.push(dots[row]);
121-
}
122-
}
118+
let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?;
119+
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
123120

124-
Ok(PrimitiveArray::new::<f32>(result.freeze(), Validity::NonNullable).into_array())
121+
// The unit-norm dot product IS the cosine similarity. Cast from f32 to the native type.
122+
match_each_float_ptype!(element_ptype, |T| {
123+
let mut result = BufferMut::<T>::with_capacity(dots.len());
124+
for &dot in &dots {
125+
result.push(f32_to_t(dot));
126+
}
127+
Ok(PrimitiveArray::new::<T>(result.freeze(), Validity::NonNullable).into_array())
128+
})
125129
}
126130

127131
/// Compute approximate dot product for all rows between two TurboQuant
128132
/// arrays (same rotation matrix and codebook) without full decompression.
129133
///
130-
/// `dot_product(a, b) ≈ ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])`
134+
/// `dot_product(a, b) = ||a|| * ||b|| * sum(c[code_a[j]] * c[code_b[j]])`
135+
///
136+
/// The output dtype matches the Vector's element type (f16, f32, or f64).
131137
pub fn dot_product_quantized_column(
132138
lhs: ArrayView<TurboQuant>,
133139
rhs: ArrayView<TurboQuant>,
134140
ctx: &mut ExecutionCtx,
135141
) -> VortexResult<ArrayRef> {
142+
vortex_ensure_eq!(
143+
lhs.dimension(),
144+
rhs.dimension(),
145+
"TurboQuant quantized dot product requires matching dimensions",
146+
);
147+
148+
let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?;
149+
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
136150
let num_rows = lhs.norms().len();
137-
let (na, nb, dots) = quantized_unit_dots(lhs, rhs, ctx)?;
138151

139-
let mut result = BufferMut::<f32>::with_capacity(num_rows);
140-
for row in 0..num_rows {
141-
// Scale the unit-norm dot product by both norms to get the actual dot product.
142-
result.push(na[row] * nb[row] * dots[row]);
143-
}
152+
let lhs_norms: PrimitiveArray = lhs.norms().clone().execute(ctx)?;
153+
let rhs_norms: PrimitiveArray = rhs.norms().clone().execute(ctx)?;
154+
155+
// Scale the f32 unit-norm dot product by native-precision norms.
156+
match_each_float_ptype!(element_ptype, |T| {
157+
let na = lhs_norms.as_slice::<T>();
158+
let nb = rhs_norms.as_slice::<T>();
159+
160+
let mut result = BufferMut::<T>::with_capacity(num_rows);
161+
for row in 0..num_rows {
162+
let dot_t: T = f32_to_t(dots[row]);
163+
result.push(na[row] * nb[row] * dot_t);
164+
}
144165

145-
Ok(PrimitiveArray::new::<f32>(result.freeze(), Validity::NonNullable).into_array())
166+
Ok(PrimitiveArray::new::<T>(result.freeze(), Validity::NonNullable).into_array())
167+
})
146168
}

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,6 @@
9090
//! assert!(encoded.nbytes() < 51200);
9191
//! ```
9292
93-
use vortex_array::session::ArraySessionExt;
94-
use vortex_session::VortexSession;
95-
96-
/// Initialize the TurboQuant encoding in the given session.
97-
pub fn initialize(session: &mut VortexSession) {
98-
session.arrays().register(TurboQuant);
99-
}
100-
10193
mod array;
10294
pub use array::data::TurboQuantData;
10395
pub use array::scheme::TurboQuantScheme;

vortex-tensor/src/encodings/turboquant/tests.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,3 +798,114 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> {
798798
}
799799
Ok(())
800800
}
801+
802+
// -----------------------------------------------------------------------
803+
// Serde roundtrip tests
804+
// -----------------------------------------------------------------------
805+
806+
/// Verify that a TurboQuant array survives serialize/deserialize.
807+
#[test]
808+
fn serde_roundtrip() -> VortexResult<()> {
809+
use vortex_array::ArrayContext;
810+
use vortex_array::ArrayEq;
811+
use vortex_array::Precision;
812+
use vortex_array::serde::SerializeOptions;
813+
use vortex_array::serde::SerializedArray;
814+
use vortex_array::session::ArraySessionExt;
815+
use vortex_buffer::ByteBufferMut;
816+
use vortex_fastlanes::BitPacked;
817+
use vortex_session::registry::ReadContext;
818+
819+
let fsl = make_fsl(20, 128, 42);
820+
let ext = make_vector_ext(&fsl);
821+
let config = TurboQuantConfig {
822+
bit_width: 3,
823+
seed: Some(123),
824+
};
825+
let mut ctx = SESSION.create_execution_ctx();
826+
let encoded = turboquant_encode(&ext, &config, &mut ctx)?;
827+
828+
let dtype = encoded.dtype().clone();
829+
let len = encoded.len();
830+
831+
// Serialize.
832+
let array_ctx = ArrayContext::empty();
833+
let serialized = encoded.serialize(&array_ctx, &SerializeOptions::default())?;
834+
835+
let mut concat = ByteBufferMut::empty();
836+
for buf in serialized {
837+
concat.extend_from_slice(buf.as_ref());
838+
}
839+
840+
// Deserialize. The session needs TurboQuant and BitPacked (for rotation signs) registered.
841+
let serde_session = VortexSession::empty().with::<ArraySession>();
842+
serde_session.arrays().register(TurboQuant);
843+
serde_session.arrays().register(BitPacked);
844+
845+
let parts = SerializedArray::try_from(concat.freeze())?;
846+
let decoded = parts.decode(
847+
&dtype,
848+
len,
849+
&ReadContext::new(array_ctx.to_ids()),
850+
&serde_session,
851+
)?;
852+
853+
assert!(
854+
decoded.array_eq(&encoded, Precision::Value),
855+
"serde roundtrip did not preserve array equality"
856+
);
857+
Ok(())
858+
}
859+
860+
/// Verify that a degenerate (empty) TurboQuant array survives serialize/deserialize.
861+
#[test]
862+
fn serde_roundtrip_empty() -> VortexResult<()> {
863+
use vortex_array::ArrayContext;
864+
use vortex_array::ArrayEq;
865+
use vortex_array::Precision;
866+
use vortex_array::serde::SerializeOptions;
867+
use vortex_array::serde::SerializedArray;
868+
use vortex_array::session::ArraySessionExt;
869+
use vortex_buffer::ByteBufferMut;
870+
use vortex_fastlanes::BitPacked;
871+
use vortex_session::registry::ReadContext;
872+
873+
let fsl = make_fsl(0, 128, 42);
874+
let ext = make_vector_ext(&fsl);
875+
let config = TurboQuantConfig {
876+
bit_width: 2,
877+
seed: Some(123),
878+
};
879+
let mut ctx = SESSION.create_execution_ctx();
880+
let encoded = turboquant_encode(&ext, &config, &mut ctx)?;
881+
assert_eq!(encoded.len(), 0);
882+
883+
let dtype = encoded.dtype().clone();
884+
let len = encoded.len();
885+
886+
let array_ctx = ArrayContext::empty();
887+
let serialized = encoded.serialize(&array_ctx, &SerializeOptions::default())?;
888+
889+
let mut concat = ByteBufferMut::empty();
890+
for buf in serialized {
891+
concat.extend_from_slice(buf.as_ref());
892+
}
893+
894+
let serde_session = VortexSession::empty().with::<ArraySession>();
895+
serde_session.arrays().register(TurboQuant);
896+
serde_session.arrays().register(BitPacked);
897+
898+
let parts = SerializedArray::try_from(concat.freeze())?;
899+
let decoded = parts.decode(
900+
&dtype,
901+
len,
902+
&ReadContext::new(array_ctx.to_ids()),
903+
&serde_session,
904+
)?;
905+
906+
assert!(
907+
decoded.array_eq(&encoded, Precision::Value),
908+
"serde roundtrip did not preserve array equality"
909+
);
910+
Ok(())
911+
}

vortex-tensor/src/encodings/turboquant/vtable.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl TurboQuant {
7474

7575
/// Creates a new [`TurboQuantArray`].
7676
///
77-
/// Internallay calls [`TurboQuantData::try_new`].
77+
/// Internally calls [`TurboQuantData::try_new`].
7878
pub fn try_new_array(
7979
dtype: DType,
8080
codes: ArrayRef,
@@ -101,7 +101,7 @@ impl VTable for TurboQuant {
101101
Self::ID
102102
}
103103

104-
fn validate(&self, data: &Self::ArrayData, dtype: &DType, _len: usize) -> VortexResult<()> {
104+
fn validate(&self, data: &Self::ArrayData, dtype: &DType, len: usize) -> VortexResult<()> {
105105
let ext = dtype
106106
.as_extension_opt()
107107
.filter(|e| e.is::<Vector>())
@@ -117,8 +117,15 @@ impl VTable for TurboQuant {
117117

118118
vortex_ensure_eq!(data.dimension(), dimension);
119119

120-
// TODO(connor): In the future, we will not need to validate `len` on the array data because
120+
// TODO(connor): In the future, we may not need to validate `len` on the array data because
121121
// the child arrays will be located somewhere else.
122+
// bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with
123+
// bit_width == 0 would have zero centroids while codes reference centroid indices.
124+
vortex_ensure!(
125+
data.bit_width > 0 || len == 0,
126+
"bit_width == 0 is only valid for empty arrays, got len={len}"
127+
);
128+
122129
Ok(())
123130
}
124131

@@ -187,6 +194,13 @@ impl VTable for TurboQuant {
187194

188195
let bit_width = metadata[0];
189196

197+
// bit_width == 0 is only valid for degenerate (empty) arrays. A non-empty array with
198+
// bit_width == 0 would have zero centroids while codes reference centroid indices.
199+
vortex_ensure!(
200+
bit_width > 0 || len == 0,
201+
"bit_width == 0 is only valid for empty arrays, got len={len}"
202+
);
203+
190204
// Validate and derive dimension and element ptype from the Vector extension dtype.
191205
let ext = TurboQuant::validate_dtype(dtype)?;
192206
let dimension = extension_list_size(ext)?;

0 commit comments

Comments
 (0)