11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- //! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor)
5- //! arrays.
4+ //! Cosine similarity expression for tensor-like extension arrays
5+ //! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and
6+ //! [`Vector`](crate::vector::Vector)).
67
78use std:: fmt:: Formatter ;
89
@@ -19,6 +20,7 @@ use vortex::array::match_each_float_ptype;
1920use vortex:: dtype:: DType ;
2021use vortex:: dtype:: NativePType ;
2122use vortex:: dtype:: Nullability ;
23+ use vortex:: dtype:: extension:: Matcher ;
2224use vortex:: error:: VortexResult ;
2325use vortex:: error:: vortex_bail;
2426use vortex:: error:: vortex_ensure;
@@ -31,18 +33,21 @@ use vortex::scalar_fn::ExecutionArgs;
3133use vortex:: scalar_fn:: ScalarFnId ;
3234use vortex:: scalar_fn:: ScalarFnVTable ;
3335
36+ use crate :: matcher:: AnyTensor ;
37+
3438// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
3539// encoded in spherical coordinates.
3640/// Cosine similarity between two columns.
3741///
38- /// For [`FixedShapeTensor`], computes ` dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of
39- /// each tensor. The shape and permutation do not affect the result because cosine similarity only
40- /// depends on the element values, not their logical arrangement.
42+ /// Computes ` dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
43+ /// The shape and permutation do not affect the result because cosine similarity only depends on the
44+ /// element values, not their logical arrangement.
4145///
42- /// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a
43- /// float element type. The output is a float column of the same float type.
46+ /// Both inputs must be tensor-like extension arrays ( [`FixedShapeTensor`] or [`Vector`]) with the
47+ /// same dtype and a float element type. The output is a float column of the same float type.
4448///
4549/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor
50+ /// [`Vector`]: crate::vector::Vector
4651#[ derive( Clone ) ]
4752pub struct CosineSimilarity ;
4853
@@ -92,10 +97,14 @@ impl ScalarFnVTable for CosineSimilarity {
9297
9398 // We don't need to look at rhs anymore since we know lhs and rhs are equal.
9499
95- // Both inputs must be extension types.
100+ // Both inputs must be tensor-like extension types.
96101 let lhs_ext = lhs. as_extension_opt ( ) . ok_or_else ( || {
97102 vortex_err ! ( "cosine_similarity lhs must be an extension type, got {lhs}" )
98103 } ) ?;
104+ vortex_ensure ! (
105+ AnyTensor :: matches( lhs_ext) ,
106+ "cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
107+ ) ;
99108
100109 // Extract the element dtype from the storage FixedSizeList.
101110 let element_dtype = lhs_ext
@@ -257,13 +266,15 @@ mod tests {
257266 use vortex:: dtype:: Nullability ;
258267 use vortex:: dtype:: extension:: ExtDType ;
259268 use vortex:: error:: VortexResult ;
269+ use vortex:: extension:: EmptyMetadata ;
260270 use vortex:: scalar:: Scalar ;
261271 use vortex:: scalar_fn:: EmptyOptions ;
262272 use vortex:: scalar_fn:: ScalarFn ;
263273
264274 use crate :: fixed_shape:: FixedShapeTensor ;
265275 use crate :: fixed_shape:: FixedShapeTensorMetadata ;
266276 use crate :: scalar_fns:: cosine_similarity:: CosineSimilarity ;
277+ use crate :: vector:: Vector ;
267278
268279 /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
269280 ///
@@ -459,4 +470,95 @@ mod tests {
459470 ) ;
460471 Ok ( ( ) )
461472 }
473+
474+ /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
475+ fn vector_array ( dim : u32 , elements : & [ f64 ] ) -> VortexResult < ArrayRef > {
476+ let row_count = elements. len ( ) / dim as usize ;
477+
478+ let elems: ArrayRef = Buffer :: copy_from ( elements) . into_array ( ) ;
479+ let fsl = FixedSizeListArray :: new ( elems, dim, Validity :: NonNullable , row_count) ;
480+
481+ let ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , fsl. dtype ( ) . clone ( ) ) ?. erased ( ) ;
482+
483+ Ok ( ExtensionArray :: new ( ext_dtype, fsl. into_array ( ) ) . into_array ( ) )
484+ }
485+
486+ #[ test]
487+ fn vector_unit_vectors ( ) -> VortexResult < ( ) > {
488+ let lhs = vector_array (
489+ 3 ,
490+ & [
491+ 1.0 , 0.0 , 0.0 , // vector 0
492+ 0.0 , 1.0 , 0.0 , // vector 1
493+ ] ,
494+ ) ?;
495+ let rhs = vector_array (
496+ 3 ,
497+ & [
498+ 1.0 , 0.0 , 0.0 , // vector 0
499+ 1.0 , 0.0 , 0.0 , // vector 1
500+ ] ,
501+ ) ?;
502+
503+ // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
504+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
505+ Ok ( ( ) )
506+ }
507+
508+ #[ test]
509+ fn vector_self_similarity ( ) -> VortexResult < ( ) > {
510+ let arr = vector_array (
511+ 4 ,
512+ & [
513+ 1.0 , 2.0 , 3.0 , 4.0 , // vector 0
514+ 0.0 , 1.0 , 0.0 , 0.0 , // vector 1
515+ 5.0 , 0.0 , 5.0 , 0.0 , // vector 2
516+ ] ,
517+ ) ?;
518+
519+ assert_close (
520+ & eval_cosine_similarity ( arr. clone ( ) , arr, 3 ) ?,
521+ & [ 1.0 , 1.0 , 1.0 ] ,
522+ ) ;
523+ Ok ( ( ) )
524+ }
525+
526+ /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
527+ fn constant_vector_array ( elements : & [ f64 ] , len : usize ) -> VortexResult < ArrayRef > {
528+ let element_dtype = DType :: Primitive ( vortex:: dtype:: PType :: F64 , Nullability :: NonNullable ) ;
529+
530+ let children: Vec < Scalar > = elements
531+ . iter ( )
532+ . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
533+ . collect ( ) ;
534+ let storage_scalar =
535+ Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
536+
537+ let storage = ConstantArray :: new ( storage_scalar, len) . into_array ( ) ;
538+
539+ let ext_dtype =
540+ ExtDType :: < Vector > :: try_new ( EmptyMetadata , storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
541+
542+ Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
543+ }
544+
545+ #[ test]
546+ fn vector_constant_query ( ) -> VortexResult < ( ) > {
547+ let data = vector_array (
548+ 3 ,
549+ & [
550+ 1.0 , 0.0 , 0.0 , // vector 0
551+ 0.0 , 1.0 , 0.0 , // vector 1
552+ 0.0 , 0.0 , 1.0 , // vector 2
553+ 1.0 , 0.0 , 0.0 , // vector 3
554+ ] ,
555+ ) ?;
556+ let query = constant_vector_array ( & [ 1.0 , 0.0 , 0.0 ] , 4 ) ?;
557+
558+ assert_close (
559+ & eval_cosine_similarity ( data, query, 4 ) ?,
560+ & [ 1.0 , 0.0 , 0.0 , 1.0 ] ,
561+ ) ;
562+ Ok ( ( ) )
563+ }
462564}
0 commit comments