Skip to content

Commit 8efe1dc

Browse files
authored
Vector Extension Type (#6964)
## Summary Tracking Issue: #6865 Adds a `Vector` extension type and a new `L2Norm` expression. Additionally adds a `AnyTensor` type that can be matched on for any kind of tensor we want. Right now the code assumes that everything is built on top of `FixedSizeList`, but in the future that might change. Additionally make some touchups to the `vortex-tensor` crate in general. ## API Changes The new `Vector` and `L2Norm` types. ## Testing Some basic tests. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 683ba3a commit 8efe1dc

10 files changed

Lines changed: 844 additions & 179 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,34 @@ pub fn vortex_tensor::fixed_shape::FixedShapeTensorMetadata::hash<__H: core::has
9090

9191
impl core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMetadata
9292

93+
pub mod vortex_tensor::matcher
94+
95+
pub enum vortex_tensor::matcher::TensorMatch<'a>
96+
97+
pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(&'a vortex_tensor::fixed_shape::FixedShapeTensorMetadata)
98+
99+
pub vortex_tensor::matcher::TensorMatch::Vector
100+
101+
impl<'a> core::cmp::Eq for vortex_tensor::matcher::TensorMatch<'a>
102+
103+
impl<'a> core::cmp::PartialEq for vortex_tensor::matcher::TensorMatch<'a>
104+
105+
pub fn vortex_tensor::matcher::TensorMatch<'a>::eq(&self, other: &vortex_tensor::matcher::TensorMatch<'a>) -> bool
106+
107+
impl<'a> core::fmt::Debug for vortex_tensor::matcher::TensorMatch<'a>
108+
109+
pub fn vortex_tensor::matcher::TensorMatch<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
110+
111+
impl<'a> core::marker::StructuralPartialEq for vortex_tensor::matcher::TensorMatch<'a>
112+
113+
pub struct vortex_tensor::matcher::AnyTensor
114+
115+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::matcher::AnyTensor
116+
117+
pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::TensorMatch<'a>
118+
119+
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(item: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
120+
93121
pub mod vortex_tensor::scalar_fns
94122

95123
pub mod vortex_tensor::scalar_fns::cosine_similarity
@@ -121,3 +149,77 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_s
121149
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>
122150

123151
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
152+
153+
pub mod vortex_tensor::scalar_fns::l2_norm
154+
155+
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
156+
157+
impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm
158+
159+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm
160+
161+
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm
162+
163+
pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
164+
165+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity
166+
167+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
168+
169+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
170+
171+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
172+
173+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::id(&self) -> vortex_array::scalar_fn::ScalarFnId
174+
175+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::is_fallible(&self, _options: &Self::Options) -> bool
176+
177+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::is_null_sensitive(&self, _options: &Self::Options) -> bool
178+
179+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>
180+
181+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
182+
183+
pub mod vortex_tensor::vector
184+
185+
pub struct vortex_tensor::vector::Vector
186+
187+
impl core::clone::Clone for vortex_tensor::vector::Vector
188+
189+
pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector
190+
191+
impl core::cmp::Eq for vortex_tensor::vector::Vector
192+
193+
impl core::cmp::PartialEq for vortex_tensor::vector::Vector
194+
195+
pub fn vortex_tensor::vector::Vector::eq(&self, other: &vortex_tensor::vector::Vector) -> bool
196+
197+
impl core::default::Default for vortex_tensor::vector::Vector
198+
199+
pub fn vortex_tensor::vector::Vector::default() -> vortex_tensor::vector::Vector
200+
201+
impl core::fmt::Debug for vortex_tensor::vector::Vector
202+
203+
pub fn vortex_tensor::vector::Vector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
204+
205+
impl core::hash::Hash for vortex_tensor::vector::Vector
206+
207+
pub fn vortex_tensor::vector::Vector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
208+
209+
impl core::marker::StructuralPartialEq for vortex_tensor::vector::Vector
210+
211+
impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::vector::Vector
212+
213+
pub type vortex_tensor::vector::Vector::Metadata = vortex_array::extension::EmptyMetadata
214+
215+
pub type vortex_tensor::vector::Vector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue
216+
217+
pub fn vortex_tensor::vector::Vector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>
218+
219+
pub fn vortex_tensor::vector::Vector::id(&self) -> vortex_array::dtype::extension::ExtId
220+
221+
pub fn vortex_tensor::vector::Vector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>
222+
223+
pub fn vortex_tensor::vector::Vector::unpack_native<'a>(&self, _ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>
224+
225+
pub fn vortex_tensor::vector::Vector::validate_dtype(&self, ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>

vortex-tensor/src/fixed_shape/vtable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ impl ExtVTable for FixedShapeTensor {
2222
type NativeValue<'a> = &'a ScalarValue;
2323

2424
fn id(&self) -> ExtId {
25-
ExtId::new_ref("vortex.fixed_shape_tensor")
25+
ExtId::new_ref("vortex.tensor.fixed_shape_tensor")
2626
}
2727

2828
fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {

vortex-tensor/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,7 @@
66
//! similarity.
77
88
pub mod fixed_shape;
9+
pub mod vector;
910

11+
pub mod matcher;
1012
pub mod scalar_fns;

vortex-tensor/src/matcher.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Matcher for tensor-like extension types.
5+
6+
use vortex::dtype::extension::ExtDTypeRef;
7+
use vortex::dtype::extension::Matcher;
8+
9+
use crate::fixed_shape::FixedShapeTensor;
10+
use crate::fixed_shape::FixedShapeTensorMetadata;
11+
use crate::vector::Vector;
12+
13+
/// Matcher for any tensor-like extension type.
14+
///
15+
/// Currently the different kinds of tensors that are available are:
16+
///
17+
/// - `FixedShapeTensor`
18+
/// - `Vector`
19+
pub struct AnyTensor;
20+
21+
/// The matched variant of a tensor-like extension type.
22+
#[derive(Debug, PartialEq, Eq)]
23+
pub enum TensorMatch<'a> {
24+
/// A [`FixedShapeTensor`] extension type.
25+
FixedShapeTensor(&'a FixedShapeTensorMetadata),
26+
/// A [`Vector`] extension type.
27+
Vector,
28+
}
29+
30+
impl Matcher for AnyTensor {
31+
type Match<'a> = TensorMatch<'a>;
32+
33+
fn try_match<'a>(item: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
34+
if let Some(metadata) = item.metadata_opt::<FixedShapeTensor>() {
35+
return Some(TensorMatch::FixedShapeTensor(metadata));
36+
}
37+
if item.metadata_opt::<Vector>().is_some() {
38+
return Some(TensorMatch::Vector);
39+
}
40+
None
41+
}
42+
}

0 commit comments

Comments
 (0)