Skip to content

Commit 6a3dd6e

Browse files
committed
add normalized vector type
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent ba90b10 commit 6a3dd6e

16 files changed

Lines changed: 1173 additions & 393 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,64 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::
250250

251251
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
252252

253+
pub mod vortex_tensor::normalized_vector
254+
255+
pub struct vortex_tensor::normalized_vector::AnyNormalizedVector
256+
257+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector
258+
259+
pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata
260+
261+
pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
262+
263+
pub struct vortex_tensor::normalized_vector::NormalizedVector
264+
265+
impl vortex_tensor::normalized_vector::NormalizedVector
266+
267+
pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
268+
269+
pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(storage: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
270+
271+
impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector
272+
273+
pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector
274+
275+
impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector
276+
277+
impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector
278+
279+
pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool
280+
281+
impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector
282+
283+
pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector
284+
285+
impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector
286+
287+
pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
288+
289+
impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector
290+
291+
pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
292+
293+
impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector
294+
295+
impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector
296+
297+
pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata
298+
299+
pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue
300+
301+
pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>
302+
303+
pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId
304+
305+
pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>
306+
307+
pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>
308+
309+
pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
310+
253311
pub mod vortex_tensor::scalar_fns
254312

255313
pub mod vortex_tensor::scalar_fns::cosine_similarity
@@ -382,8 +440,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:
382440

383441
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
384442

385-
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
386-
387443
pub mod vortex_tensor::scalar_fns::l2_norm
388444

389445
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
@@ -574,7 +630,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32
574630

575631
pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType
576632

577-
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
633+
pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(&self) -> bool
634+
635+
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult<Self>
578636

579637
impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata
580638

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub fn turboquant_encode(
101101
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?;
102102

103103
// SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally
104-
// bypass the strict normalized-row validation when reattaching the stored norms.
104+
// bypass the strict normalized-row and zero-row validation when reattaching the stored norms.
105105
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
106106
}
107107

vortex-tensor/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use vortex_array::session::ArraySessionExt;
1212
use vortex_session::VortexSession;
1313

1414
use crate::fixed_shape::FixedShapeTensor;
15+
use crate::normalized_vector::NormalizedVector;
1516
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
1617
use crate::scalar_fns::inner_product::InnerProduct;
1718
use crate::scalar_fns::l2_denorm::L2Denorm;
@@ -23,6 +24,7 @@ pub mod matcher;
2324
pub mod scalar_fns;
2425

2526
pub mod fixed_shape;
27+
pub mod normalized_vector;
2628
pub mod vector;
2729

2830
pub mod encodings;
@@ -41,6 +43,7 @@ pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_P
4143
/// Initialize the Vortex tensor library with a Vortex session.
4244
pub fn initialize(session: &VortexSession) {
4345
session.dtypes().register(Vector);
46+
session.dtypes().register(NormalizedVector);
4447
session.dtypes().register(FixedShapeTensor);
4548

4649
let session_fns = session.scalar_fns();
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_array::dtype::DType;
5+
use vortex_array::dtype::extension::ExtDTypeRef;
6+
use vortex_array::dtype::extension::Matcher;
7+
use vortex_error::VortexExpect;
8+
use vortex_error::vortex_panic;
9+
10+
use crate::normalized_vector::NormalizedVector;
11+
use crate::vector::VectorMatcherMetadata;
12+
13+
/// Matcher that accepts only the [`NormalizedVector`] extension type.
14+
///
15+
/// Use this when a consumer must reject plain [`Vector`](crate::vector::Vector) inputs. Callers
16+
/// that can accept either should use [`AnyVector`](crate::vector::AnyVector) instead.
17+
pub struct AnyNormalizedVector;
18+
19+
impl Matcher for AnyNormalizedVector {
20+
type Match<'a> = VectorMatcherMetadata;
21+
22+
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
23+
if !ext_dtype.is::<NormalizedVector>() {
24+
return None;
25+
}
26+
27+
let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
28+
vortex_panic!(
29+
"`NormalizedVector` type somehow did not have a `FixedSizeList` storage type"
30+
)
31+
};
32+
assert!(element_dtype.is_float(), "element dtype must be float");
33+
assert!(
34+
!element_dtype.is_nullable(),
35+
"element dtype must be non-nullable"
36+
);
37+
38+
let metadata = VectorMatcherMetadata::try_new(element_dtype.as_ptype(), *list_size, true)
39+
.vortex_expect("`NormalizedVector` type somehow did not have float elements");
40+
41+
Some(metadata)
42+
}
43+
}
44+
45+
#[cfg(test)]
46+
mod tests {
47+
use std::sync::Arc;
48+
49+
use vortex_array::dtype::DType;
50+
use vortex_array::dtype::Nullability;
51+
use vortex_array::dtype::PType;
52+
use vortex_array::dtype::extension::ExtDType;
53+
use vortex_array::extension::EmptyMetadata;
54+
use vortex_error::VortexResult;
55+
56+
use super::*;
57+
use crate::vector::AnyVector;
58+
use crate::vector::Vector;
59+
60+
fn storage_dtype(element_ptype: PType, dimensions: u32) -> DType {
61+
DType::FixedSizeList(
62+
Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
63+
dimensions,
64+
Nullability::NonNullable,
65+
)
66+
}
67+
68+
#[test]
69+
fn matches_normalized_vector_dtype() -> VortexResult<()> {
70+
let ext_dtype =
71+
ExtDType::<NormalizedVector>::try_new(EmptyMetadata, storage_dtype(PType::F32, 128))?
72+
.erased();
73+
74+
let metadata = ext_dtype.metadata::<AnyNormalizedVector>();
75+
assert_eq!(metadata.element_ptype(), PType::F32);
76+
assert_eq!(metadata.dimensions(), 128);
77+
assert!(metadata.is_normalized());
78+
Ok(())
79+
}
80+
81+
#[test]
82+
fn rejects_plain_vector() -> VortexResult<()> {
83+
let ext_dtype =
84+
ExtDType::<Vector>::try_new(EmptyMetadata, storage_dtype(PType::F32, 128))?.erased();
85+
86+
assert!(ext_dtype.metadata_opt::<AnyNormalizedVector>().is_none());
87+
Ok(())
88+
}
89+
90+
#[test]
91+
fn any_vector_matches_normalized_vector() -> VortexResult<()> {
92+
let ext_dtype =
93+
ExtDType::<NormalizedVector>::try_new(EmptyMetadata, storage_dtype(PType::F32, 128))?
94+
.erased();
95+
96+
let metadata = ext_dtype.metadata::<AnyVector>();
97+
assert_eq!(metadata.element_ptype(), PType::F32);
98+
assert_eq!(metadata.dimensions(), 128);
99+
assert!(metadata.is_normalized());
100+
Ok(())
101+
}
102+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Normalized vector extension type: a refinement of [`Vector`](crate::vector::Vector) whose
5+
//! rows are guaranteed (or asserted, for lossy encodings) to have unit L2 norm.
6+
7+
use num_traits::ToPrimitive;
8+
use vortex_array::ArrayRef;
9+
use vortex_array::ExecutionCtx;
10+
use vortex_array::IntoArray;
11+
use vortex_array::arrays::ExtensionArray;
12+
use vortex_array::arrays::extension::ExtensionArrayExt;
13+
use vortex_array::dtype::PType;
14+
use vortex_array::extension::EmptyMetadata;
15+
use vortex_array::match_each_float_ptype;
16+
use vortex_error::VortexResult;
17+
use vortex_error::vortex_ensure;
18+
19+
use crate::utils::extract_flat_elements;
20+
use crate::utils::validate_tensor_float_input;
21+
22+
/// Refinement of [`Vector`](crate::vector::Vector) that asserts every valid row is L2-normalized
23+
/// (unit-norm) or the zero vector.
24+
///
25+
/// The storage shape is identical to [`Vector`](crate::vector::Vector): a `FixedSizeList<float,
26+
/// dim, nullability>` with non-nullable float elements. Downstream operators such as
27+
/// [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm),
28+
/// [`L2Norm`](crate::scalar_fns::l2_norm::L2Norm),
29+
/// [`InnerProduct`](crate::scalar_fns::inner_product::InnerProduct), and
30+
/// [`CosineSimilarity`](crate::scalar_fns::cosine_similarity::CosineSimilarity) short-circuit
31+
/// arithmetic when they see this refinement.
32+
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
33+
pub struct NormalizedVector;
34+
35+
impl NormalizedVector {
36+
/// Wraps `storage` as a [`NormalizedVector`] extension array after checking that every valid
37+
/// row is unit-norm or the zero vector.
38+
///
39+
/// # Errors
40+
///
41+
/// Returns an error if the extension dtype rejects `storage`, if `storage` is not a tensor
42+
/// with float elements, or if any valid row's L2 norm is not `1.0` (or `0.0`) within the
43+
/// tolerance implied by the element precision.
44+
pub fn try_new(storage: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
45+
let ext = ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, storage)?
46+
.into_array();
47+
validate_unit_norm_rows(&ext, ctx)?;
48+
Ok(ext)
49+
}
50+
51+
/// Wraps `storage` as a [`NormalizedVector`] extension array **without** validating that
52+
/// rows are unit-norm.
53+
///
54+
/// # Safety
55+
///
56+
/// Every valid row must be unit-norm or the zero vector. Lossy approximations (e.g.
57+
/// TurboQuant) deliberately relax this, but still treat the claim as authoritative
58+
/// downstream. Violating this does not cause memory unsafety but will produce silently
59+
/// incorrect results.
60+
///
61+
/// # Errors
62+
///
63+
/// Returns an error if the extension dtype rejects `storage` (e.g. non-FSL storage, wrong
64+
/// element dtype, or nullable elements).
65+
pub unsafe fn new_unchecked(storage: ArrayRef) -> VortexResult<ArrayRef> {
66+
Ok(
67+
ExtensionArray::try_new_from_vtable(NormalizedVector, EmptyMetadata, storage)?
68+
.into_array(),
69+
)
70+
}
71+
}
72+
73+
/// Returns the acceptable unit-norm drift for the given element precision.
74+
pub(crate) fn unit_norm_tolerance(element_ptype: PType) -> f64 {
75+
match element_ptype {
76+
PType::F16 => 2e-3,
77+
PType::F32 => 2e-6,
78+
PType::F64 => 1e-10,
79+
_ => unreachable!("NormalizedVector requires float elements, got {element_ptype:?}"),
80+
}
81+
}
82+
83+
/// Validates that every valid row of a [`NormalizedVector`] extension array has L2 norm `1.0`
84+
/// or `0.0` within the element-precision tolerance.
85+
fn validate_unit_norm_rows(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
86+
let row_count = array.len();
87+
if row_count == 0 {
88+
return Ok(());
89+
}
90+
91+
let tensor_match = validate_tensor_float_input(array.dtype())?;
92+
let element_ptype = tensor_match.element_ptype();
93+
let tolerance = unit_norm_tolerance(element_ptype);
94+
let tensor_flat_size = tensor_match.list_size() as usize;
95+
96+
let ext: ExtensionArray = array.clone().execute(ctx)?;
97+
let validity = ext.as_ref().validity()?;
98+
let flat = extract_flat_elements(ext.storage_array(), tensor_flat_size, ctx)?;
99+
100+
match_each_float_ptype!(element_ptype, |T| {
101+
for i in 0..row_count {
102+
if !validity.is_valid(i)? {
103+
continue;
104+
}
105+
106+
let row_norm_sq = flat.row::<T>(i).iter().fold(0.0f64, |sum_sq, x| {
107+
let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN);
108+
sum_sq + value * value
109+
});
110+
let row_norm = row_norm_sq.sqrt();
111+
112+
vortex_ensure!(
113+
row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance,
114+
"NormalizedVector row {i} has L2 norm {row_norm:.6}, expected 1.0 or 0.0",
115+
);
116+
}
117+
});
118+
119+
Ok(())
120+
}
121+
122+
mod matcher;
123+
mod vtable;
124+
125+
pub use matcher::AnyNormalizedVector;

0 commit comments

Comments
 (0)