@@ -12,6 +12,9 @@ use vortex_array::IntoArray;
1212use vortex_array:: arrays:: PrimitiveArray ;
1313use vortex_array:: arrays:: ScalarFnArray ;
1414use vortex_array:: arrays:: scalar_fn:: ExactScalarFn ;
15+ use vortex_array:: arrays:: scalar_fn:: ScalarFnArrayView ;
16+ use vortex_array:: arrays:: scalar_fn:: plugin:: ScalarFnArrayParts ;
17+ use vortex_array:: arrays:: scalar_fn:: plugin:: ScalarFnArrayVTable ;
1518use vortex_array:: builtins:: ArrayBuiltins ;
1619use vortex_array:: dtype:: DType ;
1720use vortex_array:: dtype:: Nullability ;
@@ -25,11 +28,14 @@ use vortex_array::scalar_fn::ExecutionArgs;
2528use vortex_array:: scalar_fn:: ScalarFn ;
2629use vortex_array:: scalar_fn:: ScalarFnId ;
2730use vortex_array:: scalar_fn:: ScalarFnVTable ;
31+ use vortex_array:: serde:: ArrayChildren ;
2832use vortex_array:: validity:: Validity ;
2933use vortex_buffer:: Buffer ;
3034use vortex_error:: VortexResult ;
3135use vortex_error:: vortex_ensure;
36+ use vortex_session:: VortexSession ;
3237
38+ use crate :: scalar_fns:: inner_product:: BinaryTensorOpMetadata ;
3339use crate :: scalar_fns:: inner_product:: InnerProduct ;
3440use crate :: scalar_fns:: l2_denorm:: L2Denorm ;
3541use crate :: scalar_fns:: l2_denorm:: try_build_constant_l2_denorm;
@@ -221,6 +227,37 @@ impl ScalarFnVTable for CosineSimilarity {
221227 }
222228}
223229
230+ impl ScalarFnArrayVTable for CosineSimilarity {
231+ fn serialize (
232+ & self ,
233+ view : & ScalarFnArrayView < Self > ,
234+ _session : & VortexSession ,
235+ ) -> VortexResult < Option < Vec < u8 > > > {
236+ Ok ( Some ( BinaryTensorOpMetadata :: encode_from_view ( view) ?) )
237+ }
238+
239+ fn deserialize (
240+ & self ,
241+ _dtype : & DType ,
242+ len : usize ,
243+ metadata : & [ u8 ] ,
244+ children : & dyn ArrayChildren ,
245+ session : & VortexSession ,
246+ ) -> VortexResult < ScalarFnArrayParts < Self > > {
247+ let reconstructed = BinaryTensorOpMetadata :: decode_children (
248+ metadata,
249+ len,
250+ children,
251+ session,
252+ "CosineSimilarity" ,
253+ ) ?;
254+ Ok ( ScalarFnArrayParts {
255+ options : EmptyOptions ,
256+ children : reconstructed,
257+ } )
258+ }
259+ }
260+
224261impl CosineSimilarity {
225262 /// Both sides are `L2Denorm`: treat the normalized children as authoritative, so
226263 /// `cosine_similarity = dot(n_l, n_r)`.
@@ -295,12 +332,14 @@ mod tests {
295332 use std:: sync:: LazyLock ;
296333
297334 use rstest:: rstest;
335+ use vortex_array:: ArrayPlugin ;
298336 use vortex_array:: ArrayRef ;
299337 use vortex_array:: IntoArray ;
300338 use vortex_array:: VortexSessionExecute ;
301339 use vortex_array:: arrays:: MaskedArray ;
302340 use vortex_array:: arrays:: PrimitiveArray ;
303341 use vortex_array:: arrays:: ScalarFnArray ;
342+ use vortex_array:: arrays:: scalar_fn:: plugin:: ScalarFnArrayPlugin ;
304343 use vortex_array:: session:: ArraySession ;
305344 use vortex_array:: validity:: Validity ;
306345 use vortex_error:: VortexResult ;
@@ -314,8 +353,11 @@ mod tests {
314353 use crate :: utils:: test_helpers:: tensor_array;
315354 use crate :: utils:: test_helpers:: vector_array;
316355
317- static SESSION : LazyLock < VortexSession > =
318- LazyLock :: new ( || VortexSession :: empty ( ) . with :: < ArraySession > ( ) ) ;
356+ static SESSION : LazyLock < VortexSession > = LazyLock :: new ( || {
357+ let session = VortexSession :: empty ( ) . with :: < ArraySession > ( ) ;
358+ crate :: initialize ( & session) ;
359+ session
360+ } ) ;
319361
320362 /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
321363 fn eval_cosine_similarity ( lhs : ArrayRef , rhs : ArrayRef , len : usize ) -> VortexResult < Vec < f64 > > {
@@ -693,4 +735,43 @@ mod tests {
693735 ) ;
694736 Ok ( ( ) )
695737 }
738+
739+ #[ rstest]
740+ #[ case:: vector(
741+ vector_array( 3 , & [ 1.0 , 0.0 , 0.0 , 3.0 , 4.0 , 0.0 ] ) . unwrap( ) ,
742+ vector_array( 3 , & [ 0.0 , 1.0 , 0.0 , 3.0 , 4.0 , 0.0 ] ) . unwrap( ) ,
743+ 2 ,
744+ ) ]
745+ #[ case:: fixed_shape_tensor(
746+ tensor_array( & [ 2 ] , & [ 1.0 , 0.0 , 3.0 , 4.0 ] ) . unwrap( ) ,
747+ tensor_array( & [ 2 ] , & [ 0.0 , 1.0 , 3.0 , 4.0 ] ) . unwrap( ) ,
748+ 2 ,
749+ ) ]
750+ fn serde_round_trip (
751+ #[ case] lhs : ArrayRef ,
752+ #[ case] rhs : ArrayRef ,
753+ #[ case] len : usize ,
754+ ) -> VortexResult < ( ) > {
755+ let original = CosineSimilarity :: try_new_array ( lhs. clone ( ) , rhs. clone ( ) , len) ?. into_array ( ) ;
756+
757+ let plugin = ScalarFnArrayPlugin :: new ( CosineSimilarity ) ;
758+ let metadata = plugin
759+ . serialize ( & original, & SESSION ) ?
760+ . expect ( "CosineSimilarity serialize must produce metadata" ) ;
761+
762+ let children = vec ! [ lhs, rhs] ;
763+ let recovered = plugin. deserialize (
764+ original. dtype ( ) ,
765+ original. len ( ) ,
766+ & metadata,
767+ & [ ] ,
768+ & children,
769+ & SESSION ,
770+ ) ?;
771+
772+ assert_eq ! ( recovered. dtype( ) , original. dtype( ) ) ;
773+ assert_eq ! ( recovered. len( ) , original. len( ) ) ;
774+ assert_eq ! ( recovered. encoding_id( ) , original. encoding_id( ) ) ;
775+ Ok ( ( ) )
776+ }
696777}
0 commit comments