Skip to content

Commit 7b599d9

Browse files
committed
add inner product and cosine similarity optimizations
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent c660607 commit 7b599d9

4 files changed

Lines changed: 399 additions & 24 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt
412412

413413
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
414414

415-
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
415+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
416416

417417
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
418418

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 192 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ use num_traits::Zero;
99
use vortex_array::ArrayRef;
1010
use vortex_array::ExecutionCtx;
1111
use vortex_array::IntoArray;
12-
use vortex_array::arrays::ExtensionArray;
1312
use vortex_array::arrays::PrimitiveArray;
1413
use vortex_array::arrays::ScalarFnArray;
14+
use vortex_array::arrays::scalar_fn::ExactScalarFn;
15+
use vortex_array::builtins::ArrayBuiltins;
1516
use vortex_array::dtype::DType;
1617
use vortex_array::dtype::Nullability;
1718
use vortex_array::expr::Expression;
@@ -29,7 +30,9 @@ use vortex_error::vortex_ensure;
2930

3031
use crate::scalar_fns::ApproxOptions;
3132
use crate::scalar_fns::inner_product::InnerProduct;
33+
use crate::scalar_fns::l2_denorm::L2Denorm;
3234
use crate::scalar_fns::l2_norm::L2Norm;
35+
use crate::utils::extract_l2_denorm_children;
3336
use crate::utils::validate_tensor_float_input;
3437

