Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vortex-tensor/src/encodings/turboquant/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use vortex_error::vortex_ensure_eq;

use crate::encodings::turboquant::array::slots::Slot;
use crate::encodings::turboquant::vtable::TurboQuant;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;

/// TurboQuant array data.
///
Expand Down Expand Up @@ -117,7 +117,7 @@ impl TurboQuantData {

let dimension = dtype
.as_extension_opt()
.and_then(|ext| extension_list_size(ext).ok())
.and_then(|ext| tensor_list_size(ext).ok())
.vortex_expect("dtype must be a Vector extension type with FixedSizeList storage");

let bit_width = if centroids.is_empty() {
Expand Down Expand Up @@ -154,7 +154,7 @@ impl TurboQuantData {
rotation_signs: &ArrayRef,
) -> VortexResult<()> {
let ext = TurboQuant::validate_dtype(dtype)?;
let dimension = extension_list_size(ext)?;
let dimension = tensor_list_size(ext)?;
let padded_dim = dimension.next_power_of_two();

// Codes must be a non-nullable FixedSizeList<u8> with list_size == padded_dim.
Expand Down Expand Up @@ -209,7 +209,7 @@ impl TurboQuantData {

// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
// Norms carry the validity of the entire TurboQuant array.
let element_ptype = extension_element_ptype(ext)?;
let element_ptype = tensor_element_ptype(ext)?;
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
vortex_ensure_eq!(
*norms.dtype(),
Expand Down
8 changes: 4 additions & 4 deletions vortex-tensor/src/encodings/turboquant/array/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use vortex_error::VortexResult;
use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::TurboQuantConfig;
use crate::encodings::turboquant::turboquant_encode;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;

/// TurboQuant compression scheme for [`Vector`] extension types.
///
Expand Down Expand Up @@ -59,8 +59,8 @@ impl Scheme for TurboQuantScheme {
let len = data.array().len();

let ext = TurboQuant::validate_dtype(dtype)?;
let element_ptype = extension_element_ptype(ext)?;
let dimension = extension_list_size(ext)?;
let element_ptype = tensor_element_ptype(ext)?;
let dimension = tensor_list_size(ext)?;

Ok(estimate_compression_ratio(
element_ptype.bit_width(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use vortex_error::vortex_ensure_eq;

use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::array::float_from_f32;
use crate::utils::extension_element_ptype;
use crate::utils::tensor_element_ptype;

/// Compute the per-row unit-norm dot products in f32 (centroids are always f32).
///
Expand Down Expand Up @@ -107,7 +107,7 @@ pub fn cosine_similarity_quantized_column(
"TurboQuant quantized dot product requires matching dimensions",
);

let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?;
let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?;
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;

Expand Down Expand Up @@ -145,7 +145,7 @@ pub fn dot_product_quantized_column(
"TurboQuant quantized dot product requires matching dimensions",
);

let element_ptype = extension_element_ptype(lhs.dtype().as_extension())?;
let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?;
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
let num_rows = lhs.norms().len();
Expand Down
4 changes: 2 additions & 2 deletions vortex-tensor/src/encodings/turboquant/decompress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use vortex_error::VortexResult;
use crate::encodings::turboquant::TurboQuant;
use crate::encodings::turboquant::array::float_from_f32;
use crate::encodings::turboquant::array::rotation::RotationMatrix;
use crate::utils::extension_element_ptype;
use crate::utils::tensor_element_ptype;

/// Decompress a `TurboQuantArray` into a [`Vector`] extension array.
///
Expand All @@ -38,7 +38,7 @@ pub fn execute_decompress(
let padded_dim = array.padded_dim() as usize;
let num_rows = array.norms().len();
let ext_dtype = array.dtype().as_extension().clone();
let element_ptype = extension_element_ptype(&ext_dtype)?;
let element_ptype = tensor_element_ptype(&ext_dtype)?;

if num_rows == 0 {
let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability());
Expand Down
12 changes: 6 additions & 6 deletions vortex-tensor/src/encodings/turboquant/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use crate::encodings::turboquant::array::slots::Slot;
use crate::encodings::turboquant::compute::rules::PARENT_KERNELS;
use crate::encodings::turboquant::compute::rules::RULES;
use crate::encodings::turboquant::decompress::execute_decompress;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;
use crate::vector::Vector;

/// Encoding marker type for TurboQuant.
Expand All @@ -66,7 +66,7 @@ impl TurboQuant {
vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}")
})?;

let dimension = extension_list_size(ext)?;
let dimension = tensor_list_size(ext)?;
vortex_ensure!(
dimension >= Self::MIN_DIMENSION,
"TurboQuant requires dimension >= {}, got {dimension}",
Expand Down Expand Up @@ -113,7 +113,7 @@ impl VTable for TurboQuant {
vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}")
})?;

let dimension = extension_list_size(ext)?;
let dimension = tensor_list_size(ext)?;
vortex_ensure!(
dimension >= Self::MIN_DIMENSION,
"TurboQuant requires dimension >= {}, got {dimension}",
Expand Down Expand Up @@ -208,8 +208,8 @@ impl VTable for TurboQuant {

// Validate and derive dimension and element ptype from the Vector extension dtype.
let ext = TurboQuant::validate_dtype(dtype)?;
let dimension = extension_list_size(ext)?;
let element_ptype = extension_element_ptype(ext)?;
let dimension = tensor_list_size(ext)?;
let element_ptype = tensor_element_ptype(ext)?;

let padded_dim = dimension.next_power_of_two();

Expand Down
2 changes: 1 addition & 1 deletion vortex-tensor/src/scalar_fns/l2_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl ScalarFnVTable for L2Norm {
// (e.g., if the input extension has f64 elements).
if let Some(tq) = input_ref.as_opt::<TurboQuant>() {
let ext = input_ref.dtype().as_extension();
let target_ptype = extension_element_ptype(ext)?;
let target_ptype = tensor_element_ptype(ext)?;
let norms: PrimitiveArray = tq.norms().clone().execute(ctx)?;
let target_dtype = DType::Primitive(target_ptype, input_ref.dtype().nullability());
return norms.into_array().cast(target_dtype);
Expand Down
Loading