Skip to content

Commit 5e93e8e

Browse files
authored
Move stuff around vortex-tensor (#7225)
## Summary In preparation for adding vector encodings, cleans up the tensor crate. ## Testing N/A --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 10c3fc6 commit 5e93e8e

7 files changed

Lines changed: 82 additions & 36 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
pub mod vortex_tensor
22

3+
pub mod vortex_tensor::encodings
4+
35
pub mod vortex_tensor::fixed_shape
46

57
pub struct vortex_tensor::fixed_shape::FixedShapeTensor
@@ -136,7 +138,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se
136138

137139
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
138140

139-
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
141+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
140142

141143
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
142144

@@ -166,7 +168,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self:
166168

167169
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
168170

169-
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
171+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
170172

171173
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
172174

vortex-tensor/src/encodings/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Encodings for the different tensor types.
5+
6+
// TODO(connor):
7+
// pub mod norm; // Unit-normalized vectors.
8+
// pub mod spherical; // Spherical transform on unit-normalized vectors.
9+
10+
// TODO(will):
11+
// pub mod turboquant;

vortex-tensor/src/lib.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
//! including unit vectors, spherical coordinates, and similarity measures such as cosine
66
//! similarity.
77
8+
pub mod matcher;
9+
pub mod scalar_fns;
10+
811
pub mod fixed_shape;
912
pub mod vector;
1013

11-
pub mod matcher;
12-
pub mod scalar_fns;
14+
pub mod encodings;
15+
16+
mod utils;

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId;
2828
use vortex::scalar_fn::ScalarFnVTable;
2929

3030
use crate::matcher::AnyTensor;
31-
use crate::scalar_fns::utils::extension_element_ptype;
32-
use crate::scalar_fns::utils::extension_list_size;
33-
use crate::scalar_fns::utils::extension_storage;
34-
use crate::scalar_fns::utils::extract_flat_elements;
31+
use crate::utils::extension_element_ptype;
32+
use crate::utils::extension_list_size;
33+
use crate::utils::extension_storage;
34+
use crate::utils::extract_flat_elements;
3535

3636
/// Cosine similarity between two columns.
3737
///
@@ -115,7 +115,7 @@ impl ScalarFnVTable for CosineSimilarity {
115115
&self,
116116
_options: &Self::Options,
117117
args: &dyn ExecutionArgs,
118-
_ctx: &mut ExecutionCtx,
118+
ctx: &mut ExecutionCtx,
119119
) -> VortexResult<ArrayRef> {
120120
let lhs = args.get(0)?;
121121
let rhs = args.get(1)?;
@@ -128,15 +128,15 @@ impl ScalarFnVTable for CosineSimilarity {
128128
lhs.dtype()
129129
)
130130
})?;
131-
let list_size = extension_list_size(ext)?;
131+
let list_size = extension_list_size(ext)? as usize;
132132

133133
// Extract the storage array from each extension input. We pass the storage (FSL) rather
134134
// than the extension array to avoid canonicalizing the extension wrapper.
135135
let lhs_storage = extension_storage(&lhs)?;
136136
let rhs_storage = extension_storage(&rhs)?;
137137

138-
let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?;
139-
let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?;
138+
let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?;
139+
let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?;
140140

