Skip to content

Commit c5d2745

Browse files
authored
Approximate expressions for tensor types (#7226)
## Summary Tracking Issue: #6865 We will want to support approximate cosine similarity from turboquant (and likely many other kinds of expressions on tensors), so it seems kind of pointless to have a completely different expression for this. ## Testing N/A since for now we don't look at the options at all. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent df84cee commit c5d2745

File tree

4 files changed

+63
-10
lines changed

4 files changed

+63
-10
lines changed

vortex-tensor/public-api.lock

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&se
132132

133133
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
134134

135-
pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions
135+
pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_tensor::scalar_fns::ApproxOptions
136136

137137
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity
138138

@@ -162,7 +162,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor
162162

163163
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm
164164

165-
pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
165+
pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_tensor::scalar_fns::ApproxOptions
166166

167167
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity
168168

@@ -182,6 +182,40 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::return_dtype(&self, _options:
182182

183183
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
184184

185+
pub enum vortex_tensor::scalar_fns::ApproxOptions
186+
187+
pub vortex_tensor::scalar_fns::ApproxOptions::Approximate
188+
189+
pub vortex_tensor::scalar_fns::ApproxOptions::Exact
190+
191+
impl core::clone::Clone for vortex_tensor::scalar_fns::ApproxOptions
192+
193+
pub fn vortex_tensor::scalar_fns::ApproxOptions::clone(&self) -> vortex_tensor::scalar_fns::ApproxOptions
194+
195+
impl core::cmp::Eq for vortex_tensor::scalar_fns::ApproxOptions
196+
197+
impl core::cmp::PartialEq for vortex_tensor::scalar_fns::ApproxOptions
198+
199+
pub fn vortex_tensor::scalar_fns::ApproxOptions::eq(&self, other: &vortex_tensor::scalar_fns::ApproxOptions) -> bool
200+
201+
impl core::default::Default for vortex_tensor::scalar_fns::ApproxOptions
202+
203+
pub fn vortex_tensor::scalar_fns::ApproxOptions::default() -> vortex_tensor::scalar_fns::ApproxOptions
204+
205+
impl core::fmt::Debug for vortex_tensor::scalar_fns::ApproxOptions
206+
207+
pub fn vortex_tensor::scalar_fns::ApproxOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
208+
209+
impl core::fmt::Display for vortex_tensor::scalar_fns::ApproxOptions
210+
211+
pub fn vortex_tensor::scalar_fns::ApproxOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
212+
213+
impl core::hash::Hash for vortex_tensor::scalar_fns::ApproxOptions
214+
215+
pub fn vortex_tensor::scalar_fns::ApproxOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
216+
217+
impl core::marker::StructuralPartialEq for vortex_tensor::scalar_fns::ApproxOptions
218+
185219
pub mod vortex_tensor::vector
186220

187221
pub struct vortex_tensor::vector::Vector

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ use vortex::error::vortex_err;
2222
use vortex::expr::Expression;
2323
use vortex::scalar_fn::Arity;
2424
use vortex::scalar_fn::ChildName;
25-
use vortex::scalar_fn::EmptyOptions;
2625
use vortex::scalar_fn::ExecutionArgs;
2726
use vortex::scalar_fn::ScalarFnId;
2827
use vortex::scalar_fn::ScalarFnVTable;
2928

3029
use crate::matcher::AnyTensor;
30+
use crate::scalar_fns::ApproxOptions;
3131
use crate::utils::extension_element_ptype;
3232
use crate::utils::extension_list_size;
3333
use crate::utils::extension_storage;
@@ -48,7 +48,7 @@ use crate::utils::extract_flat_elements;
4848
pub struct CosineSimilarity;
4949

5050
impl ScalarFnVTable for CosineSimilarity {
51-
type Options = EmptyOptions;
51+
type Options = ApproxOptions;
5252

5353
fn id(&self) -> ScalarFnId {
5454
ScalarFnId::new_ref("vortex.tensor.cosine_similarity")
@@ -192,9 +192,9 @@ mod tests {
192192
use vortex::array::ToCanonical;
193193
use vortex::array::arrays::ScalarFnArray;
194194
use vortex::error::VortexResult;
195-
use vortex::scalar_fn::EmptyOptions;
196195
use vortex::scalar_fn::ScalarFn;
197196

197+
use crate::scalar_fns::ApproxOptions;
198198
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
199199
use crate::utils::test_helpers::assert_close;
200200
use crate::utils::test_helpers::constant_tensor_array;
@@ -204,7 +204,7 @@ mod tests {
204204

205205
/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
206206
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
207-
let scalar_fn = ScalarFn::new(CosineSimilarity, EmptyOptions).erased();
207+
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
208208
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?;
209209
let prim = result.to_primitive();
210210
Ok(prim.as_slice::<f64>().to_vec())

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ use vortex::error::vortex_err;
2222
use vortex::expr::Expression;
2323
use vortex::scalar_fn::Arity;
2424
use vortex::scalar_fn::ChildName;
25-
use vortex::scalar_fn::EmptyOptions;
2625
use vortex::scalar_fn::ExecutionArgs;
2726
use vortex::scalar_fn::ScalarFnId;
2827
use vortex::scalar_fn::ScalarFnVTable;
2928

3029
use crate::matcher::AnyTensor;
30+
use crate::scalar_fns::ApproxOptions;
3131
use crate::utils::extension_element_ptype;
3232
use crate::utils::extension_list_size;
3333
use crate::utils::extension_storage;
@@ -43,7 +43,7 @@ use crate::utils::extract_flat_elements;
4343
pub struct L2Norm;
4444

4545
impl ScalarFnVTable for L2Norm {
46-
type Options = EmptyOptions;
46+
type Options = ApproxOptions;
4747

4848
fn id(&self) -> ScalarFnId {
4949
ScalarFnId::new_ref("vortex.tensor.l2_norm")
@@ -159,17 +159,17 @@ mod tests {
159159
use vortex::array::ToCanonical;
160160
use vortex::array::arrays::ScalarFnArray;
161161
use vortex::error::VortexResult;
162-
use vortex::scalar_fn::EmptyOptions;
163162
use vortex::scalar_fn::ScalarFn;
164163

164+
use crate::scalar_fns::ApproxOptions;
165165
use crate::scalar_fns::l2_norm::L2Norm;
166166
use crate::utils::test_helpers::assert_close;
167167
use crate::utils::test_helpers::tensor_array;
168168
use crate::utils::test_helpers::vector_array;
169169

170170
/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
171171
fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
172-
let scalar_fn = ScalarFn::new(L2Norm, EmptyOptions).erased();
172+
let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased();
173173
let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?;
174174
let prim = result.to_primitive();
175175
Ok(prim.as_slice::<f64>().to_vec())

vortex-tensor/src/scalar_fns/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,24 @@
33

44
//! Scalar function expressions defined on tensor and tensor-like extension types.
55
6+
use std::fmt;
7+
68
pub mod cosine_similarity;
79
pub mod l2_norm;
10+
11+
/// Options for tensor-related expressions that might have error.
12+
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
13+
pub enum ApproxOptions {
14+
#[default]
15+
Exact,
16+
Approximate,
17+
}
18+
19+
impl fmt::Display for ApproxOptions {
20+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21+
match self {
22+
Self::Exact => write!(f, "Exact"),
23+
Self::Approximate => write!(f, "Approximate"),
24+
}
25+
}
26+
}

0 commit comments

Comments
 (0)