Skip to content

Commit a2ec5b4

Browse files
committed
add AnyTensor matcher and impl cosine similarity for vector
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 903eab9 commit a2ec5b4

4 files changed

Lines changed: 154 additions & 14 deletions

File tree

vortex-tensor/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
pub mod fixed_shape;
99
pub mod vector;
1010

11+
pub mod matcher;
1112
pub mod scalar_fns;

vortex-tensor/src/matcher.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Matcher for tensor-like extension types.
5+
6+
use vortex::dtype::extension::ExtDTypeRef;
7+
use vortex::dtype::extension::Matcher;
8+
9+
use crate::fixed_shape::FixedShapeTensor;
10+
use crate::fixed_shape::FixedShapeTensorMetadata;
11+
use crate::vector::Vector;
12+
13+
/// Matcher for any tensor-like extension type.
14+
///
15+
/// Currently the different kinds of tensors that are available are:
16+
///
17+
/// - `FixedShapeTensor`
18+
/// - `Vector`
19+
pub struct AnyTensor;
20+
21+
/// The matched variant of a tensor-like extension type.
22+
#[derive(Debug, PartialEq, Eq)]
23+
pub enum TensorMatch<'a> {
24+
/// A [`FixedShapeTensor`] extension type.
25+
FixedShapeTensor(&'a FixedShapeTensorMetadata),
26+
/// A [`Vector`] extension type.
27+
Vector,
28+
}
29+
30+
impl Matcher for AnyTensor {
31+
type Match<'a> = TensorMatch<'a>;
32+
33+
fn try_match<'a>(item: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
34+
if let Some(metadata) = item.metadata_opt::<FixedShapeTensor>() {
35+
return Some(TensorMatch::FixedShapeTensor(metadata));
36+
}
37+
if item.metadata_opt::<Vector>().is_some() {
38+
return Some(TensorMatch::Vector);
39+
}
40+
None
41+
}
42+
}

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
//! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor)
5-
//! arrays.
4+
//! Cosine similarity expression for tensor-like extension arrays
5+
//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and
6+
//! [`Vector`](crate::vector::Vector)).
67
78
use std::fmt::Formatter;
89

@@ -19,6 +20,7 @@ use vortex::array::match_each_float_ptype;
1920
use vortex::dtype::DType;
2021
use vortex::dtype::NativePType;
2122
use vortex::dtype::Nullability;
23+
use vortex::dtype::extension::Matcher;
2224
use vortex::error::VortexResult;
2325
use vortex::error::vortex_bail;
2426
use vortex::error::vortex_ensure;
@@ -31,18 +33,21 @@ use vortex::scalar_fn::ExecutionArgs;
3133
use vortex::scalar_fn::ScalarFnId;
3234
use vortex::scalar_fn::ScalarFnVTable;
3335

36+
use crate::matcher::AnyTensor;
37+
3438
// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
3539
// encoded in spherical coordinates.
3640
/// Cosine similarity between two columns.
3741
///
38-
/// For [`FixedShapeTensor`], computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of
39-
/// each tensor. The shape and permutation do not affect the result because cosine similarity only
40-
/// depends on the element values, not their logical arrangement.
42+
/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
43+
/// The shape and permutation do not affect the result because cosine similarity only depends on the
44+
/// element values, not their logical arrangement.
4145
///
42-
/// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a
43-
/// float element type. The output is a float column of the same float type.
46+
/// Both inputs must be tensor-like extension arrays ([`FixedShapeTensor`] or [`Vector`]) with the
47+
/// same dtype and a float element type. The output is a float column of the same float type.
4448
///
4549
/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor
50+
/// [`Vector`]: crate::vector::Vector
4651
#[derive(Clone)]
4752
pub struct CosineSimilarity;
4853

@@ -92,10 +97,14 @@ impl ScalarFnVTable for CosineSimilarity {
9297

9398
// We don't need to look at rhs anymore since we know lhs and rhs are equal.
9499

95-
// Both inputs must be extension types.
100+
// Both inputs must be tensor-like extension types.
96101
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
97102
vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}")
98103
})?;
104+
vortex_ensure!(
105+
AnyTensor::matches(lhs_ext),
106+
"cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
107+
);
99108

