Skip to content

Commit 32ea0f0

Browse files
committed
add normalized vector type
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 3f6f342 commit 32ea0f0

17 files changed

Lines changed: 1174 additions & 393 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,64 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::
250250

251251
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
252252

253+
pub mod vortex_tensor::normalized_vector
254+
255+
pub struct vortex_tensor::normalized_vector::AnyNormalizedVector
256+
257+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector
258+
259+
pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata
260+
261+
pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
262+
263+
pub struct vortex_tensor::normalized_vector::NormalizedVector
264+
265+
impl vortex_tensor::normalized_vector::NormalizedVector
266+
267+
pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
268+
269+
pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(storage: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
270+
271+
impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector
272+
273+
pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector
274+
275+
impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector
276+
277+
impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector
278+
279+
pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool
280+
281+
impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector
282+
283+
pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector
284+
285+
impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector
286+
287+
pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
288+
289+
impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector
290+
291+
pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
292+
293+
impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector
294+
295+
impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector
296+
297+
pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata
298+
299+
pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue
300+
301+
pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>
302+
303+
pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId
304+
305+
pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>
306+
307+
pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>
308+
309+
pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
310+
253311
pub mod vortex_tensor::scalar_fns
254312

255313
pub mod vortex_tensor::scalar_fns::cosine_similarity
@@ -382,8 +440,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:
382440

383441
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
384442

385-
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
386-
387443
pub mod vortex_tensor::scalar_fns::l2_norm
388444

389445
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
@@ -574,7 +630,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32
574630

575631
pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType
576632

577-
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
633+
pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(&self) -> bool
634+
635+
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult<Self>
578636

579637
impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata
580638

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub fn turboquant_encode(
101101
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?;
102102

103103
// SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally
104-
// bypass the strict normalized-row validation when reattaching the stored norms.
104+
// bypass the strict normalized-row and zero-row validation when reattaching the stored norms.
105105
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
106106
}
107107

vortex-tensor/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::scalar_fns::l2_denorm::L2Denorm;
1717
use crate::scalar_fns::l2_norm::L2Norm;
1818
use crate::scalar_fns::sorf_transform::SorfTransform;
1919
use crate::types::fixed_shape::FixedShapeTensor;
20+
use crate::types::normalized_vector::NormalizedVector;
2021
use crate::types::vector::Vector;
2122

2223
pub mod matcher;
@@ -25,6 +26,7 @@ pub mod scalar_fns;
2526
mod types;
2627

2728
pub use types::fixed_shape;
29+
pub use types::normalized_vector;
2830
pub use types::vector;
2931

3032
pub mod encodings;
@@ -43,6 +45,7 @@ pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_P
4345
/// Initialize the Vortex tensor library with a Vortex session.
4446
pub fn initialize(session: &VortexSession) {
4547
session.dtypes().register(Vector);
48+
session.dtypes().register(NormalizedVector);
4649
session.dtypes().register(FixedShapeTensor);
4750

4851
let session_fns = session.scalar_fns();

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@ use vortex_session::VortexSession;
3535

3636
use crate::scalar_fns::inner_product::BinaryTensorOpMetadata;
3737
use crate::scalar_fns::inner_product::InnerProduct;
38-
use crate::scalar_fns::l2_denorm::DenormOrientation;
38+
use crate::scalar_fns::l2_denorm::NormalForm;
3939
use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm;
4040
use crate::scalar_fns::l2_norm::L2Norm;
41-
use crate::utils::extract_l2_denorm_children;
4241
use crate::utils::validate_binary_tensor_float_inputs;
4342

4443
/// Cosine similarity between two columns.
@@ -141,15 +140,21 @@ impl ScalarFnVTable for CosineSimilarity {
141140
rhs_ref = sfn.into_array();
142141
}
143142

144-
// Take any L2Denorm-wrapped fast path that applies.
145-
match DenormOrientation::classify(&lhs_ref, &rhs_ref) {
146-
DenormOrientation::Both { lhs, rhs } => {
147-
return self.execute_both_denorm(lhs, rhs, len);
143+
// Classify each operand by its normal form. When both operands carry a known unit-norm
144+
// representation, cosine similarity collapses to the dot product of the unit vectors.
145+
let lhs_form = NormalForm::classify(&lhs_ref);
146+
let rhs_form = NormalForm::classify(&rhs_ref);
147+
match (lhs_form.unit_array(), rhs_form.unit_array()) {
148+
(Some(unit_lhs), Some(unit_rhs)) => {
149+
return self.execute_both_unit(unit_lhs, unit_rhs, &lhs_ref, &rhs_ref, len);
148150
}
149-
DenormOrientation::One { denorm, plain } => {
150-
return self.execute_one_denorm(denorm, plain, len, ctx);
151+
(Some(unit_lhs), None) => {
152+
return self.execute_one_unit(unit_lhs, &rhs_ref, &lhs_ref, len, ctx);
151153
}
152-
DenormOrientation::Neither => {}
154+
(None, Some(unit_rhs)) => {
155+
return self.execute_one_unit(unit_rhs, &lhs_ref, &rhs_ref, len, ctx);
156+
}
157+
(None, None) => {}
153158
}
154159

155160
// Compute combined validity.
@@ -242,22 +247,20 @@ impl ScalarFnArrayVTable for CosineSimilarity {
242247
}
243248

244249
impl CosineSimilarity {
245-
/// Both sides are `L2Denorm`: treat the normalized children as authoritative, so
246-
/// `cosine_similarity = dot(n_l, n_r)`.
247-
fn execute_both_denorm(
250+
/// Both sides carry a known unit-norm representation: cosine similarity collapses to the
251+
/// dot product of the unit children.
252+
fn execute_both_unit(
248253
&self,
254+
unit_lhs: &ArrayRef,
255+
unit_rhs: &ArrayRef,
249256
lhs_ref: &ArrayRef,
250257
rhs_ref: &ArrayRef,
251258
len: usize,
252259
) -> VortexResult<ArrayRef> {
253260
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
254261

255-
let (normalized_l, _) = extract_l2_denorm_children(lhs_ref);
256-
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);
257-
258-
// `L2Denorm` makes the normalized children authoritative, so their dot product is the
259-
// cosine similarity even for lossy storage wrappers.
260-
let dot = InnerProduct::try_new_array(normalized_l, normalized_r, len)?.into_array();
262+
let dot =
263+
InnerProduct::try_new_array(unit_lhs.clone(), unit_rhs.clone(), len)?.into_array();
261264

262265
if !matches!(validity, Validity::NonNullable) {
263266
// Masking always changes the nullability to nullable.
@@ -267,22 +270,21 @@ impl CosineSimilarity {
267270
}
268271
}
269272

270-
/// One side is `L2Denorm`: treat the normalized child as authoritative, so
271-
/// `cosine_similarity = dot(n, b) / ||b||`.
272-
///
273-
/// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`.
274-
fn execute_one_denorm(
273+
/// Exactly one side carries a unit-norm representation: cosine similarity reduces to
274+
/// `dot(unit, other) / ||other||`. The norms of the unit side are implicitly `1.0` (naked
275+
/// `NormalizedVector`) or stored separately (the outer `L2Denorm` wrapper, which is not
276+
/// needed here since cosine ignores magnitude).
277+
fn execute_one_unit(
275278
&self,
276-
denorm_ref: &ArrayRef,
279+
unit: &ArrayRef,
277280
plain_ref: &ArrayRef,
281+
unit_ref: &ArrayRef,
278282
len: usize,
279283
ctx: &mut ExecutionCtx,
280284
) -> VortexResult<ArrayRef> {
281-
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;
285+
let validity = unit_ref.validity()?.and(plain_ref.validity()?)?;
282286

283-
let (normalized, _) = extract_l2_denorm_children(denorm_ref);
284-
285-
let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?;
287+
let dot_arr = InnerProduct::try_new_array(unit.clone(), plain_ref.clone(), len)?;
286288
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
287289

288290
let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?;
@@ -331,6 +333,7 @@ mod tests {
331333
use crate::utils::test_helpers::assert_close;
332334
use crate::utils::test_helpers::constant_tensor_array;
333335
use crate::utils::test_helpers::l2_denorm_array;
336+
use crate::utils::test_helpers::normalized_vector_array;
334337
use crate::utils::test_helpers::tensor_array;
335338
use crate::utils::test_helpers::vector_array;
336339

@@ -519,13 +522,25 @@ mod tests {
519522
Ok(())
520523
}
521524

525+
/// Naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) operands take the
526+
/// fast path: cosine similarity collapses to the dot product without computing norms.
527+
#[test]
528+
fn naked_normalized_vector_cosine() -> VortexResult<()> {
529+
let mut ctx = SESSION.create_execution_ctx();
530+
let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?;
531+
let rhs = normalized_vector_array(2, &[0.6, 0.8, 0.0, 1.0], &mut ctx)?;
532+
// Row 0: identical -> 1.0, Row 1: orthogonal -> 0.0.
533+
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
534+
Ok(())
535+
}
536+
522537
#[test]
523538
fn both_denorm_self_similarity() -> VortexResult<()> {
524539
// [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
525540
// [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
526541
let mut ctx = SESSION.create_execution_ctx();
527-
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
528-
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
542+
let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
543+
let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
529544

530545
// Self-similarity should always be 1.0.
531546
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]);
@@ -537,8 +552,8 @@ mod tests {
537552
// [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
538553
// [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
539554
let mut ctx = SESSION.create_execution_ctx();
540-
let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?;
541-
let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?;
555+
let lhs = l2_denorm_array(2, &[1.0, 0.0], &[3.0], &mut ctx)?;
556+
let rhs = l2_denorm_array(2, &[0.0, 1.0], &[4.0], &mut ctx)?;
542557

543558
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]);
544559
Ok(())
@@ -548,8 +563,8 @@ mod tests {
548563
fn both_denorm_zero_norm() -> VortexResult<()> {
549564
// Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
550565
let mut ctx = SESSION.create_execution_ctx();
551-
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?;
552-
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
566+
let lhs = l2_denorm_array(2, &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?;
567+
let rhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
553568

554569
// Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
555570
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
@@ -562,8 +577,8 @@ mod tests {
562577
// RHS is plain [3.0, 4.0].
563578
// cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
564579
let mut ctx = SESSION.create_execution_ctx();
565-
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?;
566-
let rhs = tensor_array(&[2], &[3.0, 4.0])?;
580+
let lhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?;
581+
let rhs = vector_array(2, &[3.0, 4.0])?;
567582

568583
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]);
569584
Ok(())
@@ -574,8 +589,8 @@ mod tests {
574589
// LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
575590
// cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
576591
let mut ctx = SESSION.create_execution_ctx();
577-
let lhs = tensor_array(&[2], &[1.0, 0.0])?;
578-
let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?;
592+
let lhs = vector_array(2, &[1.0, 0.0])?;
593+
let rhs = l2_denorm_array(2, &[0.6, 0.8], &[5.0], &mut ctx)?;
579594

580595
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]);
581596
Ok(())
@@ -585,9 +600,9 @@ mod tests {
585600
fn both_denorm_null_norms() -> VortexResult<()> {
586601
// Row 0: valid, row 1: null (via nullable norms on rhs).
587602
let mut ctx = SESSION.create_execution_ctx();
588-
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
603+
let lhs = l2_denorm_array(2, &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?;
589604

590-
let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
605+
let normalized_r = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?;
591606
let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
592607
let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array();
593608

@@ -703,6 +718,34 @@ mod tests {
703718
Ok(())
704719
}
705720

721+
#[test]
722+
fn serde_round_trip_mixed_vector_and_normalized_vector() -> VortexResult<()> {
723+
let mut ctx = SESSION.create_execution_ctx();
724+
let lhs = normalized_vector_array(2, &[0.6, 0.8, 1.0, 0.0], &mut ctx)?;
725+
let rhs = vector_array(2, &[3.0, 4.0, 0.0, 1.0])?;
726+
let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), 2)?.into_array();
727+
728+
let plugin = ScalarFnArrayPlugin::new(CosineSimilarity);
729+
let metadata = plugin
730+
.serialize(&original, &SESSION)?
731+
.expect("CosineSimilarity serialize must produce metadata");
732+
733+
let children = vec![lhs, rhs];
734+
let recovered = plugin.deserialize(
735+
original.dtype(),
736+
original.len(),
737+
&metadata,
738+
&[],
739+
&children,
740+
&SESSION,
741+
)?;
742+
743+
assert_eq!(recovered.dtype(), original.dtype());
744+
assert_eq!(recovered.len(), original.len());
745+
assert_eq!(recovered.encoding_id(), original.encoding_id());
746+
Ok(())
747+
}
748+
706749
#[rstest]
707750
#[case::vector(
708751
vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).unwrap(),

0 commit comments

Comments
 (0)