@@ -19,6 +19,7 @@ use vortex::dtype::Nullability;
1919use vortex:: dtype:: extension:: Matcher ;
2020use vortex:: error:: VortexResult ;
2121use vortex:: error:: vortex_ensure;
22+ use vortex:: error:: vortex_ensure_eq;
2223use vortex:: error:: vortex_err;
2324use vortex:: expr:: Expression ;
2425use vortex:: scalar_fn:: Arity ;
@@ -81,33 +82,38 @@ impl ScalarFnVTable for CosineSimilarity {
8182 }
8283
8384 fn return_dtype ( & self , _options : & Self :: Options , arg_dtypes : & [ DType ] ) -> VortexResult < DType > {
84- debug_assert_eq ! ( arg_dtypes. len( ) , 2 ) ;
85+ vortex_ensure_eq ! (
86+ arg_dtypes. len( ) ,
87+ 2 ,
88+ "CosineSimilarity requires exactly 2 arguments, got {}" ,
89+ arg_dtypes. len( )
90+ ) ;
8591
8692 let lhs = & arg_dtypes[ 0 ] ;
8793 let rhs = & arg_dtypes[ 1 ] ;
8894
8995 // Both must have the same dtype (ignoring top-level nullability).
9096 vortex_ensure ! (
9197 lhs. eq_ignore_nullability( rhs) ,
92- "cosine_similarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
98+ "CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}"
9399 ) ;
94100
95101 // We don't need to look at rhs anymore since we know lhs and rhs are equal.
96102
97103 // Both inputs must be tensor-like extension types.
98104 let lhs_ext = lhs. as_extension_opt ( ) . ok_or_else ( || {
99- vortex_err ! ( "cosine_similarity lhs must be an extension type, got {lhs}" )
105+ vortex_err ! ( "CosineSimilarity lhs must be an extension type, got {lhs}" )
100106 } ) ?;
101107
102108 vortex_ensure ! (
103109 AnyTensor :: matches( lhs_ext) ,
104- "cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
110+ "CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
105111 ) ;
106112
107113 let ptype = extension_element_ptype ( lhs_ext) ?;
108114 vortex_ensure ! (
109115 ptype. is_float( ) ,
110- "cosine_similarity element dtype must be a float primitive, got {ptype}"
116+ "CosineSimilarity element dtype must be a float primitive, got {ptype}"
111117 ) ;
112118
113119 let nullability = Nullability :: from ( lhs. is_nullable ( ) || rhs. is_nullable ( ) ) ;
@@ -190,79 +196,32 @@ fn cosine_similarity_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
190196
191197#[ cfg( test) ]
192198mod tests {
193- use vortex:: array:: ArrayRef ;
194- use vortex:: array:: IntoArray ;
199+ use rstest:: rstest;
195200 use vortex:: array:: ToCanonical ;
196- use vortex:: array:: arrays:: ConstantArray ;
197- use vortex:: array:: arrays:: ExtensionArray ;
198- use vortex:: array:: arrays:: FixedSizeListArray ;
199201 use vortex:: array:: arrays:: ScalarFnArray ;
200- use vortex:: array:: validity:: Validity ;
201- use vortex:: buffer:: Buffer ;
202- use vortex:: dtype:: DType ;
203- use vortex:: dtype:: Nullability ;
204- use vortex:: dtype:: extension:: ExtDType ;
205202 use vortex:: error:: VortexResult ;
206- use vortex:: extension:: EmptyMetadata ;
207- use vortex:: scalar:: Scalar ;
208203 use vortex:: scalar_fn:: EmptyOptions ;
209204 use vortex:: scalar_fn:: ScalarFn ;
210205
211- use crate :: fixed_shape:: FixedShapeTensor ;
212- use crate :: fixed_shape:: FixedShapeTensorMetadata ;
213206 use crate :: scalar_fns:: cosine_similarity:: CosineSimilarity ;
214- use crate :: vector:: Vector ;
215-
216- /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
217- ///
218- /// The number of rows is inferred from the total element count divided by the product of the
219- /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row.
220- fn tensor_array ( shape : & [ usize ] , elements : & [ f64 ] ) -> VortexResult < ArrayRef > {
221- let list_size: u32 = shape. iter ( ) . product :: < usize > ( ) . max ( 1 ) . try_into ( ) . unwrap ( ) ;
222- let row_count = elements. len ( ) / list_size as usize ;
223-
224- let elems: ArrayRef = Buffer :: copy_from ( elements) . into_array ( ) ;
225- let fsl = FixedSizeListArray :: new ( elems, list_size, Validity :: NonNullable , row_count) ;
226-
227- let metadata = FixedShapeTensorMetadata :: new ( shape. to_vec ( ) ) ;
228- let ext_dtype =
229- ExtDType :: < FixedShapeTensor > :: try_new ( metadata, fsl. dtype ( ) . clone ( ) ) ?. erased ( ) ;
230-
231- Ok ( ExtensionArray :: new ( ext_dtype, fsl. into_array ( ) ) . into_array ( ) )
232- }
207+ use crate :: scalar_fns:: utils:: test_helpers:: assert_close;
208+ use crate :: scalar_fns:: utils:: test_helpers:: constant_tensor_array;
209+ use crate :: scalar_fns:: utils:: test_helpers:: constant_vector_array;
210+ use crate :: scalar_fns:: utils:: test_helpers:: tensor_array;
211+ use crate :: scalar_fns:: utils:: test_helpers:: vector_array;
233212
234213 /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
235- fn eval_cosine_similarity ( lhs : ArrayRef , rhs : ArrayRef , len : usize ) -> VortexResult < Vec < f64 > > {
214+ fn eval_cosine_similarity (
215+ lhs : vortex:: array:: ArrayRef ,
216+ rhs : vortex:: array:: ArrayRef ,
217+ len : usize ,
218+ ) -> VortexResult < Vec < f64 > > {
236219 let scalar_fn = ScalarFn :: new ( CosineSimilarity , EmptyOptions ) . erased ( ) ;
237220 let result = ScalarFnArray :: try_new ( scalar_fn, vec ! [ lhs, rhs] , len) ?;
238221 let prim = result. to_primitive ( ) ;
239222 Ok ( prim. as_slice :: < f64 > ( ) . to_vec ( ) )
240223 }
241224
242- /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected`
243- /// value, with support for NaN (NaN == NaN is considered equal).
244- #[ track_caller]
245- fn assert_close ( actual : & [ f64 ] , expected : & [ f64 ] ) {
246- assert_eq ! (
247- actual. len( ) ,
248- expected. len( ) ,
249- "length mismatch: got {} elements, expected {}" ,
250- actual. len( ) ,
251- expected. len( )
252- ) ;
253-
254- for ( i, ( a, e) ) in actual. iter ( ) . zip ( expected) . enumerate ( ) {
255- if a. is_nan ( ) && e. is_nan ( ) {
256- continue ;
257- }
258- assert ! (
259- ( a - e) . abs( ) < 1e-10 ,
260- "element {i}: got {a}, expected {e} (diff = {})" ,
261- ( a - e) . abs( )
262- ) ;
263- }
264- }
265-
266225 #[ test]
267226 fn unit_vectors_1d ( ) -> VortexResult < ( ) > {
268227 let lhs = tensor_array (
@@ -280,20 +239,18 @@ mod tests {
280239 ] ,
281240 ) ?;
282241
283- // Row 0: identical → 1.0, row 1: orthogonal → 0.0.
242+ // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
284243 assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , 0.0 ] ) ;
285244 Ok ( ( ) )
286245 }
287246
288- use rstest:: rstest;
289-
290247 /// Single-row cosine similarity for various vector pairs.
291248 #[ rstest]
292- // Antiparallel → -1.0.
249+ // Antiparallel -> -1.0.
293250 #[ case:: opposite( & [ 3 ] , & [ 1.0 , 0.0 , 0.0 ] , & [ -1.0 , 0.0 , 0.0 ] , & [ -1.0 ] ) ]
294- // dot=24, both magnitudes=5 → 24/25 = 0.96.
251+ // dot=24, both magnitudes=5 -> 24/25 = 0.96.
295252 #[ case:: non_unit( & [ 2 ] , & [ 3.0 , 4.0 ] , & [ 4.0 , 3.0 ] , & [ 0.96 ] ) ]
296- // Zero vector → 0/0 → NaN.
253+ // Zero vector -> 0/0 -> NaN.
297254 #[ case:: zero_norm( & [ 2 ] , & [ 0.0 , 0.0 ] , & [ 1.0 , 0.0 ] , & [ f64 :: NAN ] ) ]
298255 fn single_row (
299256 #[ case] shape : & [ usize ] ,
@@ -332,14 +289,14 @@ mod tests {
332289 let lhs = tensor_array ( & [ ] , & [ 5.0 , 3.0 ] ) ?;
333290 let rhs = tensor_array ( & [ ] , & [ 5.0 , -3.0 ] ) ?;
334291
335- // Same sign → 1.0, opposite sign → -1.0.
292+ // Same sign -> 1.0, opposite sign -> -1.0.
336293 assert_close ( & eval_cosine_similarity ( lhs, rhs, 2 ) ?, & [ 1.0 , -1.0 ] ) ;
337294 Ok ( ( ) )
338295 }
339296
340297 #[ test]
341298 fn many_rows ( ) -> VortexResult < ( ) > {
342- // 5 tensors of shape [4] compared against themselves → all 1.0.
299+ // 5 tensors of shape [4] compared against themselves -> all 1.0.
343300 let lhs = tensor_array (
344301 & [ 4 ] ,
345302 & [
@@ -359,35 +316,8 @@ mod tests {
359316 Ok ( ( ) )
360317 }
361318
362- /// Builds an extension array whose storage is a [`ConstantArray`], representing a single
363- /// query tensor broadcast to `len` rows.
364- fn constant_tensor_array (
365- shape : & [ usize ] ,
366- elements : & [ f64 ] ,
367- len : usize ,
368- ) -> VortexResult < ArrayRef > {
369- let element_dtype = DType :: Primitive ( vortex:: dtype:: PType :: F64 , Nullability :: NonNullable ) ;
370-
371- // Build the FSL storage scalar from individual element scalars.
372- let children: Vec < Scalar > = elements
373- . iter ( )
374- . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
375- . collect ( ) ;
376- let storage_scalar =
377- Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
378-
379- // Wrap the FSL scalar in a ConstantArray to avoid materializing `len` copies.
380- let storage = ConstantArray :: new ( storage_scalar, len) . into_array ( ) ;
381-
382- let metadata = FixedShapeTensorMetadata :: new ( shape. to_vec ( ) ) ;
383- let ext_dtype =
384- ExtDType :: < FixedShapeTensor > :: try_new ( metadata, storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
385-
386- Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
387- }
388-
389319 #[ test]
390- fn constant_query_vector ( ) -> VortexResult < ( ) > {
320+ fn constant_query_tensor ( ) -> VortexResult < ( ) > {
391321 // Compare 4 tensors of shape [3] against a single constant query tensor [1,0,0].
392322 let data = tensor_array (
393323 & [ 3 ] ,
@@ -400,26 +330,13 @@ mod tests {
400330 ) ?;
401331 let query = constant_tensor_array ( & [ 3 ] , & [ 1.0 , 0.0 , 0.0 ] , 4 ) ?;
402332
403- // Only tensor 0 is aligned with the query.
404333 assert_close (
405334 & eval_cosine_similarity ( data, query, 4 ) ?,
406335 & [ 1.0 , 0.0 , 0.0 , 1.0 ] ,
407336 ) ;
408337 Ok ( ( ) )
409338 }
410339
411- /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
412- fn vector_array ( dim : u32 , elements : & [ f64 ] ) -> VortexResult < ArrayRef > {
413- let row_count = elements. len ( ) / dim as usize ;
414-
415- let elems: ArrayRef = Buffer :: copy_from ( elements) . into_array ( ) ;
416- let fsl = FixedSizeListArray :: new ( elems, dim, Validity :: NonNullable , row_count) ;
417-
418- let ext_dtype = ExtDType :: < Vector > :: try_new ( EmptyMetadata , fsl. dtype ( ) . clone ( ) ) ?. erased ( ) ;
419-
420- Ok ( ExtensionArray :: new ( ext_dtype, fsl. into_array ( ) ) . into_array ( ) )
421- }
422-
423340 #[ test]
424341 fn vector_unit_vectors ( ) -> VortexResult < ( ) > {
425342 let lhs = vector_array (
@@ -442,43 +359,6 @@ mod tests {
442359 Ok ( ( ) )
443360 }
444361
445- #[ test]
446- fn vector_self_similarity ( ) -> VortexResult < ( ) > {
447- let arr = vector_array (
448- 4 ,
449- & [
450- 1.0 , 2.0 , 3.0 , 4.0 , // vector 0
451- 0.0 , 1.0 , 0.0 , 0.0 , // vector 1
452- 5.0 , 0.0 , 5.0 , 0.0 , // vector 2
453- ] ,
454- ) ?;
455-
456- assert_close (
457- & eval_cosine_similarity ( arr. clone ( ) , arr, 3 ) ?,
458- & [ 1.0 , 1.0 , 1.0 ] ,
459- ) ;
460- Ok ( ( ) )
461- }
462-
463- /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
464- fn constant_vector_array ( elements : & [ f64 ] , len : usize ) -> VortexResult < ArrayRef > {
465- let element_dtype = DType :: Primitive ( vortex:: dtype:: PType :: F64 , Nullability :: NonNullable ) ;
466-
467- let children: Vec < Scalar > = elements
468- . iter ( )
469- . map ( |& v| Scalar :: primitive ( v, Nullability :: NonNullable ) )
470- . collect ( ) ;
471- let storage_scalar =
472- Scalar :: fixed_size_list ( element_dtype, children, Nullability :: NonNullable ) ;
473-
474- let storage = ConstantArray :: new ( storage_scalar, len) . into_array ( ) ;
475-
476- let ext_dtype =
477- ExtDType :: < Vector > :: try_new ( EmptyMetadata , storage. dtype ( ) . clone ( ) ) ?. erased ( ) ;
478-
479- Ok ( ExtensionArray :: new ( ext_dtype, storage) . into_array ( ) )
480- }
481-
482362 #[ test]
483363 fn vector_constant_query ( ) -> VortexResult < ( ) > {
484364 let data = vector_array (
0 commit comments