@@ -35,10 +35,9 @@ use vortex_session::VortexSession;
3535
3636use crate :: scalar_fns:: inner_product:: BinaryTensorOpMetadata ;
3737use crate :: scalar_fns:: inner_product:: InnerProduct ;
38- use crate :: scalar_fns:: l2_denorm:: DenormOrientation ;
38+ use crate :: scalar_fns:: l2_denorm:: NormalForm ;
3939use crate :: scalar_fns:: l2_denorm:: try_build_constant_l2_denorm;
4040use crate :: scalar_fns:: l2_norm:: L2Norm ;
41- use crate :: utils:: extract_l2_denorm_children;
4241use crate :: utils:: validate_binary_tensor_float_inputs;
4342
4443/// Cosine similarity between two columns.
@@ -141,15 +140,21 @@ impl ScalarFnVTable for CosineSimilarity {
141140 rhs_ref = sfn. into_array ( ) ;
142141 }
143142
144- // Take any L2Denorm-wrapped fast path that applies.
145- match DenormOrientation :: classify ( & lhs_ref, & rhs_ref) {
146- DenormOrientation :: Both { lhs, rhs } => {
147- return self . execute_both_denorm ( lhs, rhs, len) ;
143+ // Classify each operand by its normal form. When both operands carry a known unit-norm
144+ // representation, cosine similarity collapses to the dot product of the unit vectors.
145+ let lhs_form = NormalForm :: classify ( & lhs_ref) ;
146+ let rhs_form = NormalForm :: classify ( & rhs_ref) ;
147+ match ( lhs_form. unit_array ( ) , rhs_form. unit_array ( ) ) {
148+ ( Some ( unit_lhs) , Some ( unit_rhs) ) => {
149+ return self . execute_both_unit ( unit_lhs, unit_rhs, & lhs_ref, & rhs_ref, len) ;
148150 }
149- DenormOrientation :: One { denorm , plain } => {
150- return self . execute_one_denorm ( denorm , plain , len, ctx) ;
151+ ( Some ( unit_lhs ) , None ) => {
152+ return self . execute_one_unit ( unit_lhs , & rhs_ref , & lhs_ref , len, ctx) ;
151153 }
152- DenormOrientation :: Neither => { }
154+ ( None , Some ( unit_rhs) ) => {
155+ return self . execute_one_unit ( unit_rhs, & lhs_ref, & rhs_ref, len, ctx) ;
156+ }
157+ ( None , None ) => { }
153158 }
154159
155160 // Compute combined validity.
@@ -242,22 +247,20 @@ impl ScalarFnArrayVTable for CosineSimilarity {
242247}
243248
244249impl CosineSimilarity {
245- /// Both sides are `L2Denorm`: treat the normalized children as authoritative, so
246- /// `cosine_similarity = dot(n_l, n_r)` .
247- fn execute_both_denorm (
250+ /// Both sides carry a known unit-norm representation: cosine similarity collapses to the
251+ /// dot product of the unit children .
252+ fn execute_both_unit (
248253 & self ,
254+ unit_lhs : & ArrayRef ,
255+ unit_rhs : & ArrayRef ,
249256 lhs_ref : & ArrayRef ,
250257 rhs_ref : & ArrayRef ,
251258 len : usize ,
252259 ) -> VortexResult < ArrayRef > {
253260 let validity = lhs_ref. validity ( ) ?. and ( rhs_ref. validity ( ) ?) ?;
254261
255- let ( normalized_l, _) = extract_l2_denorm_children ( lhs_ref) ;
256- let ( normalized_r, _) = extract_l2_denorm_children ( rhs_ref) ;
257-
258- // `L2Denorm` makes the normalized children authoritative, so their dot product is the
259- // cosine similarity even for lossy storage wrappers.
260- let dot = InnerProduct :: try_new_array ( normalized_l, normalized_r, len) ?. into_array ( ) ;
262+ let dot =
263+ InnerProduct :: try_new_array ( unit_lhs. clone ( ) , unit_rhs. clone ( ) , len) ?. into_array ( ) ;
261264
262265 if !matches ! ( validity, Validity :: NonNullable ) {
263266 // Masking always changes the nullability to nullable.
@@ -267,22 +270,21 @@ impl CosineSimilarity {
267270 }
268271 }
269272
270- /// One side is `L2Denorm`: treat the normalized child as authoritative, so
271- /// `cosine_similarity = dot(n, b ) / ||b ||`.
272- ///
273- /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref` .
274- fn execute_one_denorm (
273+ /// Exactly one side carries a unit-norm representation: cosine similarity reduces to
274+ /// `dot(unit, other ) / ||other ||`. The norms of the unit side are implicitly `1.0` (naked
275+ /// `NormalizedVector`) or stored separately (the outer `L2Denorm` wrapper, which is not
276+ /// needed here since cosine ignores magnitude) .
277+ fn execute_one_unit (
275278 & self ,
276- denorm_ref : & ArrayRef ,
279+ unit : & ArrayRef ,
277280 plain_ref : & ArrayRef ,
281+ unit_ref : & ArrayRef ,
278282 len : usize ,
279283 ctx : & mut ExecutionCtx ,
280284 ) -> VortexResult < ArrayRef > {
281- let validity = denorm_ref . validity ( ) ?. and ( plain_ref. validity ( ) ?) ?;
285+ let validity = unit_ref . validity ( ) ?. and ( plain_ref. validity ( ) ?) ?;
282286
283- let ( normalized, _) = extract_l2_denorm_children ( denorm_ref) ;
284-
285- let dot_arr = InnerProduct :: try_new_array ( normalized, plain_ref. clone ( ) , len) ?;
287+ let dot_arr = InnerProduct :: try_new_array ( unit. clone ( ) , plain_ref. clone ( ) , len) ?;
286288 let dot: PrimitiveArray = dot_arr. into_array ( ) . execute ( ctx) ?;
287289
288290 let norm_arr = L2Norm :: try_new_array ( plain_ref. clone ( ) , len) ?;
@@ -331,6 +333,7 @@ mod tests {
331333 use crate :: utils:: test_helpers:: assert_close;
332334 use crate :: utils:: test_helpers:: constant_tensor_array;
333335 use crate :: utils:: test_helpers:: l2_denorm_array;
336+ use crate :: utils:: test_helpers:: normalized_vector_array;
334337 use crate :: utils:: test_helpers:: tensor_array;
335338 use crate :: utils:: test_helpers:: vector_array;
336339
@@ -519,13 +522,25 @@ mod tests {
519522 Ok ( ( ) )
520523 }
521524
525+ /// Naked [`NormalizedVector`](crate::normalized_vector::NormalizedVector) operands take the
526+ /// fast path: cosine similarity collapses to the dot product without computing norms.
527+ #[ test]
528+ fn naked_normalized_vector_cosine ( ) -> VortexResult < ( ) > {
529+ let mut ctx = SESSION . create_execution_ctx ( ) ;
530+ let lhs = normalized_vector_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & mut ctx) ?;
531+ let rhs = normalized_vector_array ( 2 , & [ 0.6 , 0.8 , 0.0 , 1.0 ] , & mut ctx) ?;
532+ // Row 0: identical -> 1.0, Row 1: orthogonal -> 0.0.
533+ assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
534+ Ok ( ( ) )
535+ }
536+
522537 #[ test]
523538 fn both_denorm_self_similarity ( ) -> VortexResult < ( ) > {
524539 // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
525540 // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
526541 let mut ctx = SESSION . create_execution_ctx ( ) ;
527- let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
528- let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
542+ let lhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
543+ let rhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
529544
530545 // Self-similarity should always be 1.0.
531546 assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 1.0 ] ) ;
@@ -537,8 +552,8 @@ mod tests {
537552 // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
538553 // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
539554 let mut ctx = SESSION . create_execution_ctx ( ) ;
540- let lhs = l2_denorm_array ( & [ 2 ] , & [ 1.0 , 0.0 ] , & [ 3.0 ] , & mut ctx) ?;
541- let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.0 , 1.0 ] , & [ 4.0 ] , & mut ctx) ?;
555+ let lhs = l2_denorm_array ( 2 , & [ 1.0 , 0.0 ] , & [ 3.0 ] , & mut ctx) ?;
556+ let rhs = l2_denorm_array ( 2 , & [ 0.0 , 1.0 ] , & [ 4.0 ] , & mut ctx) ?;
542557
543558 assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 0.0 ] ) ;
544559 Ok ( ( ) )
@@ -548,8 +563,8 @@ mod tests {
548563 fn both_denorm_zero_norm ( ) -> VortexResult < ( ) > {
549564 // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
550565 let mut ctx = SESSION . create_execution_ctx ( ) ;
551- let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 0.0 , 0.0 ] , & [ 5.0 , 0.0 ] , & mut ctx) ?;
552- let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
566+ let lhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 , 0.0 , 0.0 ] , & [ 5.0 , 0.0 ] , & mut ctx) ?;
567+ let rhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
553568
554569 // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
555570 assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
@@ -562,8 +577,8 @@ mod tests {
562577 // RHS is plain [3.0, 4.0].
563578 // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
564579 let mut ctx = SESSION . create_execution_ctx ( ) ;
565- let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 ] , & [ 5.0 ] , & mut ctx) ?;
566- let rhs = tensor_array ( & [ 2 ] , & [ 3.0 , 4.0 ] ) ?;
580+ let lhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 ] , & [ 5.0 ] , & mut ctx) ?;
581+ let rhs = vector_array ( 2 , & [ 3.0 , 4.0 ] ) ?;
567582
568583 assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 1.0 ] ) ;
569584 Ok ( ( ) )
@@ -574,8 +589,8 @@ mod tests {
574589 // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
575590 // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
576591 let mut ctx = SESSION . create_execution_ctx ( ) ;
577- let lhs = tensor_array ( & [ 2 ] , & [ 1.0 , 0.0 ] ) ?;
578- let rhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 ] , & [ 5.0 ] , & mut ctx) ?;
592+ let lhs = vector_array ( 2 , & [ 1.0 , 0.0 ] ) ?;
593+ let rhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 ] , & [ 5.0 ] , & mut ctx) ?;
579594
580595 assert_close ( & eval_cosine_similarity ( lhs, rhs, 1 ) ?, & [ 0.6 ] ) ;
581596 Ok ( ( ) )
@@ -585,9 +600,9 @@ mod tests {
585600 fn both_denorm_null_norms ( ) -> VortexResult < ( ) > {
586601 // Row 0: valid, row 1: null (via nullable norms on rhs).
587602 let mut ctx = SESSION . create_execution_ctx ( ) ;
588- let lhs = l2_denorm_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
603+ let lhs = l2_denorm_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & [ 5.0 , 1.0 ] , & mut ctx) ?;
589604
590- let normalized_r = tensor_array ( & [ 2 ] , & [ 0.6 , 0.8 , 1.0 , 0.0 ] ) ?;
605+ let normalized_r = normalized_vector_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & mut ctx ) ?;
591606 let norms_r = PrimitiveArray :: from_option_iter ( [ Some ( 5.0f64 ) , None ] ) . into_array ( ) ;
592607 let rhs = L2Denorm :: try_new_array ( normalized_r, norms_r, 2 , & mut ctx) ?. into_array ( ) ;
593608
@@ -703,6 +718,34 @@ mod tests {
703718 Ok ( ( ) )
704719 }
705720
721+ #[ test]
722+ fn serde_round_trip_mixed_vector_and_normalized_vector ( ) -> VortexResult < ( ) > {
723+ let mut ctx = SESSION . create_execution_ctx ( ) ;
724+ let lhs = normalized_vector_array ( 2 , & [ 0.6 , 0.8 , 1.0 , 0.0 ] , & mut ctx) ?;
725+ let rhs = vector_array ( 2 , & [ 3.0 , 4.0 , 0.0 , 1.0 ] ) ?;
726+ let original = CosineSimilarity :: try_new_array ( lhs. clone ( ) , rhs. clone ( ) , 2 ) ?. into_array ( ) ;
727+
728+ let plugin = ScalarFnArrayPlugin :: new ( CosineSimilarity ) ;
729+ let metadata = plugin
730+ . serialize ( & original, & SESSION ) ?
731+ . expect ( "CosineSimilarity serialize must produce metadata" ) ;
732+
733+ let children = vec ! [ lhs, rhs] ;
734+ let recovered = plugin. deserialize (
735+ original. dtype ( ) ,
736+ original. len ( ) ,
737+ & metadata,
738+ & [ ] ,
739+ & children,
740+ & SESSION ,
741+ ) ?;
742+
743+ assert_eq ! ( recovered. dtype( ) , original. dtype( ) ) ;
744+ assert_eq ! ( recovered. len( ) , original. len( ) ) ;
745+ assert_eq ! ( recovered. encoding_id( ) , original. encoding_id( ) ) ;
746+ Ok ( ( ) )
747+ }
748+
706749 #[ rstest]
707750 #[ case:: vector(
708751 vector_array( 3 , & [ 1.0 , 0.0 , 0.0 , 3.0 , 4.0 , 0.0 ] ) . unwrap( ) ,
0 commit comments