@@ -32,6 +32,7 @@ use vortex_error::vortex_ensure;
3232
3333use crate :: scalar_fns:: inner_product:: InnerProduct ;
3434use crate :: scalar_fns:: l2_denorm:: L2Denorm ;
35+ use crate :: scalar_fns:: l2_denorm:: try_build_constant_l2_denorm;
3536use crate :: scalar_fns:: l2_norm:: L2Norm ;
3637use crate :: utils:: extract_l2_denorm_children;
3738use crate :: utils:: validate_tensor_float_input;
@@ -133,6 +134,20 @@ impl ScalarFnVTable for CosineSimilarity {
133134 let mut rhs_ref = args. get ( 1 ) ?;
134135 let len = args. row_count ( ) ;
135136
137+ // If either side is a constant tensor-like extension array, eagerly normalize the single
138+ // stored row and re-wrap it as an `L2Denorm` whose children are both [`ConstantArray`]s.
139+ // The L2Denorm fast path below then picks it up.
140+ if let Some ( lhs_constant) =
141+ try_build_constant_l2_denorm ( & lhs_ref, len, ctx) ?. map ( |sfn| sfn. into_array ( ) )
142+ {
143+ lhs_ref = lhs_constant;
144+ }
145+ if let Some ( rhs_constant) =
146+ try_build_constant_l2_denorm ( & rhs_ref, len, ctx) ?. map ( |sfn| sfn. into_array ( ) )
147+ {
148+ rhs_ref = rhs_constant;
149+ }
150+
136151 // Check if any of our children have be already normalized.
137152 {
138153 let lhs_is_denorm = lhs_ref. is :: < ExactScalarFn < L2Denorm > > ( ) ;
@@ -249,8 +264,9 @@ impl CosineSimilarity {
249264 let ( normalized, _) = extract_l2_denorm_children ( denorm_ref) ;
250265
251266 let dot_arr = InnerProduct :: try_new_array ( normalized, plain_ref. clone ( ) , len) ?;
252- let norm_arr = L2Norm :: try_new_array ( plain_ref. clone ( ) , len) ?;
253267 let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
268+
269+ let norm_arr = L2Norm :: try_new_array ( plain_ref. clone ( ) , len) ?;
254270 let plain_norm: PrimitiveArray = norm_arr. into_array ( ) . execute ( ctx) ?;
255271
256272 // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
@@ -575,4 +591,106 @@ mod tests {
575591 assert_close ( & [ prim. as_slice :: < f64 > ( ) [ 0 ] ] , & [ 1.0 ] ) ;
576592 Ok ( ( ) )
577593 }
594+
595+ #[ test]
596+ fn constant_lhs_matches_plain_tensor ( ) -> VortexResult < ( ) > {
597+ // The constant query `[1, 2, 2]` has norm 3, so its normalized form is `[1/3, 2/3, 2/3]`.
598+ // Expected cosine similarity against each row is `dot([1, 2, 2], row) / (3 * ||row||)`.
599+ let lhs = constant_tensor_array ( & [ 3 ] , & [ 1.0 , 2.0 , 2.0 ] , 4 ) ?;
600+ let rhs = tensor_array (
601+ & [ 3 ] ,
602+ & [
603+ 1.0 , 0.0 , 0.0 , // dot=1, ||rhs||=1, expected=1/3
604+ 1.0 , 2.0 , 2.0 , // dot=9, ||rhs||=3, expected=1
605+ 0.0 , 0.0 , 1.0 , // dot=2, ||rhs||=1, expected=2/3
606+ 2.0 , 1.0 , 2.0 , // dot=8, ||rhs||=3, expected=8/9
607+ ] ,
608+ ) ?;
609+ assert_close (
610+ & eval_cosine_similarity ( lhs, rhs, 4 ) ?,
611+ & [ 1.0 / 3.0 , 1.0 , 2.0 / 3.0 , 8.0 / 9.0 ] ,
612+ ) ;
613+ Ok ( ( ) )
614+ }
615+
616+ #[ test]
617+ fn constant_rhs_matches_plain_tensor ( ) -> VortexResult < ( ) > {
618+ // Mirror of `constant_lhs_matches_plain_tensor` with the constant on the right.
619+ let lhs = tensor_array (
620+ & [ 3 ] ,
621+ & [
622+ 1.0 , 0.0 , 0.0 , //
623+ 1.0 , 2.0 , 2.0 , //
624+ 0.0 , 0.0 , 1.0 , //
625+ 2.0 , 1.0 , 2.0 , //
626+ ] ,
627+ ) ?;
628+ let rhs = constant_tensor_array ( & [ 3 ] , & [ 1.0 , 2.0 , 2.0 ] , 4 ) ?;
629+ assert_close (
630+ & eval_cosine_similarity ( lhs, rhs, 4 ) ?,
631+ & [ 1.0 / 3.0 , 1.0 , 2.0 / 3.0 , 8.0 / 9.0 ] ,
632+ ) ;
633+ Ok ( ( ) )
634+ }
635+
636+ #[ test]
637+ fn both_constant_tensors ( ) -> VortexResult < ( ) > {
638+ // `[1, 0, 0]` vs `[1, 1, 0]`. dot=1, ||lhs||=1, ||rhs||=sqrt(2), expected=1/sqrt(2).
639+ let lhs = constant_tensor_array ( & [ 3 ] , & [ 1.0 , 0.0 , 0.0 ] , 3 ) ?;
640+ let rhs = constant_tensor_array ( & [ 3 ] , & [ 1.0 , 1.0 , 0.0 ] , 3 ) ?;
641+ let expected = 1.0 / 2.0_f64 . sqrt ( ) ;
642+ assert_close (
643+ & eval_cosine_similarity ( lhs, rhs, 3 ) ?,
644+ & [ expected, expected, expected] ,
645+ ) ;
646+ Ok ( ( ) )
647+ }
648+
649+ #[ test]
650+ fn constant_zero_norm_query ( ) -> VortexResult < ( ) > {
651+ // A zero-norm constant query must produce `0.0` for every row via the zero-norm guard in
652+ // `execute_one_denorm` and `execute_both_denorm`.
653+ let lhs = constant_tensor_array ( & [ 3 ] , & [ 0.0 , 0.0 , 0.0 ] , 3 ) ?;
654+ let rhs = tensor_array (
655+ & [ 3 ] ,
656+ & [
657+ 1.0 , 2.0 , 3.0 , //
658+ 4.0 , 5.0 , 6.0 , //
659+ 7.0 , 8.0 , 9.0 , //
660+ ] ,
661+ ) ?;
662+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 3 ) ?, & [ 0.0 , 0.0 , 0.0 ] ) ;
663+ Ok ( ( ) )
664+ }
665+
666+ #[ test]
667+ fn constant_self_similarity_nonunit ( ) -> VortexResult < ( ) > {
668+ // A non-unit constant query compared to itself must produce `1.0`. This exercises the
669+ // helper's division: after normalization, both sides must be exactly unit so the
670+ // L2Denorm fast path's inner product yields 1.
671+ let lhs = constant_tensor_array ( & [ 3 ] , & [ 3.0 , 4.0 , 0.0 ] , 5 ) ?;
672+ let rhs = constant_tensor_array ( & [ 3 ] , & [ 3.0 , 4.0 , 0.0 ] , 5 ) ?;
673+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 5 ) ?, & [ 1.0 ; 5 ] ) ;
674+ Ok ( ( ) )
675+ }
676+
677+ #[ test]
678+ fn vector_constant_matches_plain ( ) -> VortexResult < ( ) > {
679+ // Exercise the `Vector` extension variant through the new pre-pass.
680+ let lhs = constant_vector_array ( & [ 1.0 , 2.0 , 2.0 ] , 4 ) ?;
681+ let rhs = vector_array (
682+ 3 ,
683+ & [
684+ 1.0 , 0.0 , 0.0 , //
685+ 1.0 , 2.0 , 2.0 , //
686+ 0.0 , 0.0 , 1.0 , //
687+ 2.0 , 1.0 , 2.0 , //
688+ ] ,
689+ ) ?;
690+ assert_close (
691+ & eval_cosine_similarity ( lhs, rhs, 4 ) ?,
692+ & [ 1.0 / 3.0 , 1.0 , 2.0 / 3.0 , 8.0 / 9.0 ] ,
693+ ) ;
694+ Ok ( ( ) )
695+ }
578696}
0 commit comments