Skip to content

Commit 4cbfb33

Browse files
authored
Vortex Fixed-Shape Tensor (#6812)
## Summary Adds an experimental fixed-shape tensor extension type in a new `vortex-tensor` crate. See https://vortex-data.github.io/rfcs/rfc/0024.html for info about the design of this tensor type. Additionally adds a `CosineSimilarity` expression that takes 2 tensor arrays and computes the cosine similarity of tensors in the arrays (resulting in a `PrimitiveArray`). ## Testing Adds some very basic tests for cosine similarity and tensor metadata operations. ## Future Work I think this was a good way to see if our `ExtVTable` is not completely wrong, but at the same time this tells us nothing about what we might want to add for extension arrays on the `ExtVTable` because we as long as the storage `DType` is correct, any storage array is valid. The more interesting expressions have not been implemented here. Those would include: - Cast - Index / Slice (and lazily get back another tensor array, potentially non-contiguous) - Maybe others? Additional work includes exporting to Arrow, NumPy, and PyTorch. Arrow will require a cheap translation from logical to physical shape, but other than that those conversions should be easy. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 8eee959 commit 4cbfb33

10 files changed

Lines changed: 1231 additions & 0 deletions

File tree

Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ members = [
1111
"vortex-io",
1212
"vortex-proto",
1313
"vortex-array",
14+
"vortex-tensor",
1415
"vortex-btrblocks",
1516
"vortex-layout",
1617
"vortex-scan",
@@ -271,6 +272,7 @@ vortex-scan = { version = "0.1.0", path = "./vortex-scan", default-features = fa
271272
vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-features = false }
272273
vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false }
273274
vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false }
275+
vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false }
274276
vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false }
275277
vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false }
276278
vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false }

vortex-tensor/Cargo.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[package]
2+
name = "vortex-tensor"
3+
authors = { workspace = true }
4+
categories = { workspace = true }
5+
description = "Vortex tensor extension type"
6+
edition = { workspace = true }
7+
homepage = { workspace = true }
8+
include = { workspace = true }
9+
keywords = { workspace = true }
10+
license = { workspace = true }
11+
readme = { workspace = true }
12+
repository = { workspace = true }
13+
rust-version = { workspace = true }
14+
version = { workspace = true }
15+
16+
[lints]
17+
workspace = true
18+
19+
[dependencies]
20+
vortex = { workspace = true }
21+
22+
itertools = { workspace = true }
23+
num-traits = { workspace = true }
24+
prost = { workspace = true }
25+
26+
[dev-dependencies]
27+
rstest = { workspace = true }

