@@ -19,6 +19,7 @@ use vortex::dtype::Nullability;
1919use vortex:: dtype:: extension:: Matcher ;
2020use vortex:: error:: VortexResult ;
2121use vortex:: error:: vortex_ensure;
22+ use vortex:: error:: vortex_ensure_eq;
2223use vortex:: error:: vortex_err;
2324use vortex:: expr:: Expression ;
2425use vortex:: scalar_fn:: Arity ;
@@ -81,33 +82,38 @@ impl ScalarFnVTable for CosineSimilarity {
8182 }
8283
8384 fn return_dtype ( & self , _options : & Self :: Options , arg_dtypes : & [ DType ] ) -> VortexResult < DType > {
84- debug_assert_eq ! ( arg_dtypes. len( ) , 2 ) ;
85+ vortex_ensure_eq ! (
86+ arg_dtypes. len( ) ,
87+ 2 ,
88+ "CosineSimilarity requires exactly 2 arguments, got {}" ,
89+ arg_dtypes. len( )
90+ ) ;
8591
8692 let lhs = & arg_dtypes[ 0 ] ;
8793 let rhs = & arg_dtypes[ 1 ] ;
8894
8995 // Both must have the same dtype (ignoring top-level nullability).
9096 vortex_ensure ! (
9197 lhs. eq_ignore_nullability( rhs) ,
92- "cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
98+ "CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
9399 ) ;
94100
95101 // We don't need to look at rhs anymore since we know lhs and rhs are equal.
96102
97103 // Both inputs must be tensor-like extension types.
98104 let lhs_ext = lhs. as_extension_opt ( ) . ok_or_else ( || {
99- vortex_err ! ( "cosine_similarity lhs must be an extension type, got {lhs}" )
105+ vortex_err ! ( "CosineSimilarity lhs must be an extension type, got {lhs}" )
100106 } ) ?;
101107
102108 vortex_ensure ! (
103109 AnyTensor :: matches( lhs_ext) ,
104- "cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
110+ "CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
105111 ) ;
106112
107113 let ptype = extension_element_ptype ( lhs_ext) ?;
108114 vortex_ensure ! (
109115 ptype. is_float( ) ,
110- "cosine_similarity element dtype must be a float primitive, got {ptype}"
116+ "CosineSimilarity element dtype must be a float primitive, got {ptype}"
111117 ) ;
112118
113119 let nullability = Nullability :: from ( lhs. is_nullable ( ) || rhs. is_nullable ( ) ) ;
@@ -191,79 +197,32 @@ fn cosine_similarity_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
191197
192198#[ cfg( test) ]
193199mod tests {
194- use vortex:: array:: ArrayRef ;
195- use vortex:: array:: IntoArray ;
200+ use rstest:: rstest;
196201 use vortex:: array:: ToCanonical ;
197- use vortex:: array:: arrays:: ConstantArray ;
198- use vortex:: array:: arrays:: ExtensionArray ;
199- use vortex:: array:: arrays:: FixedSizeListArray ;
200202 use vortex:: array:: arrays:: ScalarFnArray ;
201- use vortex:: array:: validity:: Validity ;
202- use vortex:: buffer:: Buffer ;
203- use vortex:: dtype:: DType ;
204- use vortex:: dtype:: Nullability ;
205- use vortex:: dtype:: extension:: ExtDType ;
206203 use vortex:: error:: VortexResult ;
207- use vortex:: extension:: EmptyMetadata ;
208- use vortex:: scalar:: Scalar ;
209204 use vortex:: scalar_fn:: EmptyOptions ;
210205 use vortex:: scalar_fn:: ScalarFn ;
211206
212- use crate :: fixed_shape:: FixedShapeTensor ;
213- use crate :: fixed_shape:: FixedShapeTensorMetadata ;
214207 use crate :: scalar_fns:: cosine_similarity:: CosineSimilarity ;
215- use crate :: vector:: Vector ;
216-
217- /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
218- ///
219- /// The number of rows is inferred from the total element count divided by the product of the
220- /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row.
221- fn tensor_array ( shape : & [ usize ] , elements : & [ f64 ] ) -> VortexResult < ArrayRef > {
222- let list_size: u32 = shape. iter ( ) . product :: < usize > ( ) . max ( 1 ) . try_into ( ) . unwrap ( ) ;
223- let row_count = elements. len ( ) / list_size as usize ;
224-
225- let elems: ArrayRef = Buffer :: copy_from ( elements) . into_array ( ) ;
226- let fsl = FixedSizeListArray :: new ( elems, list_size, Validity :: NonNullable , row_count) ;
227-
228- let metadata = FixedShapeTensorMetadata :: new ( shape. to_vec ( ) ) ;
229- let ext_dtype =
230- ExtDType :: < FixedShapeTensor > :: try_new ( metadata, fsl. dtype ( ) . clone ( ) ) ?. erased ( ) ;
231-
232- Ok ( ExtensionArray :: new ( ext_dtype, fsl. into_array ( ) ) . into_array ( ) )
233- }
208+ use crate :: scalar_fns:: utils:: test_helpers:: assert_close;
209+ use crate :: scalar_fns:: utils:: test_helpers:: constant_tensor_array;
210+ use crate :: scalar_fns:: utils:: test_helpers:: constant_vector_array;
211+ use crate :: scalar_fns:: utils:: test_helpers:: tensor_array;
212+ use crate :: scalar_fns:: utils:: test_helpers:: vector_array;
234213
235214 /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
236- fn eval_cosine_similarity ( lhs : ArrayRef , rhs : ArrayRef , len : usize ) -> VortexResult < Vec < f64 > > {
215+ fn eval_cosine_similarity (
216+ lhs : vortex:: array:: ArrayRef ,
217+ rhs : vortex:: array:: ArrayRef ,
218+ len : usize ,
219+ ) -> VortexResult < Vec < f64 > > {
237220 let scalar_fn = ScalarFn :: new ( CosineSimilarity , EmptyOptions ) . erased ( ) ;
238221 let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ lhs, rhs] , len) ?;
239222 let prim = result. to_primitive ( ) ;
240223 Ok ( prim. as_slice :: < f64 > ( ) . to_vec ( ) )
241224 }
242225
243- /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected`
244- /// value, with support for NaN (NaN == NaN is considered equal).
245- #[ track_caller]
246- fn assert_close ( actual : & [ f64 ] , expected : & [ f64 ] ) {
247- assert_eq ! (
248- actual. len( ) ,
249- expected. len( ) ,
250- "length mismatch: got {} elements, expected {}" ,
251- actual. len( ) ,
252- expected. len( )
253- ) ;
254-
255- for ( i, ( a, e) ) in actual. iter ( ) . zip ( expected) . enumerate ( ) {
256- if a. is_nan ( ) && e. is_nan ( ) {
257- continue ;
258- }
259- assert ! (
260- ( a - e) . abs( ) < 1e-10 ,
261- "element {i}: got {a}, expected {e} (diff = {})" ,
262- ( a - e) . abs( )
263- ) ;
264- }
265- }
266-
267226 #[ test]
268227 fn unit_vectors_1d ( ) -> VortexResult < ( ) > {
269228 let lhs = tensor_array (
@@ -281,20 +240,18 @@ mod tests {
281240 ] ,
282241 ) ?;
283242
284- // Row 0: identical → 1.0, row 1: orthogonal → 0.0.
243+ // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
285244 assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
286245 Ok ( ( ) )
287246 }
288247
289- use rstest:: rstest;
290-
291248 /// Single-row cosine similarity for various vector pairs.
292249 #[ rstest]
293- // Antiparallel → -1.0.
250+ // Antiparallel -> -1.0.
294251 #[ case:: opposite( & [ 3 ] , & [ 1.0 , 0.0 , 0.0 ] , & [ -1.0 , 0.0 , 0.0 ] , & [ -1.0 ] ) ]
295- // dot=24, both magnitudes=5 → 24/25 = 0.96.
252+ // dot=24, both magnitudes=5 -> 24/25 = 0.96.
296253 #[ case:: non_unit( & [ 2 ] , & [ 3.0 , 4.0 ] , & [ 4.0 , 3.0 ] , & [ 0.96 ] ) ]
297- // Zero vector → 0/0 → NaN.
254+ // Zero vector -> 0/0 -> NaN.
298255 #[ case:: zero_norm( & [ 2 ] , & [ 0.0 , 0.0 ] , & [ 1.0 , 0.0 ] , & [ f64 :: NAN ] ) ]
299256 fn single_row (
300257 #[ case] shape : & [ usize ] ,
@@ -333,14 +290,14 @@ mod tests {
333290 let lhs = tensor_array ( & [ ] , & [ 5.0 , 3.0 ] ) ?;
334291 let rhs = tensor_array ( & [ ] , & [ 5.0 , -3.0 ] ) ?;
335292
336- // Same sign → 1.0, opposite sign → -1.0.
293+ // Same sign -> 1.0, opposite sign -> -1.0.
337294 assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , -1.0 ] ) ;
338295 Ok ( ( ) )
339296 }
340297
341298 #[ test]
342299 fn many_rows ( ) -> VortexResult < ( ) > {
343- // 5 tensors of shape [4] compared against themselves → all 1.0.
300+ // 5 tensors of shape [4] compared against themselves -> all 1.0.
344301 let lhs = tensor_array (
345302 & [ 4 ] ,
346303 & [
@@ -360,35 +317,8 @@ mod tests {
360317 Ok ( ( ) )
361318 }
362319
363- /// Builds an extension array whose storage is a [`ConstantArray`], representing a single
364- /// query tensor broadcast to `len` rows.
365- fn constant_tensor_array (
366- shape : & [ usize ] ,
367- elements : & [ f64 ] ,
368- len : usize ,
369- ) -> VortexResult < ArrayRef > {
370- let element_dtype = DType :: Primitive ( vortex:: dtype:: PType :: F64 , Nullability :: NonNullable ) ;
371-
372- // Build the FSL storage scalar from individual element scalars.
373- let children: Vec < Scalar > = elements
374- . iter ( )
375- . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
376- . collect ( ) ;
377- let storage_scalar =
378- Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
379-
380- // Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies.
381- let storage = ConstantArray :: new ( storage_scalar, len) . into_array ( ) ;
382-
383- let metadata = FixedShapeTensorMetadata :: new ( shape. to_vec ( ) ) ;
384- let ext_dtype =
385- ExtDType :: < FixedShapeTensor > :: try_new ( metadata, storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
386-
387- Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
388- }
389-
390320 #[ test]
391- fn constant_query_vector ( ) -> VortexResult < ( ) > {
321+ fn constant_query_tensor ( ) -> VortexResult < ( ) > {
392322 // Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0].
393323 let data = tensor_array (
394324 & [ 3 ] ,
@@ -401,26 +331,13 @@ mod tests {
401331 ) ?;
402332 let query = constant_tensor_array ( & [ 3 ] , & [ 1.0 , 0.0 , 0.0 ] , 4 ) ?;
403333
404- // Only tensor 0 is aligned with the query.
405334 assert_close (
406335 & eval_cosine_similarity ( data, query, 4 ) ?,
407336 & [ 1.0 , 0.0 , 0.0 , 1.0 ] ,
408337 ) ;
409338 Ok ( ( ) )
410339 }
411340
412- /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
413- fn vector_array ( dim : u32 , elements : & [ f64 ] ) -> VortexResult < ArrayRef > {
414- let row_count = elements. len ( ) / dim as usize ;
415-
416- let elems: ArrayRef = Buffer :: copy_from ( elements) . into_array ( ) ;
417- let fsl = FixedSizeListArray :: new ( elems, dim, Validity :: NonNullable , row_count) ;
418-
419- let ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , fsl. dtype ( ) . clone ( ) ) ?. erased ( ) ;
420-
421- Ok ( ExtensionArray :: new ( ext_dtype, fsl. into_array ( ) ) . into_array ( ) )
422- }
423-
424341 #[ test]
425342 fn vector_unit_vectors ( ) -> VortexResult < ( ) > {
426343 let lhs = vector_array (
@@ -443,43 +360,6 @@ mod tests {
443360 Ok ( ( ) )
444361 }
445362
446- #[ test]
447- fn vector_self_similarity ( ) -> VortexResult < ( ) > {
448- let arr = vector_array (
449- 4 ,
450- & [
451- 1.0 , 2.0 , 3.0 , 4.0 , // vector 0
452- 0.0 , 1.0 , 0.0 , 0.0 , // vector 1
453- 5.0 , 0.0 , 5.0 , 0.0 , // vector 2
454- ] ,
455- ) ?;
456-
457- assert_close (
458- & eval_cosine_similarity ( arr. clone ( ) , arr, 3 ) ?,
459- & [ 1.0 , 1.0 , 1.0 ] ,
460- ) ;
461- Ok ( ( ) )
462- }
463-
464- /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
465- fn constant_vector_array ( elements : & [ f64 ] , len : usize ) -> VortexResult < ArrayRef > {
466- let element_dtype = DType :: Primitive ( vortex:: dtype:: PType :: F64 , Nullability :: NonNullable ) ;
467-
468- let children: Vec < Scalar > = elements
469- . iter ( )
470- . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
471- . collect ( ) ;
472- let storage_scalar =
473- Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
474-
475- let storage = ConstantArray :: new ( storage_scalar, len) . into_array ( ) ;
476-
477- let ext_dtype =
478- ExtDType :: < Vector > :: try_new ( EmptyMetadata , storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
479-
480- Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
481- }
482-
483363 #[ test]
484364 fn vector_constant_query ( ) -> VortexResult < ( ) > {
485365 let data = vector_array (
0 commit comments