@@ -9,9 +9,10 @@ use num_traits::Zero;
99use vortex_array:: ArrayRef ;
1010use vortex_array:: ExecutionCtx ;
1111use vortex_array:: IntoArray ;
12- use vortex_array:: arrays:: ExtensionArray ;
1312use vortex_array:: arrays:: PrimitiveArray ;
1413use vortex_array:: arrays:: ScalarFnArray ;
14+ use vortex_array:: arrays:: scalar_fn:: ExactScalarFn ;
15+ use vortex_array:: builtins:: ArrayBuiltins ;
1516use vortex_array:: dtype:: DType ;
1617use vortex_array:: dtype:: Nullability ;
1718use vortex_array:: expr:: Expression ;
@@ -23,13 +24,16 @@ use vortex_array::scalar_fn::ExecutionArgs;
2324use vortex_array:: scalar_fn:: ScalarFn ;
2425use vortex_array:: scalar_fn:: ScalarFnId ;
2526use vortex_array:: scalar_fn:: ScalarFnVTable ;
27+ use vortex_array:: validity:: Validity ;
2628use vortex_buffer:: Buffer ;
2729use vortex_error:: VortexResult ;
2830use vortex_error:: vortex_ensure;
2931
3032use crate :: scalar_fns:: ApproxOptions ;
3133use crate :: scalar_fns:: inner_product:: InnerProduct ;
34+ use crate :: scalar_fns:: l2_denorm:: L2Denorm ;
3235use crate :: scalar_fns:: l2_norm:: L2Norm ;
36+ use crate :: utils:: extract_l2_denorm_children;
3337use crate :: utils:: validate_tensor_float_input;
3438
3539/// Cosine similarity between two columns.
@@ -126,35 +130,47 @@ impl ScalarFnVTable for CosineSimilarity {
126130 args : & dyn ExecutionArgs ,
127131 ctx : & mut ExecutionCtx ,
128132 ) -> VortexResult < ArrayRef > {
129- let lhs = args. get ( 0 ) ? . execute :: < ExtensionArray > ( ctx ) ?;
130- let rhs = args. get ( 1 ) ? . execute :: < ExtensionArray > ( ctx ) ?;
133+ let mut lhs_ref = args. get ( 0 ) ?;
134+ let mut rhs_ref = args. get ( 1 ) ?;
131135 let len = args. row_count ( ) ;
132136
133- // Compute combined validity.
134- let validity = lhs. as_ref ( ) . validity ( ) ?. and ( rhs. as_ref ( ) . validity ( ) ?) ?;
137+ // Check if any of our children have be already normalized.
138+ {
139+ let lhs_is_denorm = lhs_ref. is :: < ExactScalarFn < L2Denorm > > ( ) ;
140+ let rhs_is_denorm = rhs_ref. is :: < ExactScalarFn < L2Denorm > > ( ) ;
141+
142+ if lhs_is_denorm && rhs_is_denorm {
143+ return self . execute_both_denorm ( options, & lhs_ref, & rhs_ref, len, ctx) ;
144+ } else if lhs_is_denorm || rhs_is_denorm {
145+ if rhs_is_denorm {
146+ ( lhs_ref, rhs_ref) = ( rhs_ref, lhs_ref) ;
147+ }
148+ return self . execute_one_denorm ( options, & lhs_ref, & rhs_ref, len, ctx) ;
149+ }
150+ }
135151
136- let lhs = lhs . into_array ( ) ;
137- let rhs = rhs . into_array ( ) ;
152+ // Compute combined validity.
153+ let validity = lhs_ref . validity ( ) ? . and ( rhs_ref . validity ( ) ? ) ? ;
138154
139155 // Compute inner product and norms as columnar operations, and propagate the options.
140- let norm_lhs_arr = L2Norm :: try_new_array ( options, lhs . clone ( ) , len) ?;
141- let norm_rhs_arr = L2Norm :: try_new_array ( options, rhs . clone ( ) , len) ?;
142- let dot_arr = InnerProduct :: try_new_array ( options, lhs , rhs , len) ?;
156+ let norm_lhs_arr = L2Norm :: try_new_array ( options, lhs_ref . clone ( ) , len) ?;
157+ let norm_rhs_arr = L2Norm :: try_new_array ( options, rhs_ref . clone ( ) , len) ?;
158+ let dot_arr = InnerProduct :: try_new_array ( options, lhs_ref , rhs_ref , len) ?;
143159
144- // Execute to get PrimitiveArrays.
160+ // Execute to get the inner product and norms of the arrays. We only fully decompress
161+ // because we need to perform special logic (guard against 0) during division.
145162 let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
146163 let norm_l: PrimitiveArray = norm_lhs_arr. into_array ( ) . execute ( ctx) ?;
147164 let norm_r: PrimitiveArray = norm_rhs_arr. into_array ( ) . execute ( ctx) ?;
148165
149- // Divide element-wise, guarding against zero norms.
166+ // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
167+ // TODO(connor): This can be written in a more SIMD-friendly manner.
150168 match_each_float_ptype ! ( dot. ptype( ) , |T | {
151169 let dots = dot. as_slice:: <T >( ) ;
152170 let norms_l = norm_l. as_slice:: <T >( ) ;
153171 let norms_r = norm_r. as_slice:: <T >( ) ;
154172 let buffer: Buffer <T > = ( 0 ..len)
155173 . map( |i| {
156- // TODO(connor): Would it be better to make this a binary multiply?
157- // What happens when this overflows???
158174 let denom = norms_l[ i] * norms_r[ i] ;
159175
160176 if denom == T :: zero( ) {
@@ -191,6 +207,74 @@ impl ScalarFnVTable for CosineSimilarity {
191207 }
192208}
193209
210+ impl CosineSimilarity {
211+ /// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`.
212+ fn execute_both_denorm (
213+ & self ,
214+ options : & ApproxOptions ,
215+ lhs_ref : & ArrayRef ,
216+ rhs_ref : & ArrayRef ,
217+ len : usize ,
218+ _ctx : & mut ExecutionCtx ,
219+ ) -> VortexResult < ArrayRef > {
220+ let validity = lhs_ref. validity ( ) ?. and ( rhs_ref. validity ( ) ?) ?;
221+
222+ let ( normalized_l, _) = extract_l2_denorm_children ( lhs_ref) ;
223+ let ( normalized_r, _) = extract_l2_denorm_children ( rhs_ref) ;
224+
225+ // Dot product of already-normalized children IS the cosine similarity.
226+ let dot =
227+ InnerProduct :: try_new_array ( options, normalized_l, normalized_r, len) ?. into_array ( ) ;
228+
229+ if !matches ! ( validity, Validity :: NonNullable ) {
230+ // Masking always changes the nullability to nullable.
231+ dot. mask ( validity. to_array ( len) )
232+ } else {
233+ Ok ( dot)
234+ }
235+ }
236+
237+ /// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`.
238+ ///
239+ /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`.
240+ fn execute_one_denorm (
241+ & self ,
242+ options : & ApproxOptions ,
243+ denorm_ref : & ArrayRef ,
244+ plain_ref : & ArrayRef ,
245+ len : usize ,
246+ ctx : & mut ExecutionCtx ,
247+ ) -> VortexResult < ArrayRef > {
248+ let validity = denorm_ref. validity ( ) ?. and ( plain_ref. validity ( ) ?) ?;
249+
250+ let ( normalized, _) = extract_l2_denorm_children ( denorm_ref) ;
251+
252+ let dot_arr = InnerProduct :: try_new_array ( options, normalized, plain_ref. clone ( ) , len) ?;
253+ let norm_arr = L2Norm :: try_new_array ( options, plain_ref. clone ( ) , len) ?;
254+ let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
255+ let plain_norm: PrimitiveArray = norm_arr. into_array ( ) . execute ( ctx) ?;
256+
257+ // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
258+ // TODO(connor): This can be written in a more SIMD-friendly manner.
259+ match_each_float_ptype ! ( dot. ptype( ) , |T | {
260+ let dots = dot. as_slice:: <T >( ) ;
261+ let norms = plain_norm. as_slice:: <T >( ) ;
262+ let buffer: Buffer <T > = ( 0 ..len)
263+ . map( |i| {
264+ if norms[ i] == T :: zero( ) {
265+ T :: zero( )
266+ } else {
267+ dots[ i] / norms[ i]
268+ }
269+ } )
270+ . collect( ) ;
271+
272+ // SAFETY: The buffer length equals `len`, which matches the source validity length.
273+ Ok ( unsafe { PrimitiveArray :: new_unchecked( buffer, validity) } . into_array( ) )
274+ } )
275+ }
276+ }
277+
194278#[ cfg( test) ]
195279mod tests {
196280 use std:: sync:: LazyLock ;
@@ -210,6 +294,7 @@ mod tests {
210294
211295 use crate :: scalar_fns:: ApproxOptions ;
212296 use crate :: scalar_fns:: cosine_similarity:: CosineSimilarity ;
297+ use crate :: scalar_fns:: l2_denorm:: L2Denorm ;
213298 use crate :: utils:: test_helpers:: assert_close;
214299 use crate :: utils:: test_helpers:: constant_tensor_array;
215300 use crate :: utils:: test_helpers:: constant_vector_array;
@@ -403,4 +488,99 @@ mod tests {
403488 assert_close ( & [ prim. as_slice :: < f64 > ( ) [ 0 ] ] , & [ 1.0 ] ) ;
404489 Ok ( ( ) )
405490 }
491+
492+ /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms.
493+ fn l2_denorm_array (
494+ shape : & [ usize ] ,
495+ normalized_elements : & [ f64 ] ,
496+ norms : & [ f64 ] ,
497+ ) -> VortexResult < ArrayRef > {
498+ let len = norms. len ( ) ;
499+ let normalized = tensor_array ( shape, normalized_elements) ?;
500+ let norms = PrimitiveArray :: from_iter ( norms. iter ( ) . copied ( ) ) . into_array ( ) ;
501+ let mut ctx = SESSION . create_execution_ctx ( ) ;
502+ Ok (
503+ L2Denorm :: try_new_array ( & ApproxOptions :: Exact , normalized, norms, len, & mut ctx) ?
504+ . into_array ( ) ,
505+ )
506+ }
507+
508+ #[ test]
509+ fn both_denorm_self_similarity ( ) -> VortexResult < ( ) > {
510+ // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
511+ // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
512+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
513+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
514+
515+ // Self-similarity should always be 1.0.
516+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 1.0 ] ) ;
517+ Ok ( ( ) )
518+ }
519+
520+ #[ test]
521+ fn both_denorm_orthogonal ( ) -> VortexResult < ( ) > {
522+ // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
523+ // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
524+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 1.0 , 0.0 ] , & [ 3.0 ] ) ?;
525+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.0 , 1.0 ] , & [ 4.0 ] ) ?;
526+
527+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 0.0 ] ) ;
528+ Ok ( ( ) )
529+ }
530+
531+ #[ test]
532+ fn both_denorm_zero_norm ( ) -> VortexResult < ( ) > {
533+ // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
534+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 0.0 , 0.0 ] , & [ 5.0 , 0.0 ] ) ?;
535+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
536+
537+ // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
538+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
539+ Ok ( ( ) )
540+ }
541+
542+ #[ test]
543+ fn one_side_denorm_lhs ( ) -> VortexResult < ( ) > {
544+ // LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
545+ // RHS is plain [3.0, 4.0].
546+ // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
547+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 ] , & [ 5.0 ] ) ?;
548+ let rhs = tensor_array ( & [ 2 ] , & [ 3.0 , 4.0 ] ) ?;
549+
550+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 1.0 ] ) ;
551+ Ok ( ( ) )
552+ }
553+
554+ #[ test]
555+ fn one_side_denorm_rhs ( ) -> VortexResult < ( ) > {
556+ // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
557+ // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
558+ let lhs = tensor_array ( & [ 2 ] , & [ 1.0 , 0.0 ] ) ?;
559+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 ] , & [ 5.0 ] ) ?;
560+
561+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 0.6 ] ) ;
562+ Ok ( ( ) )
563+ }
564+
565+ #[ test]
566+ fn both_denorm_null_norms ( ) -> VortexResult < ( ) > {
567+ // Row 0: valid, row 1: null (via nullable norms on rhs).
568+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
569+
570+ let normalized_r = tensor_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] ) ?;
571+ let norms_r = PrimitiveArray :: from_option_iter ( [ Some ( 5.0f64 ) , None ] ) . into_array ( ) ;
572+ let mut ctx = SESSION . create_execution_ctx ( ) ;
573+ let rhs =
574+ L2Denorm :: try_new_array ( & ApproxOptions :: Exact , normalized_r, norms_r, 2 , & mut ctx) ?
575+ . into_array ( ) ;
576+
577+ let scalar_fn = ScalarFn :: new ( CosineSimilarity , ApproxOptions :: Exact ) . erased ( ) ;
578+ let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ lhs, rhs] , 2 ) ?;
579+ let prim: PrimitiveArray = result. into_array ( ) . execute ( & mut ctx) ?;
580+
581+ assert ! ( prim. is_valid( 0 ) ?) ;
582+ assert ! ( !prim. is_valid( 1 ) ?) ;
583+ assert_close ( & [ prim. as_slice :: < f64 > ( ) [ 0 ] ] , & [ 1.0 ] ) ;
584+ Ok ( ( ) )
585+ }
406586}
0 commit comments