diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs index 289271d15f1..5006bb061f3 100644 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ b/vortex-tensor/src/encodings/turboquant/array/data.rs @@ -14,8 +14,8 @@ use vortex_error::vortex_ensure_eq; use crate::encodings::turboquant::array::slots::Slot; use crate::encodings::turboquant::vtable::TurboQuant; -use crate::utils::extension_element_ptype; -use crate::utils::extension_list_size; +use crate::utils::tensor_element_ptype; +use crate::utils::tensor_list_size; /// TurboQuant array data. /// @@ -117,7 +117,7 @@ impl TurboQuantData { let dimension = dtype .as_extension_opt() - .and_then(|ext| extension_list_size(ext).ok()) + .and_then(|ext| tensor_list_size(ext).ok()) .vortex_expect("dtype must be a Vector extension type with FixedSizeList storage"); let bit_width = if centroids.is_empty() { @@ -154,7 +154,7 @@ impl TurboQuantData { rotation_signs: &ArrayRef, ) -> VortexResult<()> { let ext = TurboQuant::validate_dtype(dtype)?; - let dimension = extension_list_size(ext)?; + let dimension = tensor_list_size(ext)?; let padded_dim = dimension.next_power_of_two(); // Codes must be a non-nullable FixedSizeList with list_size == padded_dim. @@ -209,7 +209,7 @@ impl TurboQuantData { // Norms dtype must match the element ptype of the Vector, with the parent's nullability. // Norms carry the validity of the entire TurboQuant array. - let element_ptype = extension_element_ptype(ext)?; + let element_ptype = tensor_element_ptype(ext)?; let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability()); vortex_ensure_eq!( *norms.dtype(), diff --git a/vortex-tensor/src/encodings/turboquant/array/scheme.rs b/vortex-tensor/src/encodings/turboquant/array/scheme.rs index de7a6a85302..b9aee1b68c5 100644 --- a/vortex-tensor/src/encodings/turboquant/array/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/array/scheme.rs @@ -14,8 +14,8 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::turboquant_encode; -use crate::utils::extension_element_ptype; -use crate::utils::extension_list_size; +use crate::utils::tensor_element_ptype; +use crate::utils::tensor_list_size; /// TurboQuant compression scheme for [`Vector`] extension types. /// @@ -59,8 +59,8 @@ impl Scheme for TurboQuantScheme { let len = data.array().len(); let ext = TurboQuant::validate_dtype(dtype)?; - let element_ptype = extension_element_ptype(ext)?; - let dimension = extension_list_size(ext)?; + let element_ptype = tensor_element_ptype(ext)?; + let dimension = tensor_list_size(ext)?; Ok(estimate_compression_ratio( element_ptype.bit_width(), diff --git a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs index e9bcd17ce96..87aa37562cd 100644 --- a/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs +++ b/vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs @@ -43,7 +43,7 @@ use vortex_error::vortex_ensure_eq; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::array::float_from_f32; -use crate::utils::extension_element_ptype; +use crate::utils::tensor_element_ptype; /// Compute the per-row unit-norm dot products in f32 (centroids are always f32). /// @@ -107,7 +107,7 @@ pub fn cosine_similarity_quantized_column( "TurboQuant quantized dot product requires matching dimensions", ); - let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?; let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?; let dots = compute_unit_dots(&lhs, &rhs, ctx)?; @@ -145,7 +145,7 @@ pub fn dot_product_quantized_column( "TurboQuant quantized dot product requires matching dimensions", ); - let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?; + let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?; let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?; let dots = compute_unit_dots(&lhs, &rhs, ctx)?; let num_rows = lhs.norms().len(); diff --git a/vortex-tensor/src/encodings/turboquant/decompress.rs b/vortex-tensor/src/encodings/turboquant/decompress.rs index b4b9fbaccfd..467662184ff 100644 --- a/vortex-tensor/src/encodings/turboquant/decompress.rs +++ b/vortex-tensor/src/encodings/turboquant/decompress.rs @@ -22,7 +22,7 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuant; use crate::encodings::turboquant::array::float_from_f32; use crate::encodings::turboquant::array::rotation::RotationMatrix; -use crate::utils::extension_element_ptype; +use crate::utils::tensor_element_ptype; /// Decompress a `TurboQuantArray` into a [`Vector`] extension array. /// @@ -38,7 +38,7 @@ pub fn execute_decompress( let padded_dim = array.padded_dim() as usize; let num_rows = array.norms().len(); let ext_dtype = array.dtype().as_extension().clone(); - let element_ptype = extension_element_ptype(&ext_dtype)?; + let element_ptype = tensor_element_ptype(&ext_dtype)?; if num_rows == 0 { let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability()); diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs index 1510132863c..39915b320f3 100644 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ b/vortex-tensor/src/encodings/turboquant/vtable.rs @@ -39,8 +39,8 @@ use crate::encodings::turboquant::array::slots::Slot; use crate::encodings::turboquant::compute::rules::PARENT_KERNELS; use crate::encodings::turboquant::compute::rules::RULES; use crate::encodings::turboquant::decompress::execute_decompress; -use crate::utils::extension_element_ptype; -use crate::utils::extension_list_size; +use crate::utils::tensor_element_ptype; +use crate::utils::tensor_list_size; use crate::vector::Vector; /// Encoding marker type for TurboQuant. @@ -66,7 +66,7 @@ impl TurboQuant { vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") })?; - let dimension = extension_list_size(ext)?; + let dimension = tensor_list_size(ext)?; vortex_ensure!( dimension >= Self::MIN_DIMENSION, "TurboQuant requires dimension >= {}, got {dimension}", @@ -113,7 +113,7 @@ impl VTable for TurboQuant { vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") })?; - let dimension = extension_list_size(ext)?; + let dimension = tensor_list_size(ext)?; vortex_ensure!( dimension >= Self::MIN_DIMENSION, "TurboQuant requires dimension >= {}, got {dimension}", @@ -208,8 +208,8 @@ impl VTable for TurboQuant { // Validate and derive dimension and element ptype from the Vector extension dtype. let ext = TurboQuant::validate_dtype(dtype)?; - let dimension = extension_list_size(ext)?; - let element_ptype = extension_element_ptype(ext)?; + let dimension = tensor_list_size(ext)?; + let element_ptype = tensor_element_ptype(ext)?; let padded_dim = dimension.next_power_of_two(); diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index beadbe9285c..dfb650fff42 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -133,7 +133,7 @@ impl ScalarFnVTable for L2Norm { // (e.g., if the input extension has f64 elements). if let Some(tq) = input_ref.as_opt::() { let ext = input_ref.dtype().as_extension(); - let target_ptype = extension_element_ptype(ext)?; + let target_ptype = tensor_element_ptype(ext)?; let norms: PrimitiveArray = tq.norms().clone().execute(ctx)?; let target_dtype = DType::Primitive(target_ptype, input_ref.dtype().nullability()); return norms.into_array().cast(target_dtype);