Skip to content

Commit b47bd61

Browse files
committed
add l2 norm scalar fn
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent a2ec5b4 commit b47bd61

5 files changed

Lines changed: 404 additions & 68 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 9 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,15 @@ use std::fmt::Formatter;
99

1010
use num_traits::Float;
1111
use vortex::array::ArrayRef;
12-
use vortex::array::DynArray;
1312
use vortex::array::ExecutionCtx;
1413
use vortex::array::IntoArray;
15-
use vortex::array::arrays::Constant;
16-
use vortex::array::arrays::ConstantArray;
17-
use vortex::array::arrays::Extension;
1814
use vortex::array::arrays::PrimitiveArray;
1915
use vortex::array::match_each_float_ptype;
2016
use vortex::dtype::DType;
2117
use vortex::dtype::NativePType;
2218
use vortex::dtype::Nullability;
2319
use vortex::dtype::extension::Matcher;
2420
use vortex::error::VortexResult;
25-
use vortex::error::vortex_bail;
2621
use vortex::error::vortex_ensure;
2722
use vortex::error::vortex_err;
2823
use vortex::expr::Expression;
@@ -34,9 +29,11 @@ use vortex::scalar_fn::ScalarFnId;
3429
use vortex::scalar_fn::ScalarFnVTable;
3530

3631
use crate::matcher::AnyTensor;
32+
use crate::scalar_fns::utils::extension_element_ptype;
33+
use crate::scalar_fns::utils::extension_list_size;
34+
use crate::scalar_fns::utils::extension_storage;
35+
use crate::scalar_fns::utils::extract_flat_elements;
3736

38-
// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
39-
// encoded in spherical coordinates.
4037
/// Cosine similarity between two columns.
4138
///
4239
/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
@@ -101,33 +98,18 @@ impl ScalarFnVTable for CosineSimilarity {
10198
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
10299
vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}")
103100
})?;
101+
104102
vortex_ensure!(
105103
AnyTensor::matches(lhs_ext),
106104
"cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
107105
);
108106

109-
// Extract the element dtype from the storage FixedSizeList.
110-
let element_dtype = lhs_ext
111-
.storage_dtype()
112-
.as_fixed_size_list_element_opt()
113-
.ok_or_else(|| {
114-
vortex_err!(
115-
"cosine_similarity storage dtype must be a FixedSizeList, got {}",
116-
lhs_ext.storage_dtype()
117-
)
118-
})?;
119-
120-
// Element dtype must be a non-nullable float primitive.
121-
vortex_ensure!(
122-
element_dtype.is_float(),
123-
"cosine_similarity element dtype must be a float primitive, got {element_dtype}"
124-
);
107+
let ptype = extension_element_ptype(lhs_ext)?;
125108
vortex_ensure!(
126-
!element_dtype.is_nullable(),
127-
"cosine_similarity element dtype must be non-nullable"
109+
ptype.is_float(),
110+
"cosine_similarity element dtype must be a float primitive, got {ptype}"
128111
);
129112

130-
let ptype = element_dtype.as_ptype();
131113
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
132114
Ok(DType::Primitive(ptype, nullability))
133115
}
@@ -149,10 +131,7 @@ impl ScalarFnVTable for CosineSimilarity {
149131
lhs.dtype()
150132
)
151133
})?;
152-
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
153-
vortex_bail!("expected FixedSizeList storage dtype");
154-
};
155-
let list_size = *list_size as usize;
134+
let list_size = extension_list_size(ext)?;
156135

157136
// Extract the storage array from each extension input. We pass the storage (FSL) rather
158137
// than the extension array to avoid canonicalizing the extension wrapper.
@@ -202,38 +181,6 @@ impl ScalarFnVTable for CosineSimilarity {
202181
}
203182
}
204183

205-
/// Extracts the storage array from an extension array without canonicalizing.
206-
fn extension_storage(array: &ArrayRef) -> VortexResult<ArrayRef> {
207-
let ext = array
208-
.as_opt::<Extension>()
209-
.ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?;
210-
Ok(ext.storage_array().clone())
211-
}
212-
213-
/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
214-
///
215-
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
216-
/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)`
217-
/// where `stride` is `list_size` for a full array and `0` for a constant.
218-
fn extract_flat_elements(
219-
storage: &ArrayRef,
220-
list_size: usize,
221-
) -> VortexResult<(PrimitiveArray, usize)> {
222-
if let Some(constant) = storage.as_opt::<Constant>() {
223-
// Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a
224-
// huge amount of data.
225-
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
226-
let fsl = single.to_canonical()?.into_fixed_size_list();
227-
let elems = fsl.elements().to_canonical()?.into_primitive();
228-
Ok((elems, 0))
229-
} else {
230-
// Otherwise we have to fully expand all of the data.
231-
let fsl = storage.to_canonical()?.into_fixed_size_list();
232-
let elems = fsl.elements().to_canonical()?.into_primitive();
233-
Ok((elems, list_size))
234-
}
235-
}
236-
237184
// TODO(connor): We should try to use a more performant library instead of doing this ourselves.
238185
/// Computes cosine similarity between two equal-length float slices.
239186
///

0 commit comments

Comments
 (0)