141141
match_each_float_ptype!(lhs_flat.ptype(), |T| {
142142
let result: PrimitiveArray = (0..row_count)
@@ -196,11 +196,11 @@ mod tests {
196196
use vortex::scalar_fn::ScalarFn;
197197

198198
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
199-
use crate::scalar_fns::utils::test_helpers::assert_close;
200-
use crate::scalar_fns::utils::test_helpers::constant_tensor_array;
201-
use crate::scalar_fns::utils::test_helpers::constant_vector_array;
202-
use crate::scalar_fns::utils::test_helpers::tensor_array;
203-
use crate::scalar_fns::utils::test_helpers::vector_array;
199+
use crate::utils::test_helpers::assert_close;
200+
use crate::utils::test_helpers::constant_tensor_array;
201+
use crate::utils::test_helpers::constant_vector_array;
202+
use crate::utils::test_helpers::tensor_array;
203+
use crate::utils::test_helpers::vector_array;
204204

205205
/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
206206
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId;
2828
use vortex::scalar_fn::ScalarFnVTable;
2929

3030
use crate::matcher::AnyTensor;
31-
use crate::scalar_fns::utils::extension_element_ptype;
32-
use crate::scalar_fns::utils::extension_list_size;
33-
use crate::scalar_fns::utils::extension_storage;
34-
use crate::scalar_fns::utils::extract_flat_elements;
31+
use crate::utils::extension_element_ptype;
32+
use crate::utils::extension_list_size;
33+
use crate::utils::extension_storage;
34+
use crate::utils::extract_flat_elements;
3535

3636
/// L2 norm (Euclidean norm) of a tensor or vector column.
3737
///
@@ -98,7 +98,7 @@ impl ScalarFnVTable for L2Norm {
9898
&self,
9999
_options: &Self::Options,
100100
args: &dyn ExecutionArgs,
101-
_ctx: &mut ExecutionCtx,
101+
ctx: &mut ExecutionCtx,
102102
) -> VortexResult<ArrayRef> {
103103
let input = args.get(0)?;
104104
let row_count = args.row_count();
@@ -110,10 +110,10 @@ impl ScalarFnVTable for L2Norm {
110110
input.dtype()
111111
)
112112
})?;
113-
let list_size = extension_list_size(ext)?;
113+
let list_size = extension_list_size(ext)? as usize;
114114

115115
let storage = extension_storage(&input)?;
116-
let flat = extract_flat_elements(&storage, list_size)?;
116+
let flat = extract_flat_elements(&storage, list_size, ctx)?;
117117

118118
match_each_float_ptype!(flat.ptype(), |T| {
119119
let result: PrimitiveArray = (0..row_count)
@@ -163,9 +163,9 @@ mod tests {
163163
use vortex::scalar_fn::ScalarFn;
164164

165165
use crate::scalar_fns::l2_norm::L2Norm;
166-
use crate::scalar_fns::utils::test_helpers::assert_close;
167-
use crate::scalar_fns::utils::test_helpers::tensor_array;
168-
use crate::scalar_fns::utils::test_helpers::vector_array;
166+
use crate::utils::test_helpers::assert_close;
167+
use crate::utils::test_helpers::tensor_array;
168+
use crate::utils::test_helpers::vector_array;
169169

170170
/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
171171
fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult<Vec<f64>> {

vortex-tensor/src/scalar_fns/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@
55
66
pub mod cosine_similarity;
77
pub mod l2_norm;
8-
9-
mod utils;
Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex::array::ArrayRef;
5+
use vortex::array::ExecutionCtx;
56
use vortex::array::IntoArray;
67
use vortex::array::arrays::Constant;
78
use vortex::array::arrays::ConstantArray;
89
use vortex::array::arrays::Extension;
10+
use vortex::array::arrays::FixedSizeListArray;
911
use vortex::array::arrays::PrimitiveArray;
1012
use vortex::dtype::DType;
1113
use vortex::dtype::NativePType;
@@ -19,15 +21,15 @@ use vortex::error::vortex_err;
1921
/// Extracts the list size from a tensor-like extension dtype.
2022
///
2123
/// The storage dtype must be a `FixedSizeList`.
22-
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<usize> {
24+
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
2325
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
2426
vortex_bail!(
2527
"expected FixedSizeList storage dtype, got {}",
2628
ext.storage_dtype()
2729
);
2830
};
2931

30-
Ok(*list_size as usize)
32+
Ok(*list_size)
3133
}
3234

3335
/// Extracts the float element [`PType`] from a tensor-like extension dtype.
@@ -91,13 +93,17 @@ impl FlatElements {
9193
///
9294
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
9395
/// materialized to avoid expanding it to the full column length.
94-
pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResult<FlatElements> {
96+
pub fn extract_flat_elements(
97+
storage: &ArrayRef,
98+
list_size: usize,
99+
ctx: &mut ExecutionCtx,
100+
) -> VortexResult<FlatElements> {
95101
if let Some(constant) = storage.as_opt::<Constant>() {
96102
// Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge
97103
// amount of data.
98104
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
99-
let fsl = single.to_canonical()?.into_fixed_size_list();
100-
let elems = fsl.elements().to_canonical()?.into_primitive();
105+
let fsl: FixedSizeListArray = single.execute(ctx)?;
106+
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
101107
return Ok(FlatElements {
102108
elems,
103109
stride: 0,
@@ -106,8 +112,8 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu
106112
}
107113

108114
// Otherwise we have to fully expand all of the data.
109-
let fsl = storage.to_canonical()?.into_fixed_size_list();
110-
let elems = fsl.elements().to_canonical()?.into_primitive();
115+
let fsl: FixedSizeListArray = storage.clone().execute(ctx)?;
116+
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
111117
Ok(FlatElements {
112118
elems,
113119
stride: list_size,
@@ -118,6 +124,7 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu
118124
#[cfg(test)]
119125
pub mod test_helpers {
120126
use vortex::array::ArrayRef;
127+
use vortex::array::ExecutionCtx;
121128
use vortex::array::IntoArray;
122129
use vortex::array::arrays::ConstantArray;
123130
use vortex::array::arrays::ExtensionArray;
@@ -128,9 +135,13 @@ pub mod test_helpers {
128135
use vortex::dtype::Nullability;
129136
use vortex::dtype::extension::ExtDType;
130137
use vortex::error::VortexResult;
138+
use vortex::error::vortex_err;
131139
use vortex::extension::EmptyMetadata;
132140
use vortex::scalar::Scalar;
133141

142+
use super::extension_list_size;
143+
use super::extension_storage;
144+
use super::extract_flat_elements;
134145
use crate::fixed_shape::FixedShapeTensor;
135146
use crate::fixed_shape::FixedShapeTensorMetadata;
136147
use crate::vector::Vector;
@@ -210,6 +221,26 @@ pub mod test_helpers {
210221
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
211222
}
212223

224+
#[expect(dead_code, reason = "TODO(connor): Use this!")]
225+
/// Extracts the f64 rows from a [`Vector`] extension array.
226+
///
227+
/// Returns a `Vec<Vec<f64>>` where each inner vec is one vector's elements.
228+
pub fn extract_vector_rows(
229+
array: &ArrayRef,
230+
ctx: &mut ExecutionCtx,
231+
) -> VortexResult<Vec<Vec<f64>>> {
232+
let ext = array
233+
.dtype()
234+
.as_extension_opt()
235+
.ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?;
236+
let list_size = extension_list_size(ext)? as usize;
237+
let storage = extension_storage(array)?;
238+
let flat = extract_flat_elements(&storage, list_size, ctx)?;
239+
Ok((0..array.len())
240+
.map(|i| flat.row::<f64>(i).to_vec())
241+
.collect())
242+
}
243+
213244
/// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected`
214245
/// value, with support for NaN (NaN == NaN is considered equal).
215246
#[track_caller]

0 commit comments

Comments
 (0)