vortex-tensor/public-api.lock

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
pub mod vortex_tensor
2+
3+
pub mod vortex_tensor::scalar_fns
4+
5+
pub mod vortex_tensor::scalar_fns::cosine_similarity
6+
7+
pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
8+
9+
impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
10+
11+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
12+
13+
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
14+
15+
pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions
16+
17+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity
18+
19+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
20+
21+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
22+
23+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
24+
25+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::id(&self) -> vortex_array::scalar_fn::ScalarFnId
26+
27+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_fallible(&self, _options: &Self::Options) -> bool
28+
29+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::is_null_sensitive(&self, _options: &Self::Options) -> bool
30+
31+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult<vortex_array::dtype::DType>
32+
33+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::expression::Expression>>
34+
35+
pub struct vortex_tensor::FixedShapeTensor
36+
37+
impl core::clone::Clone for vortex_tensor::FixedShapeTensor
38+
39+
pub fn vortex_tensor::FixedShapeTensor::clone(&self) -> vortex_tensor::FixedShapeTensor
40+
41+
impl core::cmp::Eq for vortex_tensor::FixedShapeTensor
42+
43+
impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensor
44+
45+
pub fn vortex_tensor::FixedShapeTensor::eq(&self, other: &vortex_tensor::FixedShapeTensor) -> bool
46+
47+
impl core::default::Default for vortex_tensor::FixedShapeTensor
48+
49+
pub fn vortex_tensor::FixedShapeTensor::default() -> vortex_tensor::FixedShapeTensor
50+
51+
impl core::fmt::Debug for vortex_tensor::FixedShapeTensor
52+
53+
pub fn vortex_tensor::FixedShapeTensor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
54+
55+
impl core::hash::Hash for vortex_tensor::FixedShapeTensor
56+
57+
pub fn vortex_tensor::FixedShapeTensor::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
58+
59+
impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensor
60+
61+
impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::FixedShapeTensor
62+
63+
pub type vortex_tensor::FixedShapeTensor::Metadata = vortex_tensor::FixedShapeTensorMetadata
64+
65+
pub type vortex_tensor::FixedShapeTensor::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue
66+
67+
pub fn vortex_tensor::FixedShapeTensor::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>
68+
69+
pub fn vortex_tensor::FixedShapeTensor::id(&self) -> vortex_array::dtype::extension::ExtId
70+
71+
pub fn vortex_tensor::FixedShapeTensor::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>
72+
73+
pub fn vortex_tensor::FixedShapeTensor::unpack_native<'a>(&self, _metadata: &'a Self::Metadata, _storage_dtype: &'a vortex_array::dtype::DType, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>
74+
75+
pub fn vortex_tensor::FixedShapeTensor::validate_dtype(&self, metadata: &Self::Metadata, storage_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<()>
76+
77+
pub struct vortex_tensor::FixedShapeTensorMetadata
78+
79+
impl vortex_tensor::FixedShapeTensorMetadata
80+
81+
pub fn vortex_tensor::FixedShapeTensorMetadata::dim_names(&self) -> core::option::Option<&[alloc::string::String]>
82+
83+
pub fn vortex_tensor::FixedShapeTensorMetadata::logical_shape(&self) -> &[usize]
84+
85+
pub fn vortex_tensor::FixedShapeTensorMetadata::ndim(&self) -> usize
86+
87+
pub fn vortex_tensor::FixedShapeTensorMetadata::new(shape: alloc::vec::Vec<usize>) -> Self
88+
89+
pub fn vortex_tensor::FixedShapeTensorMetadata::permutation(&self) -> core::option::Option<&[usize]>
90+
91+
pub fn vortex_tensor::FixedShapeTensorMetadata::physical_shape(&self) -> impl core::iter::traits::iterator::Iterator<Item = usize> + '_
92+
93+
pub fn vortex_tensor::FixedShapeTensorMetadata::strides(&self) -> impl core::iter::traits::iterator::Iterator<Item = usize> + '_
94+
95+
pub fn vortex_tensor::FixedShapeTensorMetadata::with_dim_names(self, names: alloc::vec::Vec<alloc::string::String>) -> vortex_error::VortexResult<Self>
96+
97+
pub fn vortex_tensor::FixedShapeTensorMetadata::with_permutation(self, permutation: alloc::vec::Vec<usize>) -> vortex_error::VortexResult<Self>
98+
99+
impl core::clone::Clone for vortex_tensor::FixedShapeTensorMetadata
100+
101+
pub fn vortex_tensor::FixedShapeTensorMetadata::clone(&self) -> vortex_tensor::FixedShapeTensorMetadata
102+
103+
impl core::cmp::Eq for vortex_tensor::FixedShapeTensorMetadata
104+
105+
impl core::cmp::PartialEq for vortex_tensor::FixedShapeTensorMetadata
106+
107+
pub fn vortex_tensor::FixedShapeTensorMetadata::eq(&self, other: &vortex_tensor::FixedShapeTensorMetadata) -> bool
108+
109+
impl core::fmt::Debug for vortex_tensor::FixedShapeTensorMetadata
110+
111+
pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
112+
113+
impl core::fmt::Display for vortex_tensor::FixedShapeTensorMetadata
114+
115+
pub fn vortex_tensor::FixedShapeTensorMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
116+
117+
impl core::hash::Hash for vortex_tensor::FixedShapeTensorMetadata
118+
119+
pub fn vortex_tensor::FixedShapeTensorMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
120+
121+
impl core::marker::StructuralPartialEq for vortex_tensor::FixedShapeTensorMetadata

vortex-tensor/src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Tensor extension type.
5+
6+
mod metadata;
7+
pub use metadata::FixedShapeTensorMetadata;
8+
9+
mod proto;
10+
mod vtable;
11+
12+
pub mod scalar_fns;
13+
14+
/// The VTable for the Tensor extension type.
15+
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
16+
pub struct FixedShapeTensor;

0 commit comments

Comments
 (0)