11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- //! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor)
5- //! arrays.
4+ //! Cosine similarity expression for tensor-like extension arrays
5+ //! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and
6+ //! [`Vector`](crate::vector::Vector)).
67
78use std:: fmt:: Formatter ;
89
@@ -19,6 +20,7 @@ use vortex::array::match_each_float_ptype;
1920use vortex:: dtype:: DType ;
2021use vortex:: dtype:: NativePType ;
2122use vortex:: dtype:: Nullability ;
23+ use vortex:: dtype:: extension:: Matcher ;
2224use vortex:: error:: VortexResult ;
2325use vortex:: error:: vortex_bail;
2426use vortex:: error:: vortex_ensure;
@@ -31,18 +33,21 @@ use vortex::scalar_fn::ExecutionArgs;
3133use vortex:: scalar_fn:: ScalarFnId ;
3234use vortex:: scalar_fn:: ScalarFnVTable ;
3335
36+ use crate :: matcher:: AnyTensor ;
37+
3438// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
3539// encoded in spherical coordinates.
3640/// Cosine similarity between two columns.
3741///
38- /// For [`FixedShapeTensor`], computes ` dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of
39- /// each tensor. The shape and permutation do not affect the result because cosine similarity only
40- /// depends on the element values, not their logical arrangement.
42+ /// Computes ` dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
43+ /// The shape and permutation do not affect the result because cosine similarity only depends on the
44+ /// element values, not their logical arrangement.
4145///
42- /// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a
43- /// float element type. The output is a float column of the same float type.
46+ /// Both inputs must be tensor-like extension arrays ( [`FixedShapeTensor`] or [`Vector`]) with the
47+ /// same dtype and a float element type. The output is a float column of the same float type.
4448///
4549/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor
50+ /// [`Vector`]: crate::vector::Vector
4651#[ derive( Clone ) ]
4752pub struct CosineSimilarity ;
4853
@@ -92,10 +97,14 @@ impl ScalarFnVTable for CosineSimilarity {
9297
9398 // We don't need to look at rhs anymore since we know lhs and rhs are equal.
9499
95- // Both inputs must be extension types.
100+ // Both inputs must be tensor-like extension types.
96101 let lhs_ext = lhs. as_extension_opt ( ) . ok_or_else ( || {
97102 vortex_err ! ( "cosine_similarity lhs must be an extension type, got {lhs}" )
98103 } ) ?;
104+ vortex_ensure ! (
105+ AnyTensor :: matches( lhs_ext) ,
106+ "cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
107+ ) ;
99108
100109 // Extract the element dtype from the storage FixedSizeList.
101110 let element_dtype = lhs_ext
@@ -258,13 +267,15 @@ mod tests {
258267 use vortex:: dtype:: Nullability ;
259268 use vortex:: dtype:: extension:: ExtDType ;
260269 use vortex:: error:: VortexResult ;
270+ use vortex:: extension:: EmptyMetadata ;
261271 use vortex:: scalar:: Scalar ;
262272 use vortex:: scalar_fn:: EmptyOptions ;
263273 use vortex:: scalar_fn:: ScalarFn ;
264274
265275 use crate :: fixed_shape:: FixedShapeTensor ;
266276 use crate :: fixed_shape:: FixedShapeTensorMetadata ;
267277 use crate :: scalar_fns:: cosine_similarity:: CosineSimilarity ;
278+ use crate :: vector:: Vector ;
268279
269280 /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
270281 ///
@@ -460,4 +471,95 @@ mod tests {
460471 ) ;
461472 Ok ( ( ) )
462473 }
474+
475+ /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
476+ fn vector_array ( dim : u32 , elements : & [ f64 ] ) -> VortexResult < ArrayRef > {
477+ let row_count = elements. len ( ) / dim as usize ;
478+
479+ let elems: ArrayRef = Buffer :: copy_from ( elements) . into_array ( ) ;
480+ let fsl = FixedSizeListArray :: new ( elems, dim, Validity :: NonNullable , row_count) ;
481+
482+ let ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , fsl. dtype ( ) . clone ( ) ) ?. erased ( ) ;
483+
484+ Ok ( ExtensionArray :: new ( ext_dtype, fsl. into_array ( ) ) . into_array ( ) )
485+ }
486+
487+ #[ test]
488+ fn vector_unit_vectors ( ) -> VortexResult < ( ) > {
489+ let lhs = vector_array (
490+ 3 ,
491+ & [
492+ 1.0 , 0.0 , 0.0 , // vector 0
493+ 0.0 , 1.0 , 0.0 , // vector 1
494+ ] ,
495+ ) ?;
496+ let rhs = vector_array (
497+ 3 ,
498+ & [
499+ 1.0 , 0.0 , 0.0 , // vector 0
500+ 1.0 , 0.0 , 0.0 , // vector 1
501+ ] ,
502+ ) ?;
503+
504+ // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
505+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
506+ Ok ( ( ) )
507+ }
508+
509+ #[ test]
510+ fn vector_self_similarity ( ) -> VortexResult < ( ) > {
511+ let arr = vector_array (
512+ 4 ,
513+ & [
514+ 1.0 , 2.0 , 3.0 , 4.0 , // vector 0
515+ 0.0 , 1.0 , 0.0 , 0.0 , // vector 1
516+ 5.0 , 0.0 , 5.0 , 0.0 , // vector 2
517+ ] ,
518+ ) ?;
519+
520+ assert_close (
521+ & eval_cosine_similarity ( arr. clone ( ) , arr, 3 ) ?,
522+ & [ 1.0 , 1.0 , 1.0 ] ,
523+ ) ;
524+ Ok ( ( ) )
525+ }
526+
527+ /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
528+ fn constant_vector_array ( elements : & [ f64 ] , len : usize ) -> VortexResult < ArrayRef > {
529+ let element_dtype = DType :: Primitive ( vortex:: dtype:: PType :: F64 , Nullability :: NonNullable ) ;
530+
531+ let children: Vec < Scalar > = elements
532+ . iter ( )
533+ . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
534+ . collect ( ) ;
535+ let storage_scalar =
536+ Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
537+
538+ let storage = ConstantArray :: new ( storage_scalar, len) . into_array ( ) ;
539+
540+ let ext_dtype =
541+ ExtDType :: < Vector > :: try_new ( EmptyMetadata , storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
542+
543+ Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
544+ }
545+
546+ #[ test]
547+ fn vector_constant_query ( ) -> VortexResult < ( ) > {
548+ let data = vector_array (
549+ 3 ,
550+ & [
551+ 1.0 , 0.0 , 0.0 , // vector 0
552+ 0.0 , 1.0 , 0.0 , // vector 1
553+ 0.0 , 0.0 , 1.0 , // vector 2
554+ 1.0 , 0.0 , 0.0 , // vector 3
555+ ] ,
556+ ) ?;
557+ let query = constant_vector_array ( & [ 1.0 , 0.0 , 0.0 ] , 4 ) ?;
558+
559+ assert_close (
560+ & eval_cosine_similarity ( data, query, 4 ) ?,
561+ & [ 1.0 , 0.0 , 0.0 , 1.0 ] ,
562+ ) ;
563+ Ok ( ( ) )
564+ }
463565}
0 commit comments