Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

Expand Down
208 changes: 194 additions & 14 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use num_traits::Zero;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::ExactScalarFn;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::expr::Expression;
Expand All @@ -23,13 +24,16 @@ use vortex_array::scalar_fn::ExecutionArgs;
use vortex_array::scalar_fn::ScalarFn;
use vortex_array::scalar_fn::ScalarFnId;
use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;

use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_norm::L2Norm;
use crate::utils::extract_l2_denorm_children;
use crate::utils::validate_tensor_float_input;

/// Cosine similarity between two columns.
Expand Down Expand Up @@ -126,35 +130,47 @@ impl ScalarFnVTable for CosineSimilarity {
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let lhs = args.get(0)?.execute::<ExtensionArray>(ctx)?;
let rhs = args.get(1)?.execute::<ExtensionArray>(ctx)?;
let mut lhs_ref = args.get(0)?;
let mut rhs_ref = args.get(1)?;
let len = args.row_count();

// Compute combined validity.
let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?;
// Check if any of our children have be already normalized.
{
let lhs_is_denorm = lhs_ref.is::<ExactScalarFn<L2Denorm>>();
let rhs_is_denorm = rhs_ref.is::<ExactScalarFn<L2Denorm>>();

if lhs_is_denorm && rhs_is_denorm {
return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
} else if lhs_is_denorm || rhs_is_denorm {
if rhs_is_denorm {
(lhs_ref, rhs_ref) = (rhs_ref, lhs_ref);
}
return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx);
}
}

let lhs = lhs.into_array();
let rhs = rhs.into_array();
// Compute combined validity.
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;

// Compute inner product and norms as columnar operations, and propagate the options.
let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?;
let norm_rhs_arr = L2Norm::try_new_array(options, rhs.clone(), len)?;
let dot_arr = InnerProduct::try_new_array(options, lhs, rhs, len)?;
let norm_lhs_arr = L2Norm::try_new_array(options, lhs_ref.clone(), len)?;
let norm_rhs_arr = L2Norm::try_new_array(options, rhs_ref.clone(), len)?;
let dot_arr = InnerProduct::try_new_array(options, lhs_ref, rhs_ref, len)?;

// Execute to get PrimitiveArrays.
// Execute to get the inner product and norms of the arrays. We only fully decompress
// because we need to perform special logic (guard against 0) during division.
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?;
let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?;

// Divide element-wise, guarding against zero norms.
// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
// TODO(connor): This can be written in a more SIMD-friendly manner.
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let norms_l = norm_l.as_slice::<T>();
let norms_r = norm_r.as_slice::<T>();
let buffer: Buffer<T> = (0..len)
.map(|i| {
// TODO(connor): Would it be better to make this a binary multiply?
// What happens when this overflows???
let denom = norms_l[i] * norms_r[i];

if denom == T::zero() {
Expand Down Expand Up @@ -191,6 +207,74 @@ impl ScalarFnVTable for CosineSimilarity {
}
}

impl CosineSimilarity {
/// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`.
fn execute_both_denorm(
&self,
options: &ApproxOptions,
lhs_ref: &ArrayRef,
rhs_ref: &ArrayRef,
len: usize,
_ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?;

let (normalized_l, _) = extract_l2_denorm_children(lhs_ref);
let (normalized_r, _) = extract_l2_denorm_children(rhs_ref);

// Dot product of already-normalized children IS the cosine similarity.
let dot =
InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?.into_array();

if !matches!(validity, Validity::NonNullable) {
// Masking always changes the nullability to nullable.
dot.mask(validity.to_array(len))
} else {
Ok(dot)
}
Comment on lines +229 to +234
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of awful

}

/// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`.
///
/// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`.
fn execute_one_denorm(
&self,
options: &ApproxOptions,
denorm_ref: &ArrayRef,
plain_ref: &ArrayRef,
len: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?;

let (normalized, _) = extract_l2_denorm_children(denorm_ref);

let dot_arr = InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)?;
let norm_arr = L2Norm::try_new_array(options, plain_ref.clone(), len)?;
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?;

// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
// TODO(connor): This can be written in a more SIMD-friendly manner.
match_each_float_ptype!(dot.ptype(), |T| {
let dots = dot.as_slice::<T>();
let norms = plain_norm.as_slice::<T>();
let buffer: Buffer<T> = (0..len)
.map(|i| {
if norms[i] == T::zero() {
T::zero()
} else {
dots[i] / norms[i]
}
})
.collect();

// SAFETY: The buffer length equals `len`, which matches the source validity length.
Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array())
})
}
}

#[cfg(test)]
mod tests {
use std::sync::LazyLock;
Expand All @@ -210,6 +294,7 @@ mod tests {

use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::constant_tensor_array;
use crate::utils::test_helpers::constant_vector_array;
Expand Down Expand Up @@ -403,4 +488,99 @@ mod tests {
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
Ok(())
}

/// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms.
fn l2_denorm_array(
shape: &[usize],
normalized_elements: &[f64],
norms: &[f64],
) -> VortexResult<ArrayRef> {
let len = norms.len();
let normalized = tensor_array(shape, normalized_elements)?;
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
let mut ctx = SESSION.create_execution_ctx();
Ok(
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)?
.into_array(),
)
}

#[test]
fn both_denorm_self_similarity() -> VortexResult<()> {
// [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8].
// [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0].
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;

// Self-similarity should always be 1.0.
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]);
Ok(())
}

#[test]
fn both_denorm_orthogonal() -> VortexResult<()> {
// [3.0, 0.0] normalized [1.0, 0.0], norm 3.0.
// [0.0, 4.0] normalized [0.0, 1.0], norm 4.0.
let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?;
let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?;

assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]);
Ok(())
}

#[test]
fn both_denorm_zero_norm() -> VortexResult<()> {
// Zero-norm row: normalized is [0.0, 0.0], norm is 0.0.
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;

// Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0.
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
Ok(())
}

#[test]
fn one_side_denorm_lhs() -> VortexResult<()> {
// LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
// RHS is plain [3.0, 4.0].
// cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0.
let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;
let rhs = tensor_array(&[2], &[3.0, 4.0])?;

assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]);
Ok(())
}

#[test]
fn one_side_denorm_rhs() -> VortexResult<()> {
// LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0].
// cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6.
let lhs = tensor_array(&[2], &[1.0, 0.0])?;
let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?;

assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]);
Ok(())
}

#[test]
fn both_denorm_null_norms() -> VortexResult<()> {
// Row 0: valid, row 1: null (via nullable norms on rhs).
let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;

let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
let mut ctx = SESSION.create_execution_ctx();
let rhs =
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_r, norms_r, 2, &mut ctx)?
.into_array();

let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;

assert!(prim.is_valid(0)?);
assert!(!prim.is_valid(1)?);
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
Ok(())
}
}
Loading
Loading