100109
// Extract the element dtype from the storage FixedSizeList.
101110
let element_dtype = lhs_ext
@@ -257,13 +266,15 @@ mod tests {
257266
use vortex::dtype::Nullability;
258267
use vortex::dtype::extension::ExtDType;
259268
use vortex::error::VortexResult;
269+
use vortex::extension::EmptyMetadata;
260270
use vortex::scalar::Scalar;
261271
use vortex::scalar_fn::EmptyOptions;
262272
use vortex::scalar_fn::ScalarFn;
263273

264274
use crate::fixed_shape::FixedShapeTensor;
265275
use crate::fixed_shape::FixedShapeTensorMetadata;
266276
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
277+
use crate::vector::Vector;
267278

268279
/// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
269280
///
@@ -459,4 +470,95 @@ mod tests {
459470
);
460471
Ok(())
461472
}
473+
474+
/// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
475+
fn vector_array(dim: u32, elements: &[f64]) -> VortexResult<ArrayRef> {
476+
let row_count = elements.len() / dim as usize;
477+
478+
let elems: ArrayRef = Buffer::copy_from(elements).into_array();
479+
let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count);
480+
481+
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
482+
483+
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
484+
}
485+
486+
#[test]
487+
fn vector_unit_vectors() -> VortexResult<()> {
488+
let lhs = vector_array(
489+
3,
490+
&[
491+
1.0, 0.0, 0.0, // vector 0
492+
0.0, 1.0, 0.0, // vector 1
493+
],
494+
)?;
495+
let rhs = vector_array(
496+
3,
497+
&[
498+
1.0, 0.0, 0.0, // vector 0
499+
1.0, 0.0, 0.0, // vector 1
500+
],
501+
)?;
502+
503+
// Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
504+
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
505+
Ok(())
506+
}
507+
508+
#[test]
509+
fn vector_self_similarity() -> VortexResult<()> {
510+
let arr = vector_array(
511+
4,
512+
&[
513+
1.0, 2.0, 3.0, 4.0, // vector 0
514+
0.0, 1.0, 0.0, 0.0, // vector 1
515+
5.0, 0.0, 5.0, 0.0, // vector 2
516+
],
517+
)?;
518+
519+
assert_close(
520+
&eval_cosine_similarity(arr.clone(), arr, 3)?,
521+
&[1.0, 1.0, 1.0],
522+
);
523+
Ok(())
524+
}
525+
526+
/// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
527+
fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult<ArrayRef> {
528+
let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable);
529+
530+
let children: Vec<Scalar> = elements
531+
.iter()
532+
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
533+
.collect();
534+
let storage_scalar =
535+
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
536+
537+
let storage = ConstantArray::new(storage_scalar, len).into_array();
538+
539+
let ext_dtype =
540+
ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
541+
542+
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
543+
}
544+
545+
#[test]
546+
fn vector_constant_query() -> VortexResult<()> {
547+
let data = vector_array(
548+
3,
549+
&[
550+
1.0, 0.0, 0.0, // vector 0
551+
0.0, 1.0, 0.0, // vector 1
552+
0.0, 0.0, 1.0, // vector 2
553+
1.0, 0.0, 0.0, // vector 3
554+
],
555+
)?;
556+
let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?;
557+
558+
assert_close(
559+
&eval_cosine_similarity(data, query, 4)?,
560+
&[1.0, 0.0, 0.0, 1.0],
561+
);
562+
Ok(())
563+
}
462564
}

vortex-tensor/src/vector/vtable.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@ impl ExtVTable for Vector {
2727
Ok(Vec::new())
2828
}
2929

30-
fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult<Self::Metadata> {
31-
vortex_ensure!(
32-
metadata.is_empty(),
33-
"Vector metadata must be empty, got {} bytes",
34-
metadata.len()
35-
);
30+
fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult<Self::Metadata> {
3631
Ok(EmptyMetadata)
3732
}
3833

0 commit comments

Comments
 (0)