Skip to content

Commit 22b1e17

Browse files
committed
change tensor utils names
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent a8351a9 commit 22b1e17

4 files changed

Lines changed: 12 additions & 12 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use crate::matcher::AnyTensor;
3232
use crate::scalar_fns::ApproxOptions;
3333
use crate::scalar_fns::inner_product::InnerProduct;
3434
use crate::scalar_fns::l2_norm::L2Norm;
35-
use crate::utils::extension_element_ptype;
35+
use crate::utils::tensor_element_ptype;
3636

3737
/// Cosine similarity between two columns.
3838
///
@@ -126,7 +126,7 @@ impl ScalarFnVTable for CosineSimilarity {
126126
"CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
127127
);
128128

129-
let ptype = extension_element_ptype(lhs_ext)?;
129+
let ptype = tensor_element_ptype(lhs_ext)?;
130130
vortex_ensure!(
131131
ptype.is_float(),
132132
"CosineSimilarity element dtype must be a float primitive, got {ptype}"

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ use vortex_error::vortex_err;
3131

3232
use crate::matcher::AnyTensor;
3333
use crate::scalar_fns::ApproxOptions;
34-
use crate::utils::extension_element_ptype;
35-
use crate::utils::extension_list_size;
3634
use crate::utils::extract_flat_elements;
35+
use crate::utils::tensor_element_ptype;
36+
use crate::utils::tensor_list_size;
3737

3838
/// Inner product (dot product) between two columns.
3939
///
@@ -125,7 +125,7 @@ impl ScalarFnVTable for InnerProduct {
125125
"InnerProduct inputs must be an `AnyTensor`, got {lhs}"
126126
);
127127

128-
let ptype = extension_element_ptype(lhs_ext)?;
128+
let ptype = tensor_element_ptype(lhs_ext)?;
129129
vortex_ensure!(
130130
ptype.is_float(),
131131
"InnerProduct element dtype must be a float primitive, got {ptype}"
@@ -153,7 +153,7 @@ impl ScalarFnVTable for InnerProduct {
153153
// Get list size from the dtype. Both sides have the same dtype (validated by
154154
// `return_dtype`).
155155
let ext = lhs.dtype().as_extension();
156-
let list_size = extension_list_size(ext)? as usize;
156+
let list_size = tensor_list_size(ext)? as usize;
157157

158158
// Extract the storage array from each extension input. We pass the storage (FSL) rather
159159
// than the extension array to avoid canonicalizing the extension wrapper.

vortex-tensor/src/scalar_fns/l2_norm.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ use vortex_error::vortex_err;
3030

3131
use crate::matcher::AnyTensor;
3232
use crate::scalar_fns::ApproxOptions;
33-
use crate::utils::extension_element_ptype;
34-
use crate::utils::extension_list_size;
3533
use crate::utils::extract_flat_elements;
34+
use crate::utils::tensor_element_ptype;
35+
use crate::utils::tensor_list_size;
3636

3737
/// L2 norm (Euclidean norm) of a tensor or vector column.
3838
///
@@ -107,7 +107,7 @@ impl ScalarFnVTable for L2Norm {
107107
"L2Norm input must be an `AnyTensor`, got {input_dtype}"
108108
);
109109

110-
let ptype = extension_element_ptype(ext)?;
110+
let ptype = tensor_element_ptype(ext)?;
111111
vortex_ensure!(
112112
ptype.is_float(),
113113
"L2Norm element dtype must be a float primitive, got {ptype}"
@@ -130,7 +130,7 @@ impl ScalarFnVTable for L2Norm {
130130

131131
// Get list size (dimensions) from the dtype (validated by `return_dtype`).
132132
let ext = input.dtype().as_extension();
133-
let list_size = extension_list_size(ext)? as usize;
133+
let list_size = tensor_list_size(ext)? as usize;
134134

135135
let storage = input.data().storage_array();
136136
let flat = extract_flat_elements(storage, list_size, ctx)?;

vortex-tensor/src/utils.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use vortex_error::vortex_err;
2020
/// Extracts the list size from a tensor-like extension dtype.
2121
///
2222
/// The storage dtype must be a `FixedSizeList`.
23-
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
23+
pub fn tensor_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
2424
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
2525
vortex_bail!(
2626
"expected FixedSizeList storage dtype, got {}",
@@ -34,7 +34,7 @@ pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
3434
/// Extracts the float element [`PType`] from a tensor-like extension dtype.
3535
///
3636
/// The storage dtype must be a `FixedSizeList` of non-nullable primitives.
37-
pub fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult<PType> {
37+
pub fn tensor_element_ptype(ext: &ExtDTypeRef) -> VortexResult<PType> {
3838
let element_dtype = ext
3939
.storage_dtype()
4040
.as_fixed_size_list_element_opt()

0 commit comments

Comments
 (0)