Skip to content

Commit 47f97f4

Browse files
authored
Add inner product and cosine similarity optimizations (#7364)
## Summary Adds compute optimizations for the `InnerProduct` and `CosineSimilarity` changes, mostly related to when a child has already been decomposed into normalized and norms via `L2Denorm` scalar fn array. ## Testing Adds some more tests. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent ccc6590 commit 47f97f4

File tree

5 files changed

+527
-71
lines changed

5 files changed

+527
-71
lines changed

vortex-tensor/public-api.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt
412412

413413
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
414414

415-
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
415+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
416416

417417
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
418418

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 194 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ use num_traits::Zero;
99
use vortex_array::ArrayRef;
1010
use vortex_array::ExecutionCtx;
1111
use vortex_array::IntoArray;
12-
use vortex_array::arrays::ExtensionArray;
1312
use vortex_array::arrays::PrimitiveArray;
1413
use vortex_array::arrays::ScalarFnArray;
14+
use vortex_array::arrays::scalar_fn::ExactScalarFn;
15+
use vortex_array::builtins::ArrayBuiltins;
1516
use vortex_array::dtype::DType;
1617
use vortex_array::dtype::Nullability;
1718
use vortex_array::expr::Expression;
@@ -23,13 +24,16 @@ use vortex_array::scalar_fn::ExecutionArgs;
2324
use vortex_array::scalar_fn::ScalarFn;
2425
use vortex_array::scalar_fn::ScalarFnId;
2526
use vortex_array::scalar_fn::ScalarFnVTable;
27+
use vortex_array::validity::Validity;
2628
use vortex_buffer::Buffer;
2729
use vortex_error::VortexResult;
2830
use vortex_error::vortex_ensure;
2931

3032
use crate::scalar_fns::ApproxOptions;
3133
use crate::scalar_fns::inner_product::InnerProduct;
34+
use crate::scalar_fns::l2_denorm::L2Denorm;
3235
use crate::scalar_fns::l2_norm::L2Norm;
36+
use crate::utils::extract_l2_denorm_children;
3337
use crate::utils::validate_tensor_float_input;
3438

3539
/// Cosine similarity between two columns.
@@ -126,35 +130,47 @@ impl ScalarFnVTable for CosineSimilarity {
126130
args: &dyn ExecutionArgs,
127131
ctx: &mut ExecutionCtx,
128132
) -> VortexResult<ArrayRef> {
129-
let lhs = args.get(0)?.execute::<ExtensionArray>(ctx)?;
130-
let rhs = args.get(1)?.execute::<ExtensionArray>(ctx)?;
133+
let mut lhs_ref = args.get(0)?;
134+
let mut rhs_ref = args.get(1)?;
131135
let len = args.row_count();
132136

133-
// Compute combined validity.
134-
let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?;
137+
// Check if any of our children have be already normalized.
138+
{
139+
let lhs_is_denorm = lhs_ref.is::<ExactScalarFn<L2Denorm>>();
140+
let rhs_is_denorm = rhs_ref.is::<ExactScalarFn<L2Denorm>>();
141+
142+
if lhs_is_denorm && rhs_is_denorm {
143+
return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
144+
} else if lhs_is_denorm || rhs_is_denorm {
145+
if rhs_is_denorm {
146+
(lhs_ref, rhs_ref) = (rhs_ref, lhs_ref);
147+
}
148+
return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
149+
}
150+
}
135151

136-
let lhs = lhs.into_array();
137-
let rhs = rhs.into_array();
152+
// Compute combined validity.
153+
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
138154

139155
// Compute inner product and norms as columnar operations, and propagate the options.
140-
let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?;
141-
let norm_rhs_arr = L2Norm::try_new_array(options, rhs.clone(), len)?;
142-
let dot_arr = InnerProduct::try_new_array(options, lhs, rhs, len)?;
156+
let norm_lhs_arr = L2Norm::try_new_array(options, lhs_ref.clone(), len)?;
157+
let norm_rhs_arr = L2Norm::try_new_array(options, rhs_ref.clone(), len)?;
158+
let dot_arr = InnerProduct::try_new_array(options, lhs_ref, rhs_ref, len)?;
143159

144-
// Execute to get PrimitiveArrays.
160+
// Execute to get the inner product and norms of the arrays. We only fully decompress
161+
// because we need to perform special logic (guard against 0) during division.
145162
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
146163
let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?;
147164
let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?;
148165

149-
// Divide element-wise, guarding against zero norms.
166+
// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
167+
// TODO(connor): This can be written in a more SIMD-friendly manner.
150168
match_each_float_ptype!(dot.ptype(), |T| {
151169
let dots = dot.as_slice::<T>();
152170
let norms_l = norm_l.as_slice::<T>();
153171
let norms_r = norm_r.as_slice::<T>();
154172
let buffer: Buffer<T> = (0..len)
155173
.map(|i| {
156-
// TODO(connor): Would it be better to make this a binary multiply?
157-
// What happens when this overflows???
158174
let denom = norms_l[i] * norms_r[i];
159175

160176
if denom == T::zero() {
@@ -191,6 +207,74 @@ impl ScalarFnVTable for CosineSimilarity {
191207
}
192208
}
193209

210+
impl CosineSimilarity {
211+
/// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`.
212+
fn execute_both_denorm(
213+
&self,
214+
options: &ApproxOptions,
215+
lhs_ref: &ArrayRef,
216+
rhs_ref: &ArrayRef,
217+
len: usize,
218+
_ctx: &mut ExecutionCtx,
219+
) -> VortexResult<ArrayRef> {
220+
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;
221+
222+
let (normalized_l, _) = extract_l2_denorm_children(lhs_ref);
223+
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);
224+
225+
// Dot product of already-normalized children IS the cosine similarity.
226+
let dot =
227+
InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?.into_array();
228+
229+
if !matches!(validity, Validity::NonNullable) {
230+
// Masking always changes the nullability to nullable.
231+
dot.mask(validity.to_array(len))
232+
} else {
233+
Ok(dot)
234+
}
235+
}
236+
237+
/// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`.
238+
///
239+
/// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`.
240+
fn execute_one_denorm(
241+
&self,
242+
options: &ApproxOptions,
243+
denorm_ref: &ArrayRef,
244+
plain_ref: &ArrayRef,
245+
len: usize,
246+
ctx: &mut ExecutionCtx,
247+
) -> VortexResult<ArrayRef> {
248+
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;
249+
250+
let (normalized, _) = extract_l2_denorm_children(denorm_ref);
251+
252+
let dot_arr = InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)?;
253+
let norm_arr = L2Norm::try_new_array(options, plain_ref.clone(), len)?;
254+
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
255+
let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?;
256+
257+
// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
258+
// TODO(connor): This can be written in a more SIMD-friendly manner.
259+
match_each_float_ptype!(dot.ptype(), |T| {
260+
let dots = dot.as_slice::<T>();
261+
let norms = plain_norm.as_slice::<T>();
262+
let buffer: Buffer<T> = (0..len)
263+
.map(|i| {
264+
if norms[i] == T::zero() {
265+
T::zero()
266+
} else {
267+
dots[i] / norms[i]
268+
}
269+
})
270+
.collect();
271+
272+
// SAFETY: The buffer length equals `len`, which matches the source validity length.
273+
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
274+
})
275+
}
276+
}
277+
194278
#[cfg(test)]
195279
mod tests {
196280
use std::sync::LazyLock;
@@ -210,6 +294,7 @@ mod tests {
210294

211295
use crate::scalar_fns::ApproxOptions;
212296
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
297+
use crate::scalar_fns::l2_denorm::L2Denorm;
213298
use crate::utils::test_helpers::assert_close;
214299
use crate::utils::test_helpers::constant_tensor_array;
215300
use crate::utils::test_helpers::constant_vector_array;
@@ -403,4 +488,99 @@ mod tests {
403488
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
404489
Ok(())
405490
}
491+
492+
/// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms.
493+
fn l2_denorm_array(
494+
shape: &[usize],
495+
normalized_elements: &[f64],
496+
norms: &[f64],
497+
) -> VortexResult<ArrayRef> {
498+
let len = norms.len();
499+
let normalized = tensor_array(shape, normalized_elements)?;
500+
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
501+
let mut ctx = SESSION.create_execution_ctx();
502+
Ok(
503+
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)?
504+
.into_array(),
505+
)
506+
}
507+
508+
#[test]
509+
fn both_denorm_self_similarity() -> VortexResult<()> {
510+
// [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
511+
// [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
512+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
513+
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
514+
515+
// Self-similarity should always be 1.0.
516+
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]);
517+
Ok(())
518+
}
519+
520+
#[test]
521+
fn both_denorm_orthogonal() -> VortexResult<()> {
522+
// [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
523+
// [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
524+
let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?;
525+
let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?;
526+
527+
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]);
528+
Ok(())
529+
}
530+
531+
#[test]
532+
fn both_denorm_zero_norm() -> VortexResult<()> {
533+
// Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
534+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?;
535+
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
536+
537+
// Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
538+
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
539+
Ok(())
540+
}
541+
542+
#[test]
543+
fn one_side_denorm_lhs() -> VortexResult<()> {
544+
// LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
545+
// RHS is plain [3.0, 4.0].
546+
// cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
547+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
548+
let rhs = tensor_array(&[2], &[3.0, 4.0])?;
549+
550+
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]);
551+
Ok(())
552+
}
553+
554+
#[test]
555+
fn one_side_denorm_rhs() -> VortexResult<()> {
556+
// LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
557+
// cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
558+
let lhs = tensor_array(&[2], &[1.0, 0.0])?;
559+
let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
560+
561+
assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]);
562+
Ok(())
563+
}
564+
565+
#[test]
566+
fn both_denorm_null_norms() -> VortexResult<()> {
567+
// Row 0: valid, row 1: null (via nullable norms on rhs).
568+
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
569+
570+
let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
571+
let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
572+
let mut ctx = SESSION.create_execution_ctx();
573+
let rhs =
574+
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_r, norms_r, 2, &mut ctx)?
575+
.into_array();
576+
577+
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
578+
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
579+
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
580+
581+
assert!(prim.is_valid(0)?);
582+
assert!(!prim.is_valid(1)?);
583+
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
584+
Ok(())
585+
}
406586
}

0 commit comments

Comments
 (0)