11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- //! Cosine similarity expression for tensor-like extension arrays
5- //! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and
6- //! [`Vector`](crate::vector::Vector)).
4+ //! Cosine similarity expression for tensor-like types.
75
86use std:: fmt:: Formatter ;
97
10- use num_traits:: Float ;
8+ use num_traits:: Zero ;
119use vortex_array:: ArrayRef ;
1210use vortex_array:: ExecutionCtx ;
1311use vortex_array:: IntoArray ;
12+ use vortex_array:: arrays:: ExtensionArray ;
1413use vortex_array:: arrays:: PrimitiveArray ;
14+ use vortex_array:: arrays:: ScalarFnArray ;
1515use vortex_array:: dtype:: DType ;
16- use vortex_array:: dtype:: NativePType ;
1716use vortex_array:: dtype:: Nullability ;
1817use vortex_array:: expr:: Expression ;
1918use vortex_array:: expr:: and;
2019use vortex_array:: match_each_float_ptype;
2120use vortex_array:: scalar_fn:: Arity ;
2221use vortex_array:: scalar_fn:: ChildName ;
2322use vortex_array:: scalar_fn:: ExecutionArgs ;
23+ use vortex_array:: scalar_fn:: ScalarFn ;
2424use vortex_array:: scalar_fn:: ScalarFnId ;
2525use vortex_array:: scalar_fn:: ScalarFnVTable ;
26+ use vortex_buffer:: Buffer ;
2627use vortex_error:: VortexResult ;
2728use vortex_error:: vortex_ensure;
2829use vortex_error:: vortex_err;
2930
3031use crate :: matcher:: AnyTensor ;
3132use crate :: scalar_fns:: ApproxOptions ;
33+ use crate :: scalar_fns:: inner_product:: InnerProduct ;
34+ use crate :: scalar_fns:: l2_norm:: L2Norm ;
3235use crate :: utils:: extension_element_ptype;
33- use crate :: utils:: extension_list_size;
34- use crate :: utils:: extension_storage;
35- use crate :: utils:: extract_flat_elements;
3636
3737/// Cosine similarity between two columns.
3838///
@@ -48,6 +48,30 @@ use crate::utils::extract_flat_elements;
4848#[ derive( Clone ) ]
4949pub struct CosineSimilarity ;
5050
51+ impl CosineSimilarity {
52+ /// Creates a new [`ScalarFn`] wrapping the cosine similarity operation with the given
53+ /// [`ApproxOptions`] controlling approximation behavior.
54+ pub fn new ( options : & ApproxOptions ) -> ScalarFn < CosineSimilarity > {
55+ ScalarFn :: new ( CosineSimilarity , options. clone ( ) )
56+ }
57+
58+ /// Constructs a [`ScalarFnArray`] that lazily computes the cosine similarity between `lhs` and
59+ /// `rhs`.
60+ ///
61+ /// # Errors
62+ ///
63+ /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype
64+ /// mismatches).
65+ pub fn try_new_array (
66+ options : & ApproxOptions ,
67+ lhs : ArrayRef ,
68+ rhs : ArrayRef ,
69+ len : usize ,
70+ ) -> VortexResult < ScalarFnArray > {
71+ ScalarFnArray :: try_new ( CosineSimilarity :: new ( options) . erased ( ) , vec ! [ lhs, rhs] , len)
72+ }
73+ }
74+
5175impl ScalarFnVTable for CosineSimilarity {
5276 type Options = ApproxOptions ;
5377
@@ -114,37 +138,49 @@ impl ScalarFnVTable for CosineSimilarity {
114138
115139 fn execute (
116140 & self ,
117- _options : & Self :: Options ,
141+ options : & Self :: Options ,
118142 args : & dyn ExecutionArgs ,
119143 ctx : & mut ExecutionCtx ,
120144 ) -> VortexResult < ArrayRef > {
121- let lhs = args. get ( 0 ) ?;
122- let rhs = args. get ( 1 ) ?;
123- let row_count = args. row_count ( ) ;
124-
125- // Get list size from the dtype. Both sides should have the same dtype.
126- let ext = lhs. dtype ( ) . as_extension_opt ( ) . ok_or_else ( || {
127- vortex_err ! (
128- "cosine_similarity input must be an extension type, got {}" ,
129- lhs. dtype( )
130- )
131- } ) ?;
132- let list_size = extension_list_size ( ext) ? as usize ;
133-
134- // Extract the storage array from each extension input. We pass the storage (FSL) rather
135- // than the extension array to avoid canonicalizing the extension wrapper.
136- let lhs_storage = extension_storage ( & lhs) ?;
137- let rhs_storage = extension_storage ( & rhs) ?;
138-
139- let lhs_flat = extract_flat_elements ( & lhs_storage, list_size, ctx) ?;
140- let rhs_flat = extract_flat_elements ( & rhs_storage, list_size, ctx) ?;
141-
142- match_each_float_ptype ! ( lhs_flat. ptype( ) , |T | {
143- let result: PrimitiveArray = ( 0 ..row_count)
144- . map( |i| cosine_similarity_row( lhs_flat. row:: <T >( i) , rhs_flat. row:: <T >( i) ) )
145+ let lhs = args. get ( 0 ) ?. execute :: < ExtensionArray > ( ctx) ?. into_array ( ) ;
146+ let rhs = args. get ( 1 ) ?. execute :: < ExtensionArray > ( ctx) ?. into_array ( ) ;
147+
148+ let len = args. row_count ( ) ;
149+
150+ // Compute combined validity.
151+ let validity = lhs. validity ( ) ?. and ( rhs. validity ( ) ?) ?;
152+
153+ // Compute inner product and norms as columnar operations, and propagate the options.
154+ let norm_lhs_arr = L2Norm :: try_new_array ( options, lhs. clone ( ) , len) ?;
155+ let norm_rhs_arr = L2Norm :: try_new_array ( options, rhs. clone ( ) , len) ?;
156+ let dot_arr = InnerProduct :: try_new_array ( options, lhs, rhs, len) ?;
157+
158+ // Execute to get PrimitiveArrays.
159+ let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
160+ let norm_l: PrimitiveArray = norm_lhs_arr. into_array ( ) . execute ( ctx) ?;
161+ let norm_r: PrimitiveArray = norm_rhs_arr. into_array ( ) . execute ( ctx) ?;
162+
163+ // Divide element-wise, guarding against zero norms.
164+ match_each_float_ptype ! ( dot. ptype( ) , |T | {
165+ let dots = dot. as_slice:: <T >( ) ;
166+ let norms_l = norm_l. as_slice:: <T >( ) ;
167+ let norms_r = norm_r. as_slice:: <T >( ) ;
168+ let buffer: Buffer <T > = ( 0 ..len)
169+ . map( |i| {
170+ // TODO(connor): Would it be better to make this a binary multiply?
171+ // What happens when this overflows???
172+ let denom = norms_l[ i] * norms_r[ i] ;
173+
174+ if denom == T :: zero( ) {
175+ T :: zero( )
176+ } else {
177+ dots[ i] / denom
178+ }
179+ } )
145180 . collect( ) ;
146181
147- Ok ( result. into_array( ) )
182+ // SAFETY: The buffer length equals `len`, which matches the source validity length.
183+ Ok ( unsafe { PrimitiveArray :: new_unchecked( buffer, validity) } . into_array( ) )
148184 } )
149185 }
150186
@@ -169,30 +205,16 @@ impl ScalarFnVTable for CosineSimilarity {
169205 }
170206}
171207
172- // TODO(connor): We should try to use a more performant library instead of doing this ourselves.
173- /// Computes cosine similarity between two equal-length float slices.
174- ///
175- /// Returns `dot(a, b) / (||a|| * ||b||)`. When either vector has zero norm, this naturally
176- /// produces `NaN` via `0.0 / 0.0`, matching standard floating-point semantics.
177- fn cosine_similarity_row < T : Float + NativePType > ( a : & [ T ] , b : & [ T ] ) -> T {
178- let mut dot = T :: zero ( ) ;
179- let mut norm_a = T :: zero ( ) ;
180- let mut norm_b = T :: zero ( ) ;
181- for i in 0 ..a. len ( ) {
182- dot = dot + a[ i] * b[ i] ;
183- norm_a = norm_a + a[ i] * a[ i] ;
184- norm_b = norm_b + b[ i] * b[ i] ;
185- }
186- dot / ( norm_a. sqrt ( ) * norm_b. sqrt ( ) )
187- }
188-
189208#[ cfg( test) ]
190209mod tests {
191210 use rstest:: rstest;
192211 use vortex_array:: ArrayRef ;
212+ use vortex_array:: IntoArray ;
193213 use vortex_array:: ToCanonical ;
214+ use vortex_array:: arrays:: MaskedArray ;
194215 use vortex_array:: arrays:: ScalarFnArray ;
195216 use vortex_array:: scalar_fn:: ScalarFn ;
217+ use vortex_array:: validity:: Validity ;
196218 use vortex_error:: VortexResult ;
197219
198220 use crate :: scalar_fns:: ApproxOptions ;
@@ -239,8 +261,8 @@ mod tests {
239261 #[ case:: opposite( & [ 3 ] , & [ 1.0 , 0.0 , 0.0 ] , & [ -1.0 , 0.0 , 0.0 ] , & [ -1.0 ] ) ]
240262 // dot=24, both magnitudes=5 -> 24/25 = 0.96.
241263 #[ case:: non_unit( & [ 2 ] , & [ 3.0 , 4.0 ] , & [ 4.0 , 3.0 ] , & [ 0.96 ] ) ]
242- // Zero vector -> 0/0 -> NaN .
243- #[ case:: zero_norm( & [ 2 ] , & [ 0.0 , 0.0 ] , & [ 1.0 , 0.0 ] , & [ f64 :: NAN ] ) ]
264+ // Zero vector -> guarded to 0.0 .
265+ #[ case:: zero_norm( & [ 2 ] , & [ 0.0 , 0.0 ] , & [ 1.0 , 0.0 ] , & [ 0.0 ] ) ]
244266 fn single_row (
245267 #[ case] shape : & [ usize ] ,
246268 #[ case] lhs_elems : & [ f64 ] ,
@@ -367,4 +389,22 @@ mod tests {
367389 ) ;
368390 Ok ( ( ) )
369391 }
392+
393+ #[ test]
394+ fn null_input_row ( ) -> VortexResult < ( ) > {
395+ // 2 rows of dim-2 vectors. Row 1 of rhs is masked as null.
396+ let lhs = tensor_array ( & [ 2 ] , & [ 3.0 , 4.0 , 1.0 , 0.0 ] ) ?;
397+ let rhs = tensor_array ( & [ 2 ] , & [ 3.0 , 4.0 , 0.0 , 1.0 ] ) ?;
398+ let rhs = MaskedArray :: try_new ( rhs, Validity :: from_iter ( [ true , false ] ) ) ?. into_array ( ) ;
399+
400+ let scalar_fn = ScalarFn :: new ( CosineSimilarity , ApproxOptions :: Exact ) . erased ( ) ;
401+ let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ lhs, rhs] , 2 ) ?;
402+ let prim = result. as_array ( ) . to_primitive ( ) ;
403+
404+ // Row 0: self-similarity = 1.0, row 1: null.
405+ assert ! ( prim. is_valid( 0 ) ?) ;
406+ assert ! ( !prim. is_valid( 1 ) ?) ;
407+ assert_close ( & [ prim. as_slice :: < f64 > ( ) [ 0 ] ] , & [ 1.0 ] ) ;
408+ Ok ( ( ) )
409+ }
370410}
0 commit comments