Skip to content

Commit e415157

Browse files
authored
Change tensor utils names (#7284)
## Summary It doesn't really make a lot of sense to call these `extension_` utils since really it is specific to tensor types (fixed shape tensor, vectors, etc). Am open to calling these `fixed_tensor_` instead as well. ## Testing N/A Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 1c2348a commit e415157

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::matcher::AnyTensor;
3434
use crate::scalar_fns::ApproxOptions;
3535
use crate::scalar_fns::inner_product::InnerProduct;
3636
use crate::scalar_fns::l2_norm::L2Norm;
37-
use crate::utils::extension_element_ptype;
37+
use crate::utils::tensor_element_ptype;
3838

3939
/// Cosine similarity between two columns.
4040
///
@@ -128,7 +128,7 @@ impl ScalarFnVTable for CosineSimilarity {
128128
"CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
129129
);
130130

131-
let ptype = extension_element_ptype(lhs_ext)?;
131+
let ptype = tensor_element_ptype(lhs_ext)?;
132132
vortex_ensure!(
133133
ptype.is_float(),
134134
"CosineSimilarity element dtype must be a float primitive, got {ptype}"

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ use crate::encodings::turboquant::TurboQuant;
3333
use crate::encodings::turboquant::compute::cosine_similarity;
3434
use crate::matcher::AnyTensor;
3535
use crate::scalar_fns::ApproxOptions;
36-
use crate::utils::extension_element_ptype;
37-
use crate::utils::extension_list_size;
3836
use crate::utils::extract_flat_elements;
37+
use crate::utils::tensor_element_ptype;
38+
use crate::utils::tensor_list_size;
3939

4040
/// Inner product (dot product) between two columns.
4141
///
@@ -127,7 +127,7 @@ impl ScalarFnVTable for InnerProduct {
127127
"InnerProduct inputs must be an `AnyTensor`, got {lhs}"
128128
);
129129

130-
let ptype = extension_element_ptype(lhs_ext)?;
130+
let ptype = tensor_element_ptype(lhs_ext)?;
131131
vortex_ensure!(
132132
ptype.is_float(),
133133
"InnerProduct element dtype must be a float primitive, got {ptype}"
@@ -168,7 +168,7 @@ impl ScalarFnVTable for InnerProduct {
168168
// Get list size from the dtype. Both sides have the same dtype (validated by
169169
// `return_dtype`).
170170
let ext = lhs.dtype().as_extension();
171-
let list_size = extension_list_size(ext)? as usize;
171+
let list_size = tensor_list_size(ext)? as usize;
172172

173173
// Extract the storage array from each extension input. We pass the storage (FSL) rather
174174
// than the extension array to avoid canonicalizing the extension wrapper.

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ use vortex_error::vortex_err;
3232
use crate::encodings::turboquant::TurboQuant;
3333
use crate::matcher::AnyTensor;
3434
use crate::scalar_fns::ApproxOptions;
35-
use crate::utils::extension_element_ptype;
36-
use crate::utils::extension_list_size;
3735
use crate::utils::extract_flat_elements;
36+
use crate::utils::tensor_element_ptype;
37+
use crate::utils::tensor_list_size;
3838

3939
/// L2 norm (Euclidean norm) of a tensor or vector column.
4040
///
@@ -109,7 +109,7 @@ impl ScalarFnVTable for L2Norm {
109109
"L2Norm input must be an `AnyTensor`, got {input_dtype}"
110110
);
111111

112-
let ptype = extension_element_ptype(ext)?;
112+
let ptype = tensor_element_ptype(ext)?;
113113
vortex_ensure!(
114114
ptype.is_float(),
115115
"L2Norm element dtype must be a float primitive, got {ptype}"
@@ -144,7 +144,7 @@ impl ScalarFnVTable for L2Norm {
144144

145145
// Get element ptype and list size from the dtype (validated by `return_dtype`).
146146
let ext = input.dtype().as_extension();
147-
let list_size = extension_list_size(ext)? as usize;
147+
let list_size = tensor_list_size(ext)? as usize;
148148

149149
let storage = input.data().storage_array();
150150
let flat = extract_flat_elements(storage, list_size, ctx)?;

vortex-tensor/src/utils.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use vortex_error::vortex_err;
2020
/// Extracts the list size from a tensor-like extension dtype.
2121
///
2222
/// The storage dtype must be a `FixedSizeList`.
23-
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
23+
pub fn tensor_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
2424
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
2525
vortex_bail!(
2626
"expected FixedSizeList storage dtype, got {}",
@@ -34,7 +34,7 @@ pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
3434
/// Extracts the float element [`PType`] from a tensor-like extension dtype.
3535
///
3636
/// The storage dtype must be a `FixedSizeList` of non-nullable primitives.
37-
pub fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult<PType> {
37+
pub fn tensor_element_ptype(ext: &ExtDTypeRef) -> VortexResult<PType> {
3838
let element_dtype = ext
3939
.storage_dtype()
4040
.as_fixed_size_list_element_opt()

0 commit comments

Comments
 (0)