Skip to content

Commit d879349

Browse files
authored
Fix tensor expressions and add InnerProduct expression (#7277)
## Summary Fixes a bunch of bugs and issues with tensor expressions (execution, validity handling) and adds an `InnerProduct` expression. `CosineSimilarity` now calls into both `InnerProduct` and `L2Norm`. Fun side change: adds `std::result::Result::and` to the list of disallowed methods. ## API Changes Adds some nice helpers for creating scalar fn expressions and arrays. ## Testing Some more tests, mainly making sure that validity is handled correctly now. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 5d0d627 commit d879349

File tree

9 files changed

+540
-105
lines changed

9 files changed

+540
-105
lines changed

clippy.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ disallowed-types = [
1212

1313
disallowed-methods = [
1414
{ path = "itertools::Itertools::counts", reason = "It uses the default hasher which is slow for primitives. Just inline the loop for better performance.", allow-invalid = true },
15+
{ path = "std::result::Result::and", reason = "This method is a footgun, especially when working with `Result<Validity>`.", allow-invalid = true },
1516
]

vortex-tensor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ workspace = true
1818

1919
[dependencies]
2020
vortex-array = { workspace = true }
21+
vortex-buffer = { workspace = true }
2122
vortex-error = { workspace = true }
2223
vortex-session = { workspace = true }
2324

vortex-tensor/public-api.lock

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ pub mod vortex_tensor::scalar_fns::cosine_similarity
126126

127127
pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
128128

129+
impl vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
130+
131+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn<vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity>
132+
133+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
134+
129135
impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
130136

131137
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
@@ -138,7 +144,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se
138144

139145
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
140146

141-
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
147+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
142148

143149
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
144150

@@ -152,10 +158,52 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dt
152158

153159
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
154160

161+
pub mod vortex_tensor::scalar_fns::inner_product
162+
163+
pub struct vortex_tensor::scalar_fns::inner_product::InnerProduct
164+
165+
impl vortex_tensor::scalar_fns::inner_product::InnerProduct
166+
167+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn<vortex_tensor::scalar_fns::inner_product::InnerProduct>
168+
169+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
170+
171+
impl core::clone::Clone for vortex_tensor::scalar_fns::inner_product::InnerProduct
172+
173+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::clone(&self) -> vortex_tensor::scalar_fns::inner_product::InnerProduct
174+
175+
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct
176+
177+
pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_tensor::scalar_fns::ApproxOptions
178+
179+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity
180+
181+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
182+
183+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
184+
185+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
186+
187+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::id(&self) -> vortex_array::scalar_fn::ScalarFnId
188+
189+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::is_fallible(&self, _options: &Self::Options) -> bool
190+
191+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::is_null_sensitive(&self, _options: &Self::Options) -> bool
192+
193+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>
194+
195+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
196+
155197
pub mod vortex_tensor::scalar_fns::l2_norm
156198

157199
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
158200

201+
impl vortex_tensor::scalar_fns::l2_norm::L2Norm
202+
203+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn<vortex_tensor::scalar_fns::l2_norm::L2Norm>
204+
205+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, child: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
206+
159207
impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm
160208

161209
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm

vortex-tensor/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use vortex_session::VortexSession;
1111

1212
use crate::fixed_shape::FixedShapeTensor;
1313
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
14+
use crate::scalar_fns::inner_product::InnerProduct;
1415
use crate::scalar_fns::l2_norm::L2Norm;
1516
use crate::vector::Vector;
1617

@@ -29,5 +30,6 @@ pub fn initialize(session: &VortexSession) {
2930
session.dtypes().register(Vector);
3031
session.dtypes().register(FixedShapeTensor);
3132
session.scalar_fns().register(CosineSimilarity);
33+
session.scalar_fns().register(InnerProduct);
3234
session.scalar_fns().register(L2Norm);
3335
}

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 93 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
//! Cosine similarity expression for tensor-like extension arrays
5-
//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and
6-
//! [`Vector`](crate::vector::Vector)).
4+
//! Cosine similarity expression for tensor-like types.
75
86
use std::fmt::Formatter;
97

10-
use num_traits::Float;
8+
use num_traits::Zero;
119
use vortex_array::ArrayRef;
1210
use vortex_array::ExecutionCtx;
1311
use vortex_array::IntoArray;
12+
use vortex_array::arrays::ExtensionArray;
1413
use vortex_array::arrays::PrimitiveArray;
14+
use vortex_array::arrays::ScalarFnArray;
1515
use vortex_array::dtype::DType;
16-
use vortex_array::dtype::NativePType;
1716
use vortex_array::dtype::Nullability;
1817
use vortex_array::expr::Expression;
1918
use vortex_array::expr::and;
2019
use vortex_array::match_each_float_ptype;
2120
use vortex_array::scalar_fn::Arity;
2221
use vortex_array::scalar_fn::ChildName;
2322
use vortex_array::scalar_fn::ExecutionArgs;
23+
use vortex_array::scalar_fn::ScalarFn;
2424
use vortex_array::scalar_fn::ScalarFnId;
2525
use vortex_array::scalar_fn::ScalarFnVTable;
26+
use vortex_buffer::Buffer;
2627
use vortex_error::VortexResult;
2728
use vortex_error::vortex_ensure;
2829
use vortex_error::vortex_err;
2930

3031
use crate::matcher::AnyTensor;
3132
use crate::scalar_fns::ApproxOptions;
33+
use crate::scalar_fns::inner_product::InnerProduct;
34+
use crate::scalar_fns::l2_norm::L2Norm;
3235
use crate::utils::extension_element_ptype;
33-
use crate::utils::extension_list_size;
34-
use crate::utils::extension_storage;
35-
use crate::utils::extract_flat_elements;
3636

3737
/// Cosine similarity between two columns.
3838
///
@@ -48,6 +48,30 @@ use crate::utils::extract_flat_elements;
4848
#[derive(Clone)]
4949
pub struct CosineSimilarity;
5050

51+
impl CosineSimilarity {
52+
/// Creates a new [`ScalarFn`] wrapping the cosine similarity operation with the given
53+
/// [`ApproxOptions`] controlling approximation behavior.
54+
pub fn new(options: &ApproxOptions) -> ScalarFn<CosineSimilarity> {
55+
ScalarFn::new(CosineSimilarity, options.clone())
56+
}
57+
58+
/// Constructs a [`ScalarFnArray`] that lazily computes the cosine similarity between `lhs` and
59+
/// `rhs`.
60+
///
61+
/// # Errors
62+
///
63+
/// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype
64+
/// mismatches).
65+
pub fn try_new_array(
66+
options: &ApproxOptions,
67+
lhs: ArrayRef,
68+
rhs: ArrayRef,
69+
len: usize,
70+
) -> VortexResult<ScalarFnArray> {
71+
ScalarFnArray::try_new(CosineSimilarity::new(options).erased(), vec![lhs, rhs], len)
72+
}
73+
}
74+
5175
impl ScalarFnVTable for CosineSimilarity {
5276
type Options = ApproxOptions;
5377

@@ -114,37 +138,49 @@ impl ScalarFnVTable for CosineSimilarity {
114138

115139
fn execute(
116140
&self,
117-
_options: &Self::Options,
141+
options: &Self::Options,
118142
args: &dyn ExecutionArgs,
119143
ctx: &mut ExecutionCtx,
120144
) -> VortexResult<ArrayRef> {
121-
let lhs = args.get(0)?;
122-
let rhs = args.get(1)?;
123-
let row_count = args.row_count();
124-
125-
// Get list size from the dtype. Both sides should have the same dtype.
126-
let ext = lhs.dtype().as_extension_opt().ok_or_else(|| {
127-
vortex_err!(
128-
"cosine_similarity input must be an extension type, got {}",
129-
lhs.dtype()
130-
)
131-
})?;
132-
let list_size = extension_list_size(ext)? as usize;
133-
134-
// Extract the storage array from each extension input. We pass the storage (FSL) rather
135-
// than the extension array to avoid canonicalizing the extension wrapper.
136-
let lhs_storage = extension_storage(&lhs)?;
137-
let rhs_storage = extension_storage(&rhs)?;
138-
139-
let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?;
140-
let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?;
141-
142-
match_each_float_ptype!(lhs_flat.ptype(), |T| {
143-
let result: PrimitiveArray = (0..row_count)
144-
.map(|i| cosine_similarity_row(lhs_flat.row::<T>(i), rhs_flat.row::<T>(i)))
145+
let lhs = args.get(0)?.execute::<ExtensionArray>(ctx)?.into_array();
146+
let rhs = args.get(1)?.execute::<ExtensionArray>(ctx)?.into_array();
147+
148+
let len = args.row_count();
149+
150+
// Compute combined validity.
151+
let validity = lhs.validity()?.and(rhs.validity()?)?;
152+
153+
// Compute inner product and norms as columnar operations, and propagate the options.
154+
let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?;
155+
let norm_rhs_arr = L2Norm::try_new_array(options, rhs.clone(), len)?;
156+
let dot_arr = InnerProduct::try_new_array(options, lhs, rhs, len)?;
157+
158+
// Execute to get PrimitiveArrays.
159+
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
160+
let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?;
161+
let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?;
162+
163+
// Divide element-wise, guarding against zero norms.
164+
match_each_float_ptype!(dot.ptype(), |T| {
165+
let dots = dot.as_slice::<T>();
166+
let norms_l = norm_l.as_slice::<T>();
167+
let norms_r = norm_r.as_slice::<T>();
168+
let buffer: Buffer<T> = (0..len)
169+
.map(|i| {
170+
// TODO(connor): Would it be better to make this a binary multiply?
171+
// What happens when this overflows???
172+
let denom = norms_l[i] * norms_r[i];
173+
174+
if denom == T::zero() {
175+
T::zero()
176+
} else {
177+
dots[i] / denom
178+
}
179+
})
145180
.collect();
146181

147-
Ok(result.into_array())
182+
// SAFETY: The buffer length equals `len`, which matches the source validity length.
183+
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
148184
})
149185
}
150186

@@ -169,30 +205,16 @@ impl ScalarFnVTable for CosineSimilarity {
169205
}
170206
}
171207

172-
// TODO(connor): We should try to use a more performant library instead of doing this ourselves.
173-
/// Computes cosine similarity between two equal-length float slices.
174-
///
175-
/// Returns `dot(a, b) / (||a|| * ||b||)`. When either vector has zero norm, this naturally
176-
/// produces `NaN` via `0.0 / 0.0`, matching standard floating-point semantics.
177-
fn cosine_similarity_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
178-
let mut dot = T::zero();
179-
let mut norm_a = T::zero();
180-
let mut norm_b = T::zero();
181-
for i in 0..a.len() {
182-
dot = dot + a[i] * b[i];
183-
norm_a = norm_a + a[i] * a[i];
184-
norm_b = norm_b + b[i] * b[i];
185-
}
186-
dot / (norm_a.sqrt() * norm_b.sqrt())
187-
}
188-
189208
#[cfg(test)]
190209
mod tests {
191210
use rstest::rstest;
192211
use vortex_array::ArrayRef;
212+
use vortex_array::IntoArray;
193213
use vortex_array::ToCanonical;
214+
use vortex_array::arrays::MaskedArray;
194215
use vortex_array::arrays::ScalarFnArray;
195216
use vortex_array::scalar_fn::ScalarFn;
217+
use vortex_array::validity::Validity;
196218
use vortex_error::VortexResult;
197219

198220
use crate::scalar_fns::ApproxOptions;
@@ -239,8 +261,8 @@ mod tests {
239261
#[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])]
240262
// dot=24, both magnitudes=5 -> 24/25 = 0.96.
241263
#[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])]
242-
// Zero vector -> 0/0 -> NaN.
243-
#[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])]
264+
// Zero vector -> guarded to 0.0.
265+
#[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[0.0])]
244266
fn single_row(
245267
#[case] shape: &[usize],
246268
#[case] lhs_elems: &[f64],
@@ -367,4 +389,22 @@ mod tests {
367389
);
368390
Ok(())
369391
}
392+
393+
#[test]
394+
fn null_input_row() -> VortexResult<()> {
395+
// 2 rows of dim-2 vectors. Row 1 of rhs is masked as null.
396+
let lhs = tensor_array(&[2], &[3.0, 4.0, 1.0, 0.0])?;
397+
let rhs = tensor_array(&[2], &[3.0, 4.0, 0.0, 1.0])?;
398+
let rhs = MaskedArray::try_new(rhs, Validity::from_iter([true, false]))?.into_array();
399+
400+
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
401+
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
402+
let prim = result.as_array().to_primitive();
403+
404+
// Row 0: self-similarity = 1.0, row 1: null.
405+
assert!(prim.is_valid(0)?);
406+
assert!(!prim.is_valid(1)?);
407+
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
408+
Ok(())
409+
}
370410
}

0 commit comments

Comments
 (0)