@@ -9,20 +9,15 @@ use std::fmt::Formatter;
99
1010use num_traits:: Float ;
1111use vortex:: array:: ArrayRef ;
12- use vortex:: array:: DynArray ;
1312use vortex:: array:: ExecutionCtx ;
1413use vortex:: array:: IntoArray ;
15- use vortex:: array:: arrays:: Constant ;
16- use vortex:: array:: arrays:: ConstantArray ;
17- use vortex:: array:: arrays:: Extension ;
1814use vortex:: array:: arrays:: PrimitiveArray ;
1915use vortex:: array:: match_each_float_ptype;
2016use vortex:: dtype:: DType ;
2117use vortex:: dtype:: NativePType ;
2218use vortex:: dtype:: Nullability ;
2319use vortex:: dtype:: extension:: Matcher ;
2420use vortex:: error:: VortexResult ;
25- use vortex:: error:: vortex_bail;
2621use vortex:: error:: vortex_ensure;
2722use vortex:: error:: vortex_err;
2823use vortex:: expr:: Expression ;
@@ -34,9 +29,11 @@ use vortex::scalar_fn::ScalarFnId;
3429use vortex:: scalar_fn:: ScalarFnVTable ;
3530
3631use crate :: matcher:: AnyTensor ;
32+ use crate :: scalar_fns:: utils:: extension_element_ptype;
33+ use crate :: scalar_fns:: utils:: extension_list_size;
34+ use crate :: scalar_fns:: utils:: extension_storage;
35+ use crate :: scalar_fns:: utils:: extract_flat_elements;
3736
38- // TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
39- // encoded in spherical coordinates.
4037/// Cosine similarity between two columns.
4138///
4239/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
@@ -101,33 +98,18 @@ impl ScalarFnVTable for CosineSimilarity {
10198 let lhs_ext = lhs. as_extension_opt ( ) . ok_or_else ( || {
10299 vortex_err ! ( "cosine_similarity lhs must be an extension type, got {lhs}" )
103100 } ) ?;
101+
104102 vortex_ensure ! (
105103 AnyTensor :: matches( lhs_ext) ,
106104 "cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
107105 ) ;
108106
109- // Extract the element dtype from the storage FixedSizeList.
110- let element_dtype = lhs_ext
111- . storage_dtype ( )
112- . as_fixed_size_list_element_opt ( )
113- . ok_or_else ( || {
114- vortex_err ! (
115- "cosine_similarity storage dtype must be a FixedSizeList, got {}" ,
116- lhs_ext. storage_dtype( )
117- )
118- } ) ?;
119-
120- // Element dtype must be a non-nullable float primitive.
121- vortex_ensure ! (
122- element_dtype. is_float( ) ,
123- "cosine_similarity element dtype must be a float primitive, got {element_dtype}"
124- ) ;
107+ let ptype = extension_element_ptype ( lhs_ext) ?;
125108 vortex_ensure ! (
126- !element_dtype . is_nullable ( ) ,
127- "cosine_similarity element dtype must be non-nullable "
109+ ptype . is_float ( ) ,
110+ "cosine_similarity element dtype must be a float primitive, got {ptype} "
128111 ) ;
129112
130- let ptype = element_dtype. as_ptype ( ) ;
131113 let nullability = Nullability :: from ( lhs. is_nullable ( ) || rhs. is_nullable ( ) ) ;
132114 Ok ( DType :: Primitive ( ptype, nullability) )
133115 }
@@ -149,10 +131,7 @@ impl ScalarFnVTable for CosineSimilarity {
149131 lhs. dtype( )
150132 )
151133 } ) ?;
152- let DType :: FixedSizeList ( _, list_size, _) = ext. storage_dtype ( ) else {
153- vortex_bail ! ( "expected FixedSizeList storage dtype" ) ;
154- } ;
155- let list_size = * list_size as usize ;
134+ let list_size = extension_list_size ( ext) ?;
156135
157136 // Extract the storage array from each extension input. We pass the storage (FSL) rather
158137 // than the extension array to avoid canonicalizing the extension wrapper.
@@ -202,38 +181,6 @@ impl ScalarFnVTable for CosineSimilarity {
202181 }
203182}
204183
205- /// Extracts the storage array from an extension array without canonicalizing.
206- fn extension_storage ( array : & ArrayRef ) -> VortexResult < ArrayRef > {
207- let ext = array
208- . as_opt :: < Extension > ( )
209- . ok_or_else ( || vortex_err ! ( "cosine_similarity input must be an extension array" ) ) ?;
210- Ok ( ext. storage_array ( ) . clone ( ) )
211- }
212-
213- /// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
214- ///
215- /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
216- /// materialized to avoid expanding it to the full column length. Returns `(elements, stride)`
217- /// where `stride` is `list_size` for a full array and `0` for a constant.
218- fn extract_flat_elements (
219- storage : & ArrayRef ,
220- list_size : usize ,
221- ) -> VortexResult < ( PrimitiveArray , usize ) > {
222- if let Some ( constant) = storage. as_opt :: < Constant > ( ) {
223- // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a
224- // huge amount of data.
225- let single = ConstantArray :: new ( constant. scalar ( ) . clone ( ) , 1 ) . into_array ( ) ;
226- let fsl = single. to_canonical ( ) ?. into_fixed_size_list ( ) ;
227- let elems = fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
228- Ok ( ( elems, 0 ) )
229- } else {
230- // Otherwise we have to fully expand all of the data.
231- let fsl = storage. to_canonical ( ) ?. into_fixed_size_list ( ) ;
232- let elems = fsl. elements ( ) . to_canonical ( ) ?. into_primitive ( ) ;
233- Ok ( ( elems, list_size) )
234- }
235- }
236-
237184// TODO(connor): We should try to use a more performant library instead of doing this ourselves.
238185/// Computes cosine similarity between two equal-length float slices.
239186///
0 commit comments