Skip to content

Commit fbfa072

Browse files
authored
Optimize math expressions with Constant children (#7394)
## Summary Tracking issue: #7297 Adds some constant array optimizations to the tensor crate expressions `L2Norm`, `L2Denorm`, and `CosineSimilarity`. The remaining expressions `InnerProduct` and `SorfTransform` are a bit more complicated and deserve their own PR. ## Testing Adds more tests for these optimizations. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d4e7dca commit fbfa072

4 files changed

Lines changed: 407 additions & 17 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use vortex_error::vortex_ensure;
3232

3333
use crate::scalar_fns::inner_product::InnerProduct;
3434
use crate::scalar_fns::l2_denorm::L2Denorm;
35+
use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm;
3536
use crate::scalar_fns::l2_norm::L2Norm;
3637
use crate::utils::extract_l2_denorm_children;
3738
use crate::utils::validate_tensor_float_input;
@@ -133,6 +134,20 @@ impl ScalarFnVTable for CosineSimilarity {
133134
let mut rhs_ref = args.get(1)?;
134135
let len = args.row_count();
135136

137+
// If either side is a constant tensor-like extension array, eagerly normalize the single
138+
// stored row and re-wrap it as an `L2Denorm` whose children are both [`ConstantArray`]s.
139+
// The L2Denorm fast path below then picks it up.
140+
if let Some(lhs_constant) =
141+
try_build_constant_l2_denorm(&lhs_ref, len, ctx)?.map(|sfn| sfn.into_array())
142+
{
143+
lhs_ref = lhs_constant;
144+
}
145+
if let Some(rhs_constant) =
146+
try_build_constant_l2_denorm(&rhs_ref, len, ctx)?.map(|sfn| sfn.into_array())
147+
{
148+
rhs_ref = rhs_constant;
149+
}
150+
136151
// Check if any of our children have be already normalized.
137152
{
138153
let lhs_is_denorm = lhs_ref.is::<ExactScalarFn<L2Denorm>>();
@@ -249,8 +264,9 @@ impl CosineSimilarity {
249264
let (normalized, _) = extract_l2_denorm_children(denorm_ref);
250265

251266
let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?;
252-
let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?;
253267
let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?;
268+
269+
let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?;
254270
let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?;
255271

256272
// TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation.
@@ -575,4 +591,106 @@ mod tests {
575591
assert_close(&[prim.as_slice::<f64>()[0]], &[1.0]);
576592
Ok(())
577593
}
594+
595+
#[test]
596+
fn constant_lhs_matches_plain_tensor() -> VortexResult<()> {
597+
// The constant query `[1, 2, 2]` has norm 3, so its normalized form is `[1/3, 2/3, 2/3]`.
598+
// Expected cosine similarity against each row is `dot([1, 2, 2], row) / (3 * ||row||)`.
599+
let lhs = constant_tensor_array(&[3], &[1.0, 2.0, 2.0], 4)?;
600+
let rhs = tensor_array(
601+
&[3],
602+
&[
603+
1.0, 0.0, 0.0, // dot=1, ||rhs||=1, expected=1/3
604+
1.0, 2.0, 2.0, // dot=9, ||rhs||=3, expected=1
605+
0.0, 0.0, 1.0, // dot=2, ||rhs||=1, expected=2/3
606+
2.0, 1.0, 2.0, // dot=8, ||rhs||=3, expected=8/9
607+
],
608+
)?;
609+
assert_close(
610+
&eval_cosine_similarity(lhs, rhs, 4)?,
611+
&[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0],
612+
);
613+
Ok(())
614+
}
615+
616+
#[test]
617+
fn constant_rhs_matches_plain_tensor() -> VortexResult<()> {
618+
// Mirror of `constant_lhs_matches_plain_tensor` with the constant on the right.
619+
let lhs = tensor_array(
620+
&[3],
621+
&[
622+
1.0, 0.0, 0.0, //
623+
1.0, 2.0, 2.0, //
624+
0.0, 0.0, 1.0, //
625+
2.0, 1.0, 2.0, //
626+
],
627+
)?;
628+
let rhs = constant_tensor_array(&[3], &[1.0, 2.0, 2.0], 4)?;
629+
assert_close(
630+
&eval_cosine_similarity(lhs, rhs, 4)?,
631+
&[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0],
632+
);
633+
Ok(())
634+
}
635+
636+
#[test]
637+
fn both_constant_tensors() -> VortexResult<()> {
638+
// `[1, 0, 0]` vs `[1, 1, 0]`. dot=1, ||lhs||=1, ||rhs||=sqrt(2), expected=1/sqrt(2).
639+
let lhs = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 3)?;
640+
let rhs = constant_tensor_array(&[3], &[1.0, 1.0, 0.0], 3)?;
641+
let expected = 1.0 / 2.0_f64.sqrt();
642+
assert_close(
643+
&eval_cosine_similarity(lhs, rhs, 3)?,
644+
&[expected, expected, expected],
645+
);
646+
Ok(())
647+
}
648+
649+
#[test]
650+
fn constant_zero_norm_query() -> VortexResult<()> {
651+
// A zero-norm constant query must produce `0.0` for every row via the zero-norm guard in
652+
// `execute_one_denorm` and `execute_both_denorm`.
653+
let lhs = constant_tensor_array(&[3], &[0.0, 0.0, 0.0], 3)?;
654+
let rhs = tensor_array(
655+
&[3],
656+
&[
657+
1.0, 2.0, 3.0, //
658+
4.0, 5.0, 6.0, //
659+
7.0, 8.0, 9.0, //
660+
],
661+
)?;
662+
assert_close(&eval_cosine_similarity(lhs, rhs, 3)?, &[0.0, 0.0, 0.0]);
663+
Ok(())
664+
}
665+
666+
#[test]
667+
fn constant_self_similarity_nonunit() -> VortexResult<()> {
668+
// A non-unit constant query compared to itself must produce `1.0`. This exercises the
669+
// helper's division: after normalization, both sides must be exactly unit so the
670+
// L2Denorm fast path's inner product yields 1.
671+
let lhs = constant_tensor_array(&[3], &[3.0, 4.0, 0.0], 5)?;
672+
let rhs = constant_tensor_array(&[3], &[3.0, 4.0, 0.0], 5)?;
673+
assert_close(&eval_cosine_similarity(lhs, rhs, 5)?, &[1.0; 5]);
674+
Ok(())
675+
}
676+
677+
#[test]
678+
fn vector_constant_matches_plain() -> VortexResult<()> {
679+
// Exercise the `Vector` extension variant through the new pre-pass.
680+
let lhs = constant_vector_array(&[1.0, 2.0, 2.0], 4)?;
681+
let rhs = vector_array(
682+
3,
683+
&[
684+
1.0, 0.0, 0.0, //
685+
1.0, 2.0, 2.0, //
686+
0.0, 0.0, 1.0, //
687+
2.0, 1.0, 2.0, //
688+
],
689+
)?;
690+
assert_close(
691+
&eval_cosine_similarity(lhs, rhs, 4)?,
692+
&[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0],
693+
);
694+
Ok(())
695+
}
578696
}

0 commit comments

Comments
 (0)