@@ -9,9 +9,10 @@ use num_traits::Zero;
99use vortex_array:: ArrayRef ;
1010use vortex_array:: ExecutionCtx ;
1111use vortex_array:: IntoArray ;
12- use vortex_array:: arrays:: ExtensionArray ;
1312use vortex_array:: arrays:: PrimitiveArray ;
1413use vortex_array:: arrays:: ScalarFnArray ;
14+ use vortex_array:: arrays:: scalar_fn:: ExactScalarFn ;
15+ use vortex_array:: builtins:: ArrayBuiltins ;
1516use vortex_array:: dtype:: DType ;
1617use vortex_array:: dtype:: Nullability ;
1718use vortex_array:: expr:: Expression ;
@@ -29,7 +30,9 @@ use vortex_error::vortex_ensure;
2930
3031use crate :: scalar_fns:: ApproxOptions ;
3132use crate :: scalar_fns:: inner_product:: InnerProduct ;
33+ use crate :: scalar_fns:: l2_denorm:: L2Denorm ;
3234use crate :: scalar_fns:: l2_norm:: L2Norm ;
35+ use crate :: utils:: extract_l2_denorm_children;
3336use crate :: utils:: validate_tensor_float_input;
3437
3538/// Cosine similarity between two columns.
@@ -126,35 +129,47 @@ impl ScalarFnVTable for CosineSimilarity {
126129 args : & dyn ExecutionArgs ,
127130 ctx : & mut ExecutionCtx ,
128131 ) -> VortexResult < ArrayRef > {
129- let lhs = args. get ( 0 ) ? . execute :: < ExtensionArray > ( ctx ) ?;
130- let rhs = args. get ( 1 ) ? . execute :: < ExtensionArray > ( ctx ) ?;
132+ let mut lhs_ref = args. get ( 0 ) ?;
133+ let mut rhs_ref = args. get ( 1 ) ?;
131134 let len = args. row_count ( ) ;
132135
133- // Compute combined validity.
134- let validity = lhs. as_ref ( ) . validity ( ) ?. and ( rhs. as_ref ( ) . validity ( ) ?) ?;
136+ // Check if any of our children have be already normalized.
137+ {
138+ let lhs_is_denorm = lhs_ref. is :: < ExactScalarFn < L2Denorm > > ( ) ;
139+ let rhs_is_denorm = rhs_ref. is :: < ExactScalarFn < L2Denorm > > ( ) ;
140+
141+ if lhs_is_denorm && rhs_is_denorm {
142+ return self . execute_both_denorm ( options, & lhs_ref, & rhs_ref, len, ctx) ;
143+ } else if lhs_is_denorm || rhs_is_denorm {
144+ if rhs_is_denorm {
145+ ( lhs_ref, rhs_ref) = ( rhs_ref, lhs_ref) ;
146+ }
147+ return self . execute_one_denorm ( options, & lhs_ref, & rhs_ref, len, ctx) ;
148+ }
149+ }
135150
136- let lhs = lhs . into_array ( ) ;
137- let rhs = rhs . into_array ( ) ;
151+ // Compute combined validity.
152+ let validity = lhs_ref . validity ( ) ? . and ( rhs_ref . validity ( ) ? ) ? ;
138153
139154 // Compute inner product and norms as columnar operations, and propagate the options.
140- let norm_lhs_arr = L2Norm :: try_new_array ( options, lhs . clone ( ) , len) ?;
141- let norm_rhs_arr = L2Norm :: try_new_array ( options, rhs . clone ( ) , len) ?;
142- let dot_arr = InnerProduct :: try_new_array ( options, lhs , rhs , len) ?;
155+ let norm_lhs_arr = L2Norm :: try_new_array ( options, lhs_ref . clone ( ) , len) ?;
156+ let norm_rhs_arr = L2Norm :: try_new_array ( options, rhs_ref . clone ( ) , len) ?;
157+ let dot_arr = InnerProduct :: try_new_array ( options, lhs_ref , rhs_ref , len) ?;
143158
144- // Execute to get PrimitiveArrays.
159+ // Execute to get the inner product and norms of the arrays. We only fully decompress
160+ // because we need to perform special logic (guard against 0) during division.
145161 let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
146162 let norm_l: PrimitiveArray = norm_lhs_arr. into_array ( ) . execute ( ctx) ?;
147163 let norm_r: PrimitiveArray = norm_rhs_arr. into_array ( ) . execute ( ctx) ?;
148164
149- // Divide element-wise, guarding against zero norms.
165+ // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
166+ // TODO(connor): This can be written in a more SIMD-friendly manner.
150167 match_each_float_ptype ! ( dot. ptype( ) , |T | {
151168 let dots = dot. as_slice:: <T >( ) ;
152169 let norms_l = norm_l. as_slice:: <T >( ) ;
153170 let norms_r = norm_r. as_slice:: <T >( ) ;
154171 let buffer: Buffer <T > = ( 0 ..len)
155172 . map( |i| {
156- // TODO(connor): Would it be better to make this a binary multiply?
157- // What happens when this overflows???
158173 let denom = norms_l[ i] * norms_r[ i] ;
159174
160175 if denom == T :: zero( ) {
@@ -191,6 +206,68 @@ impl ScalarFnVTable for CosineSimilarity {
191206 }
192207}
193208
209+ impl CosineSimilarity {
210+ /// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`.
211+ fn execute_both_denorm (
212+ & self ,
213+ options : & ApproxOptions ,
214+ lhs_ref : & ArrayRef ,
215+ rhs_ref : & ArrayRef ,
216+ len : usize ,
217+ _ctx : & mut ExecutionCtx ,
218+ ) -> VortexResult < ArrayRef > {
219+ let validity = lhs_ref. validity ( ) ?. and ( rhs_ref. validity ( ) ?) ?;
220+
221+ let ( normalized_l, _) = extract_l2_denorm_children ( lhs_ref) ;
222+ let ( normalized_r, _) = extract_l2_denorm_children ( rhs_ref) ;
223+
224+ // Dot product of already-normalized children IS the cosine similarity.
225+ InnerProduct :: try_new_array ( options, normalized_l, normalized_r, len) ?
226+ . into_array ( )
227+ . mask ( validity. to_array ( len) )
228+ }
229+
230+ /// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`.
231+ ///
232+ /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`.
233+ fn execute_one_denorm (
234+ & self ,
235+ options : & ApproxOptions ,
236+ denorm_ref : & ArrayRef ,
237+ plain_ref : & ArrayRef ,
238+ len : usize ,
239+ ctx : & mut ExecutionCtx ,
240+ ) -> VortexResult < ArrayRef > {
241+ let validity = denorm_ref. validity ( ) ?. and ( plain_ref. validity ( ) ?) ?;
242+
243+ let ( normalized, _) = extract_l2_denorm_children ( denorm_ref) ;
244+
245+ let dot_arr = InnerProduct :: try_new_array ( options, normalized, plain_ref. clone ( ) , len) ?;
246+ let norm_arr = L2Norm :: try_new_array ( options, plain_ref. clone ( ) , len) ?;
247+ let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
248+ let plain_norm: PrimitiveArray = norm_arr. into_array ( ) . execute ( ctx) ?;
249+
250+ // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
251+ // TODO(connor): This can be written in a more SIMD-friendly manner.
252+ match_each_float_ptype ! ( dot. ptype( ) , |T | {
253+ let dots = dot. as_slice:: <T >( ) ;
254+ let norms = plain_norm. as_slice:: <T >( ) ;
255+ let buffer: Buffer <T > = ( 0 ..len)
256+ . map( |i| {
257+ if norms[ i] == T :: zero( ) {
258+ T :: zero( )
259+ } else {
260+ dots[ i] / norms[ i]
261+ }
262+ } )
263+ . collect( ) ;
264+
265+ // SAFETY: The buffer length equals `len`, which matches the source validity length.
266+ Ok ( unsafe { PrimitiveArray :: new_unchecked( buffer, validity) } . into_array( ) )
267+ } )
268+ }
269+ }
270+
194271#[ cfg( test) ]
195272mod tests {
196273 use std:: sync:: LazyLock ;
@@ -403,4 +480,105 @@ mod tests {
403480 assert_close ( & [ prim. as_slice :: < f64 > ( ) [ 0 ] ] , & [ 1.0 ] ) ;
404481 Ok ( ( ) )
405482 }
483+
484+ /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms.
485+ fn l2_denorm_array (
486+ shape : & [ usize ] ,
487+ normalized_elements : & [ f64 ] ,
488+ norms : & [ f64 ] ,
489+ ) -> VortexResult < ArrayRef > {
490+ let len = norms. len ( ) ;
491+ let normalized = tensor_array ( shape, normalized_elements) ?;
492+ let norms = PrimitiveArray :: from_iter ( norms. iter ( ) . copied ( ) ) . into_array ( ) ;
493+ Ok ( crate :: scalar_fns:: l2_denorm:: L2Denorm :: try_new_array (
494+ & ApproxOptions :: Exact ,
495+ normalized,
496+ norms,
497+ len,
498+ ) ?
499+ . into_array ( ) )
500+ }
501+
502+ #[ test]
503+ fn both_denorm_self_similarity ( ) -> VortexResult < ( ) > {
504+ // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
505+ // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
506+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
507+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
508+
509+ // Self-similarity should always be 1.0.
510+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 1.0 ] ) ;
511+ Ok ( ( ) )
512+ }
513+
514+ #[ test]
515+ fn both_denorm_orthogonal ( ) -> VortexResult < ( ) > {
516+ // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
517+ // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
518+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 1.0 , 0.0 ] , & [ 3.0 ] ) ?;
519+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.0 , 1.0 ] , & [ 4.0 ] ) ?;
520+
521+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 0.0 ] ) ;
522+ Ok ( ( ) )
523+ }
524+
525+ #[ test]
526+ fn both_denorm_zero_norm ( ) -> VortexResult < ( ) > {
527+ // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
528+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 0.0 , 0.0 ] , & [ 5.0 , 0.0 ] ) ?;
529+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
530+
531+ // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
532+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
533+ Ok ( ( ) )
534+ }
535+
536+ #[ test]
537+ fn one_side_denorm_lhs ( ) -> VortexResult < ( ) > {
538+ // LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
539+ // RHS is plain [3.0, 4.0].
540+ // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
541+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 ] , & [ 5.0 ] ) ?;
542+ let rhs = tensor_array ( & [ 2 ] , & [ 3.0 , 4.0 ] ) ?;
543+
544+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 1.0 ] ) ;
545+ Ok ( ( ) )
546+ }
547+
548+ #[ test]
549+ fn one_side_denorm_rhs ( ) -> VortexResult < ( ) > {
550+ // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
551+ // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
552+ let lhs = tensor_array ( & [ 2 ] , & [ 1.0 , 0.0 ] ) ?;
553+ let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 ] , & [ 5.0 ] ) ?;
554+
555+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 0.6 ] ) ;
556+ Ok ( ( ) )
557+ }
558+
559+ #[ test]
560+ fn both_denorm_null_norms ( ) -> VortexResult < ( ) > {
561+ // Row 0: valid, row 1: null (via nullable norms on rhs).
562+ let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] ) ?;
563+
564+ let normalized_r = tensor_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] ) ?;
565+ let norms_r = PrimitiveArray :: from_option_iter ( [ Some ( 5.0f64 ) , None ] ) . into_array ( ) ;
566+ let rhs = crate :: scalar_fns:: l2_denorm:: L2Denorm :: try_new_array (
567+ & ApproxOptions :: Exact ,
568+ normalized_r,
569+ norms_r,
570+ 2 ,
571+ ) ?
572+ . into_array ( ) ;
573+
574+ let scalar_fn = ScalarFn :: new ( CosineSimilarity , ApproxOptions :: Exact ) . erased ( ) ;
575+ let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ lhs, rhs] , 2 ) ?;
576+ let mut ctx = SESSION . create_execution_ctx ( ) ;
577+ let prim: PrimitiveArray = result. into_array ( ) . execute ( & mut ctx) ?;
578+
579+ assert ! ( prim. is_valid( 0 ) ?) ;
580+ assert ! ( !prim. is_valid( 1 ) ?) ;
581+ assert_close ( & [ prim. as_slice :: < f64 > ( ) [ 0 ] ] , & [ 1.0 ] ) ;
582+ Ok ( ( ) )
583+ }
406584}
0 commit comments