Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,42 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::return_dtype(&sel

pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>

pub mod vortex_tensor::scalar_fns::l2_denorm

pub struct vortex_tensor::scalar_fns::l2_denorm::L2Denorm

impl vortex_tensor::scalar_fns::l2_denorm::L2Denorm

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn<vortex_tensor::scalar_fns::l2_denorm::L2Denorm>

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, normalized: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>

impl core::clone::Clone for vortex_tensor::scalar_fns::l2_denorm::L2Denorm

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::clone(&self) -> vortex_tensor::scalar_fns::l2_denorm::L2Denorm

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm

pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::id(&self) -> vortex_array::scalar_fn::ScalarFnId

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::is_fallible(&self, _options: &Self::Options) -> bool

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::is_null_sensitive(&self, _options: &Self::Options) -> bool

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>

pub mod vortex_tensor::scalar_fns::l2_norm

pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
Expand Down
2 changes: 2 additions & 0 deletions vortex-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::encodings::turboquant::TurboQuant;
use crate::fixed_shape::FixedShapeTensor;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_denorm::L2Denorm;
use crate::scalar_fns::l2_norm::L2Norm;
use crate::vector::Vector;

Expand All @@ -36,5 +37,6 @@ pub fn initialize(session: &VortexSession) {

session.scalar_fns().register(CosineSimilarity);
session.scalar_fns().register(InnerProduct);
session.scalar_fns().register(L2Denorm);
session.scalar_fns().register(L2Norm);
}
17 changes: 2 additions & 15 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ use vortex_array::scalar_fn::ScalarFnVTable;
use vortex_buffer::Buffer;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;

use crate::matcher::AnyTensor;
use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_norm::L2Norm;
use crate::utils::validate_tensor_float_input;

/// Cosine similarity between two columns.
///
Expand Down Expand Up @@ -114,20 +113,8 @@ impl ScalarFnVTable for CosineSimilarity {
);

// We don't need to look at rhs anymore since we know lhs and rhs are equal.

// Both inputs must be tensor-like extension types.
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
vortex_err!("CosineSimilarity lhs must be an extension type, got {lhs}")
})?;

let tensor_match = lhs_ext.metadata_opt::<AnyTensor>().ok_or_else(|| {
vortex_err!("CosineSimilarity inputs must be an `AnyTensor`, got {lhs}")
})?;
let tensor_match = validate_tensor_float_input(lhs)?;
let ptype = tensor_match.element_ptype();
vortex_ensure!(
ptype.is_float(),
"CosineSimilarity element dtype must be a float primitive, got {ptype}"
);

let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
Ok(DType::Primitive(ptype, nullability))
Expand Down
Loading
Loading