Skip to content

Commit 213d665

Browse files
committed
add l2 norm scalar fn
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 16ed7db commit 213d665

5 files changed

Lines changed: 404 additions & 68 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 9 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,15 @@ use std::fmt::Formatter;
99

1010
use num_traits::Float;
1111
use vortex::array::ArrayRef;
12-
use vortex::array::DynArray;
1312
use vortex::array::ExecutionCtx;
1413
use vortex::array::IntoArray;
15-
use vortex::array::arrays::Constant;
16-
use vortex::array::arrays::ConstantArray;
17-
use vortex::array::arrays::Extension;
1814
use vortex::array::arrays::PrimitiveArray;
1915
use vortex::array::match_each_float_ptype;
2016
use vortex::dtype::DType;
2117
use vortex::dtype::NativePType;
2218
use vortex::dtype::Nullability;
2319
use vortex::dtype::extension::Matcher;
2420
use vortex::error::VortexResult;
25-
use vortex::error::vortex_bail;
2621
use vortex::error::vortex_ensure;
2722
use vortex::error::vortex_err;
2823
use vortex::expr::Expression;
@@ -34,9 +29,11 @@ use vortex::scalar_fn::ScalarFnId;
3429
use vortex::scalar_fn::ScalarFnVTable;
3530

3631
use crate::matcher::AnyTensor;
32+
use crate::scalar_fns::utils::extension_element_ptype;
33+
use crate::scalar_fns::utils::extension_list_size;
34+
use crate::scalar_fns::utils::extension_storage;
35+
use crate::scalar_fns::utils::extract_flat_elements;
3736

38-
// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
39-
// encoded in spherical coordinates.
4037
/// Cosine similarity between two columns.
4138
///
4239
/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
@@ -101,33 +98,18 @@ impl ScalarFnVTable for CosineSimilarity {
10198
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
10299
vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}")
103100
})?;
101+
104102
vortex_ensure!(
105103
AnyTensor::matches(lhs_ext),
106104
"cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
107105
);
108106

109-
// Extract the element dtype from the storage FixedSizeList.
110-
let element_dtype = lhs_ext
111-
.storage_dtype()
112-
.as_fixed_size_list_element_opt()
113-
.ok_or_else(|| {
114-
vortex_err!(
115-
"cosine_similarity storage dtype must be a FixedSizeList, got {}",
116-
lhs_ext.storage_dtype()
117-
)
118-
})?;
119-
120-
// Element dtype must be a non-nullable float primitive.
121-
vortex_ensure!(
122-
element_dtype.is_float(),
123-
"cosine_similarity element dtype must be a float primitive, got {element_dtype}"
124-
);
107+
let ptype = extension_element_ptype(lhs_ext)?;
125108
vortex_ensure!(
126-
!element_dtype.is_nullable(),
127-
"cosine_similarity element dtype must be non-nullable"
109+
ptype.is_float(),
110+
"cosine_similarity element dtype must be a float primitive, got {ptype}"
128111
);
129112

130-
let ptype = element_dtype.as_ptype();
131113
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
132114
Ok(DType::Primitive(ptype, nullability))
133115
}
@@ -149,10 +131,7 @@ impl ScalarFnVTable for CosineSimilarity {
149131
lhs.dtype()
150132
)
151133
})?;
152-
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
153-
vortex_bail!("expected FixedSizeList storage dtype");
154-
};
155-
let list_size = *list_size as usize;
134+
let list_size = extension_list_size(ext)?;
156135

157136
// Extract the storage array from each extension input. We pass the storage (FSL) rather
158137
// than the extension array to avoid canonicalizing the extension wrapper.
@@ -203,38 +182,6 @@ impl ScalarFnVTable for CosineSimilarity {
203182
}
204183
}
205184

206-
/// Extracts the storage array from an extension array without canonicalizing.
207-
fn extension_storage(array: &ArrayRef) -> VortexResult<ArrayRef> {
208-
let ext = array
209-
.as_opt::<Extension>()
210-
.ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?;
211-
Ok(ext.storage_array().clone())
212-
}
213-
214-
/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
215-
///
216-
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
217-
/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)`
218-
/// where `stride` is `list_size` for a full array and `0` for a constant.
219-
fn extract_flat_elements(
220-
storage: &ArrayRef,
221-
list_size: usize,
222-
) -> VortexResult<(PrimitiveArray, usize)> {
223-
if let Some(constant) = storage.as_opt::<Constant>() {
224-
// Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a
225-
// huge amount of data.
226-
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
227-
let fsl = single.to_canonical()?.into_fixed_size_list();
228-
let elems = fsl.elements().to_canonical()?.into_primitive();
229-
Ok((elems, 0))
230-
} else {
231-
// Otherwise we have to fully expand all of the data.
232-
let fsl = storage.to_canonical()?.into_fixed_size_list();
233-
let elems = fsl.elements().to_canonical()?.into_primitive();
234-
Ok((elems, list_size))
235-
}
236-
}
237-
238185
// TODO(connor): We should try to use a more performant library instead of doing this ourselves.
239186
/// Computes cosine similarity between two equal-length float slices.
240187
///

0 commit comments

Comments
 (0)