Skip to content

Commit b9c47cf

Browse files
authored
L2 Denorm expression (#7329)
## Summary Adds a new L2 denormalization expression. This is essentially just a broadcast multiplication expression with the additional constraint that the vector array must be normalized. ## Testing Basic testing. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 0b52e48 commit b9c47cf

File tree

7 files changed

+577
-32
lines changed

7 files changed

+577
-32
lines changed

vortex-tensor/public-api.lock

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,42 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::return_dtype(&sel
430430

431431
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
432432

433+
pub mod vortex_tensor::scalar_fns::l2_denorm
434+
435+
pub struct vortex_tensor::scalar_fns::l2_denorm::L2Denorm
436+
437+
impl vortex_tensor::scalar_fns::l2_denorm::L2Denorm
438+
439+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn<vortex_tensor::scalar_fns::l2_denorm::L2Denorm>
440+
441+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, normalized: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
442+
443+
impl core::clone::Clone for vortex_tensor::scalar_fns::l2_denorm::L2Denorm
444+
445+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::clone(&self) -> vortex_tensor::scalar_fns::l2_denorm::L2Denorm
446+
447+
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm
448+
449+
pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_tensor::scalar_fns::ApproxOptions
450+
451+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity
452+
453+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
454+
455+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
456+
457+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
458+
459+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::id(&self) -> vortex_array::scalar_fn::ScalarFnId
460+
461+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::is_fallible(&self, _options: &Self::Options) -> bool
462+
463+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::is_null_sensitive(&self, _options: &Self::Options) -> bool
464+
465+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>
466+
467+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
468+
433469
pub mod vortex_tensor::scalar_fns::l2_norm
434470

435471
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm

vortex-tensor/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::encodings::turboquant::TurboQuant;
1414
use crate::fixed_shape::FixedShapeTensor;
1515
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
1616
use crate::scalar_fns::inner_product::InnerProduct;
17+
use crate::scalar_fns::l2_denorm::L2Denorm;
1718
use crate::scalar_fns::l2_norm::L2Norm;
1819
use crate::vector::Vector;
1920

@@ -36,5 +37,6 @@ pub fn initialize(session: &VortexSession) {
3637

3738
session.scalar_fns().register(CosineSimilarity);
3839
session.scalar_fns().register(InnerProduct);
40+
session.scalar_fns().register(L2Denorm);
3941
session.scalar_fns().register(L2Norm);
4042
}

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@ use vortex_array::scalar_fn::ScalarFnVTable;
2626
use vortex_buffer::Buffer;
2727
use vortex_error::VortexResult;
2828
use vortex_error::vortex_ensure;
29-
use vortex_error::vortex_err;
3029

31-
use crate::matcher::AnyTensor;
3230
use crate::scalar_fns::ApproxOptions;
3331
use crate::scalar_fns::inner_product::InnerProduct;
3432
use crate::scalar_fns::l2_norm::L2Norm;
33+
use crate::utils::validate_tensor_float_input;
3534

3635
/// Cosine similarity between two columns.
3736
///
@@ -114,20 +113,8 @@ impl ScalarFnVTable for CosineSimilarity {
114113
);
115114

116115
// We don't need to look at rhs anymore since we know lhs and rhs are equal.
117-
118-
// Both inputs must be tensor-like extension types.
119-
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
120-
vortex_err!("CosineSimilarity lhs must be an extension type, got {lhs}")
121-
})?;
122-
123-
let tensor_match = lhs_ext.metadata_opt::<AnyTensor>().ok_or_else(|| {
124-
vortex_err!("CosineSimilarity inputs must be an `AnyTensor`, got {lhs}")
125-
})?;
116+
let tensor_match = validate_tensor_float_input(lhs)?;
126117
let ptype = tensor_match.element_ptype();
127-
vortex_ensure!(
128-
ptype.is_float(),
129-
"CosineSimilarity element dtype must be a float primitive, got {ptype}"
130-
);
131118

132119
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
133120
Ok(DType::Primitive(ptype, nullability))

0 commit comments

Comments
 (0)