3538
/// Cosine similarity between two columns.
@@ -126,35 +129,47 @@ impl ScalarFnVTable for CosineSimilarity {
126129
args: &dyn ExecutionArgs,
127130
ctx: &mut ExecutionCtx,
128131
) -> VortexResult<ArrayRef> {
129-
let lhs = args.get(0)?.execute::<ExtensionArray>(ctx)?;
130-
let rhs = args.get(1)?.execute::<ExtensionArray>(ctx)?;
132+
let mut lhs_ref = args.get(0)?;
133+
let mut rhs_ref = args.get(1)?;
131134
let len = args.row_count();
132135

133-
// Compute combined validity.
134-
let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?;
136+
// Check if any of our children have be already normalized.
137+
{
138+
let lhs_is_denorm = lhs_ref.is::<ExactScalarFn<L2Denorm>>();
139+
let rhs_is_denorm = rhs_ref.is::<ExactScalarFn<L2Denorm>>();
140+
141+
if lhs_is_denorm && rhs_is_denorm {
142+
return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
143+
} else if lhs_is_denorm || rhs_is_denorm {
144+
if rhs_is_denorm {
145+
(lhs_ref, rhs_ref) = (rhs_ref, lhs_ref);
146+
}
147+
return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
148+
}
149+
}
135150

136-
let lhs = lhs.into_array();
137-
let rhs = rhs.into_array();
151+
// Compute combined validity.
152+
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
138153

139154
// Compute inner product and norms as columnar operations, and propagate the options.
140-
let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?;
141-
let norm_rhs_arr = L2Norm::try_new_array(options, rhs.clone(), len)?;
142-
let dot_arr = InnerProduct::try_new_array(options, lhs, rhs, len)?;
155+
let norm_lhs_arr = L2Norm::try_new_array(options, lhs_ref.clone(), len)?;
156+
let norm_rhs_arr = L2Norm::try_new_array(options, rhs_ref.clone(), len)?;
157+
let dot_arr = InnerProduct::try_new_array(options, lhs_ref, rhs_ref, len)?;
143158

144-
// Execute to get PrimitiveArrays.
159+
// Execute to get the inner product and norms of the arrays. We only fully decompress
160+
// because we need to perform special logic (guard against 0) during division.
145161
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
146162
let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?;
147163
let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?;
148164

149-
// Divide element-wise, guarding against zero norms.
165+
// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
166+
// TODO(connor): This can be written in a more SIMD-friendly manner.
150167
match_each_float_ptype!(dot.ptype(), |T| {
151168
let dots = dot.as_slice::<T>();
152169
let norms_l = norm_l.as_slice::<T>();
153170
let norms_r = norm_r.as_slice::<T>();
154171
let buffer: Buffer<T> = (0..len)
155172
.map(|i| {
156-
// TODO(connor): Would it be better to make this a binary multiply?
157-
// What happens when this overflows???
158173
let denom = norms_l[i] * norms_r[i];
159174

160175
if denom == T::zero() {
@@ -191,6 +206,68 @@ impl ScalarFnVTable for CosineSimilarity {
191206
}
192207
}
193208

209+
impl CosineSimilarity {
210+
/// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`.
211+
fn execute_both_denorm(
212+
&self,
213+
options: &ApproxOptions,
214+
lhs_ref: &ArrayRef,
215+
rhs_ref: &ArrayRef,
216+
len: usize,
217+
_ctx: &mut ExecutionCtx,
218+
) -> VortexResult<ArrayRef> {
219+
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
220+
221+
let (normalized_l, _) = extract_l2_denorm_children(lhs_ref);
222+
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);
223+
224+
// Dot product of already-normalized children IS the cosine similarity.
225+
InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?
226+
.into_array()
227+
.mask(validity.to_array(len))
228+
}
229+
230+
/// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`.
231+
///
232+
/// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`.
233+
fn execute_one_denorm(
234+
&self,
235+
options: &ApproxOptions,
236+
denorm_ref: &ArrayRef,
237+
plain_ref: &ArrayRef,
238+
len: usize,
239+
ctx: &mut ExecutionCtx,
240+
) -> VortexResult<ArrayRef> {
241+
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;
242+
243+
let (normalized, _) = extract_l2_denorm_children(denorm_ref);
244+
245+
let dot_arr = InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)?;
246+
let norm_arr = L2Norm::try_new_array(options, plain_ref.clone(), len)?;
247+
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
248+
let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?;
249+
250+
// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
251+
// TODO(connor): This can be written in a more SIMD-friendly manner.
252+
match_each_float_ptype!(dot.ptype(), |T| {
253+
let dots = dot.as_slice::<T>();
254+
let norms = plain_norm.as_slice::<T>();
255+
let buffer: Buffer<T> = (0..len)
256+
.map(|i| {
257+
if norms[i] == T::zero() {
258+
T::zero()
259+
} else {
260+
dots[i] / norms[i]
261+
}
262+
})
263+
.collect();
264+
265+
// SAFETY: The buffer length equals `len`, which matches the source validity length.
266+
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
267+
})
268+
}
269+
}
270+
194271
#[cfg(test)]
195272
mod tests {
196273
use std::sync::LazyLock;
@@ -403,4 +480,105 @@ mod tests {
403480
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
404481
Ok(())
405482
}
483+
484+
/// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms.
485+
fn l2_denorm_array(
486+
shape: &[usize],
487+
normalized_elements: &[f64],
488+
norms: &[f64],
489+
) -> VortexResult<ArrayRef> {
490+
let len = norms.len();
491+
let normalized = tensor_array(shape, normalized_elements)?;
492+
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
493+
Ok(crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(
494+
&ApproxOptions::Exact,
495+
normalized,
496+
norms,
497+
len,
498+
)?
499+
.into_array())
500+
}
501+
502+
#[test]
503+
fn both_denorm_self_similarity() -> VortexResult<()> {
504+
// [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
505+
// [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
506+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
507+
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
508+
509+
// Self-similarity should always be 1.0.
510+
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]);
511+
Ok(())
512+
}
513+
514+
#[test]
515+
fn both_denorm_orthogonal() -> VortexResult<()> {
516+
// [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
517+
// [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
518+
let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?;
519+
let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?;
520+
521+
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]);
522+
Ok(())
523+
}
524+
525+
#[test]
526+
fn both_denorm_zero_norm() -> VortexResult<()> {
527+
// Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
528+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?;
529+
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
530+
531+
// Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
532+
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
533+
Ok(())
534+
}
535+
536+
#[test]
537+
fn one_side_denorm_lhs() -> VortexResult<()> {
538+
// LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
539+
// RHS is plain [3.0, 4.0].
540+
// cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
541+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
542+
let rhs = tensor_array(&[2], &[3.0, 4.0])?;
543+
544+
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]);
545+
Ok(())
546+
}
547+
548+
#[test]
549+
fn one_side_denorm_rhs() -> VortexResult<()> {
550+
// LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
551+
// cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
552+
let lhs = tensor_array(&[2], &[1.0, 0.0])?;
553+
let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
554+
555+
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]);
556+
Ok(())
557+
}
558+
559+
#[test]
560+
fn both_denorm_null_norms() -> VortexResult<()> {
561+
// Row 0: valid, row 1: null (via nullable norms on rhs).
562+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
563+
564+
let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
565+
let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
566+
let rhs = crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(
567+
&ApproxOptions::Exact,
568+
normalized_r,
569+
norms_r,
570+
2,
571+
)?
572+
.into_array();
573+
574+
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
575+
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
576+
let mut ctx = SESSION.create_execution_ctx();
577+
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
578+
579+
assert!(prim.is_valid(0)?);
580+
assert!(!prim.is_valid(1)?);
581+
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
582+
Ok(())
583+
}
406584
}

0 commit comments

Comments
 (0)