Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&se

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity

pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions
pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity

Expand Down Expand Up @@ -162,7 +162,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor

impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm

pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity

Expand All @@ -182,6 +182,40 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::return_dtype(&self, _options:

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>

pub enum vortex_tensor::scalar_fns::ApproxOptions

pub vortex_tensor::scalar_fns::ApproxOptions::Approximate

pub vortex_tensor::scalar_fns::ApproxOptions::Exact

impl core::clone::Clone for vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::ApproxOptions::clone(&self) -> vortex_tensor::scalar_fns::ApproxOptions

impl core::cmp::Eq for vortex_tensor::scalar_fns::ApproxOptions

impl core::cmp::PartialEq for vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::ApproxOptions::eq(&self, other: &vortex_tensor::scalar_fns::ApproxOptions) -> bool

impl core::default::Default for vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::ApproxOptions::default() -> vortex_tensor::scalar_fns::ApproxOptions

impl core::fmt::Debug for vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::ApproxOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::fmt::Display for vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::ApproxOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl core::hash::Hash for vortex_tensor::scalar_fns::ApproxOptions

pub fn vortex_tensor::scalar_fns::ApproxOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H)

impl core::marker::StructuralPartialEq for vortex_tensor::scalar_fns::ApproxOptions

pub mod vortex_tensor::vector

pub struct vortex_tensor::vector::Vector
Expand Down
8 changes: 4 additions & 4 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ use vortex::error::vortex_err;
use vortex::expr::Expression;
use vortex::scalar_fn::Arity;
use vortex::scalar_fn::ChildName;
use vortex::scalar_fn::EmptyOptions;
use vortex::scalar_fn::ExecutionArgs;
use vortex::scalar_fn::ScalarFnId;
use vortex::scalar_fn::ScalarFnVTable;

use crate::matcher::AnyTensor;
use crate::scalar_fns::ApproxOptions;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extension_storage;
Expand All @@ -48,7 +48,7 @@ use crate::utils::extract_flat_elements;
pub struct CosineSimilarity;

impl ScalarFnVTable for CosineSimilarity {
type Options = EmptyOptions;
type Options = ApproxOptions;

fn id(&self) -> ScalarFnId {
ScalarFnId::new_ref("vortex.tensor.cosine_similarity")
Expand Down Expand Up @@ -192,9 +192,9 @@ mod tests {
use vortex::array::ToCanonical;
use vortex::array::arrays::ScalarFnArray;
use vortex::error::VortexResult;
use vortex::scalar_fn::EmptyOptions;
use vortex::scalar_fn::ScalarFn;

use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::constant_tensor_array;
Expand All @@ -204,7 +204,7 @@ mod tests {

/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased();
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
let prim = result.to_primitive();
Ok(prim.as_slice::<f64>().to_vec())
Expand Down
8 changes: 4 additions & 4 deletions vortex-tensor/src/scalar_fns/l2_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ use vortex::error::vortex_err;
use vortex::expr::Expression;
use vortex::scalar_fn::Arity;
use vortex::scalar_fn::ChildName;
use vortex::scalar_fn::EmptyOptions;
use vortex::scalar_fn::ExecutionArgs;
use vortex::scalar_fn::ScalarFnId;
use vortex::scalar_fn::ScalarFnVTable;

use crate::matcher::AnyTensor;
use crate::scalar_fns::ApproxOptions;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extension_storage;
Expand All @@ -43,7 +43,7 @@ use crate::utils::extract_flat_elements;
pub struct L2Norm;

impl ScalarFnVTable for L2Norm {
type Options = EmptyOptions;
type Options = ApproxOptions;

fn id(&self) -> ScalarFnId {
ScalarFnId::new_ref("vortex.tensor.l2_norm")
Expand Down Expand Up @@ -159,17 +159,17 @@ mod tests {
use vortex::array::ToCanonical;
use vortex::array::arrays::ScalarFnArray;
use vortex::error::VortexResult;
use vortex::scalar_fn::EmptyOptions;
use vortex::scalar_fn::ScalarFn;

use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::l2_norm::L2Norm;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;

/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
let scalar_fn = ScalarFn::new(L2Norm, EmptyOptions).erased();
let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?;
let prim = result.to_primitive();
Ok(prim.as_slice::<f64>().to_vec())
Expand Down
19 changes: 19 additions & 0 deletions vortex-tensor/src/scalar_fns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,24 @@

//! Scalar function expressions defined on tensor and tensor-like extension types.

use std::fmt;

pub mod cosine_similarity;
pub mod l2_norm;

/// Options for tensor-related expressions that might have error.
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
pub enum ApproxOptions {
#[default]
Exact,
Approximate,
}

impl fmt::Display for ApproxOptions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Exact => write!(f, "Exact"),
Self::Approximate => write!(f, "Approximate"),
}
}
}
Loading