diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 8ae349c3da5..859ee6829e2 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -21546,6 +21546,8 @@ pub fn vortex_array::Array::new(ext_dtype: vort pub fn vortex_array::Array::try_new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> vortex_error::VortexResult +pub fn vortex_array::Array::try_new_from_vtable(vtable: V, metadata: ::Metadata, storage_array: vortex_array::ArrayRef) -> vortex_error::VortexResult + impl vortex_array::Array pub fn vortex_array::Array::new(array: vortex_array::ArrayRef, mask: vortex_mask::Mask) -> Self diff --git a/vortex-array/src/arrays/extension/array.rs b/vortex-array/src/arrays/extension/array.rs index 4ff2ad1bcfb..0a79774a2b5 100644 --- a/vortex-array/src/arrays/extension/array.rs +++ b/vortex-array/src/arrays/extension/array.rs @@ -13,7 +13,9 @@ use crate::array::ArrayParts; use crate::array::TypedArrayRef; use crate::arrays::Extension; use crate::dtype::DType; +use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtDTypeRef; +use crate::dtype::extension::ExtVTable; /// The backing storage array for this extension array. pub(super) const STORAGE_SLOT: usize = 0; @@ -163,4 +165,17 @@ impl Array { ) }) } + + /// Creates a new [`ExtensionArray`](crate::arrays::ExtensionArray) from a vtable, metadata, and + /// a storage array. + pub fn try_new_from_vtable( + vtable: V, + metadata: V::Metadata, + storage_array: ArrayRef, + ) -> VortexResult { + let ext_dtype = + ExtDType::::try_with_vtable(vtable, metadata, storage_array.dtype().clone())? + .erased(); + Self::try_new(ext_dtype, storage_array) + } } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 1b763f53635..d4a00fbdec4 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -14,8 +14,8 @@ use crate::scalar::ScalarValue; /// The public API for defining new extension types. /// -/// This is the non-object-safe trait that plugin authors implement to define a new extension -/// type. It specifies the type's identity, metadata, serialization, and validation. +/// This is the non-object-safe trait that plugin authors implement to define a new extension type. +/// It specifies the type's identity, metadata, serialization, and validation. pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Associated type containing the deserialized metadata for this extension type. type Metadata: 'static + Send + Sync + Clone + Debug + Display + Eq + Hash; @@ -39,11 +39,11 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Validate that the given storage type is compatible with this extension type. fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()>; - /// Can a value of `other` be implicitly widened into this type? - /// e.g. GeographyType might accept Point, LineString, etc. + /// Can a value of `other` be implicitly widened into this type? (e.g. GeographyType might + /// accept Point, LineString, etc.) /// - /// Implementors only need to override one of `can_coerce_from` or `can_coerce_to` — both - /// exist so that either side of the coercion can provide the logic. + /// Implementors only need to override one of `can_coerce_from` or `can_coerce_to`. We have both + /// so that either side of the coercion can provide the logic. fn can_coerce_from(ext_dtype: &ExtDType, other: &DType) -> bool { let _ = (ext_dtype, other); false @@ -51,14 +51,15 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Can this type be implicitly widened into `other`? /// - /// Implementors only need to override one of `can_coerce_from` or `can_coerce_to` — both - /// exist so that either side of the coercion can provide the logic. + /// Implementors only need to override one of `can_coerce_from` or `can_coerce_to`. We have both + /// so that either side of the coercion can provide the logic. fn can_coerce_to(ext_dtype: &ExtDType, other: &DType) -> bool { let _ = (ext_dtype, other); false } /// Given two types in a Uniform context, what is their least supertype? + /// /// Return None if no supertype exists. fn least_supertype(ext_dtype: &ExtDType, other: &DType) -> Option { let _ = (ext_dtype, other); @@ -69,7 +70,8 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Validate the given storage value is compatible with the extension type. /// - /// By default, this calls [`unpack_native()`](ExtVTable::unpack_native) and discards the result. + /// By default, this calls [`unpack_native()`](ExtVTable::unpack_native) and discards the + /// result. /// /// # Errors /// diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 96b95e1e91e..d233322a3bf 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -80,7 +80,7 @@ pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32 pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult -pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::turboquant_encode(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult @@ -142,7 +142,7 @@ impl vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_> pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::element_ptype(&self) -> vortex_array::dtype::ptype::PType -pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::list_size(&self) -> usize +pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::flat_list_size(&self) -> u32 pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::metadata(&self) -> &vortex_tensor::fixed_shape::FixedShapeTensorMetadata @@ -222,7 +222,7 @@ impl vortex_tensor::matcher::TensorMatch<'_> pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType -pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> usize +pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> u32 impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a> @@ -382,7 +382,7 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows(input: &vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()> pub mod vortex_tensor::scalar_fns::l2_norm @@ -502,7 +502,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::child_name(&sel pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::id(&self) -> vortex_array::scalar_fn::ScalarFnId @@ -600,12 +600,8 @@ impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherM pub mod vortex_tensor::vector_search -pub fn vortex_tensor::vector_search::build_constant_query_vector>(query: &[T], num_rows: usize) -> vortex_error::VortexResult - pub fn vortex_tensor::vector_search::build_similarity_search_tree>(data: vortex_array::array::erased::ArrayRef, query: &[T], threshold: T) -> vortex_error::VortexResult -pub fn vortex_tensor::vector_search::compress_turboquant(data: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult - pub const vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession) diff --git a/vortex-tensor/src/encodings/l2_denorm.rs b/vortex-tensor/src/encodings/l2_denorm.rs index 172191abf6e..d29b2e94daf 100644 --- a/vortex-tensor/src/encodings/l2_denorm.rs +++ b/vortex-tensor/src/encodings/l2_denorm.rs @@ -19,9 +19,8 @@ use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; pub struct L2DenormScheme; impl Scheme for L2DenormScheme { - // TODO(connor): FIX THIS!!! fn scheme_name(&self) -> &'static str { - "vortex.tensor.UNSTABLE.l2_denorm" + "vortex.tensor.l2_denorm" } fn matches(&self, canonical: &Canonical) -> bool { diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs index cd7b8b889ce..3111034e68f 100644 --- a/vortex-tensor/src/encodings/turboquant/centroids.rs +++ b/vortex-tensor/src/encodings/turboquant/centroids.rs @@ -11,6 +11,7 @@ use std::sync::LazyLock; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_utils::aliases::dash_map::DashMap; @@ -27,16 +28,15 @@ const CONVERGENCE_EPSILON: f64 = 1e-12; /// Number of numerical integration points for computing conditional expectations. const INTEGRATION_POINTS: usize = 1000; -// TODO(connor): Maybe we should just store an `ArrayRef` here? /// Global centroid cache keyed by (dimension, bit_width). -static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); /// Get or compute cached centroids for the given dimension and bit width. /// /// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar /// quantization levels for the coordinate distribution after random rotation in /// `dimension`-dimensional space. -pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { +pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { vortex_ensure!( (1..=MAX_BIT_WIDTH).contains(&bit_width), "TurboQuant bit_width must be 1-{}, got {bit_width}", @@ -92,7 +92,7 @@ impl HalfIntExponent { /// The probability distribution function is: /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` /// where `C_d` is the normalizing constant. -fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); let num_centroids = 1usize << bit_width; @@ -288,7 +288,7 @@ mod tests { #[case(128, 4)] fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { let centroids = get_centroids(dim, bits)?; - for &val in ¢roids { + for &val in centroids.iter() { assert!( (-1.0..=1.0).contains(&val), "centroid out of [-1, 1]: {val}", diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index b0173bbe36c..40e24807091 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -15,16 +15,15 @@ use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; +use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -35,7 +34,8 @@ use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::encodings::turboquant::centroids::get_centroids; -use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; +use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; @@ -64,129 +64,45 @@ impl Default for TurboQuantConfig { } } -/// Shared intermediate results from the quantization loop. -struct QuantizationResult { - centroids: Vec, - all_indices: BufferMut, - padded_dim: usize, -} - -/// Core quantization: rotate and quantize already-normalized rows. +/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) +/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized +/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer +/// [`L2Denorm`] wrapper. /// -/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null -/// vectors are not supported and must be zeroed out before reaching this function. The rotation -/// and centroid lookup happen in f32. -fn turboquant_quantize_core( - fsl: &FixedSizeListArray, - seed: u64, - bit_width: u8, - num_rounds: u8, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let dimension = - usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize"); - let num_rows = fsl.len(); - - let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; - let padded_dim = rotation.padded_dim(); - let padded_dim_u32 = - u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - - let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - let f32_elements = cast_to_f32(elements_prim)?; - - let centroids = get_centroids(padded_dim_u32, bit_width)?; - let boundaries = compute_centroid_boundaries(¢roids); - - let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); - let mut padded = vec![0.0f32; padded_dim]; - let mut rotated = vec![0.0f32; padded_dim]; - - let f32_slice = f32_elements.as_slice(); - for row in 0..num_rows { - let x = &f32_slice[row * dimension..(row + 1) * dimension]; - - // Zero-pad to the next power of 2. - padded[..dimension].copy_from_slice(x); - padded[dimension..].fill(0.0); - - rotation.rotate(&padded, &mut rotated); - - for j in 0..padded_dim { - all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); - } - } - - Ok(QuantizationResult { - centroids, - all_indices, - padded_dim, - }) -} - -/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`. +/// The returned array has the canonical TurboQuant shape: /// -/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes -/// indexes into the centroid codebook. The Dict can be independently sliced, taken, -/// or executed (dequantized) without knowledge of the rotation. -fn build_quantized_fsl( - num_rows: usize, - all_indices: BufferMut, - centroids: &[f32], - padded_dim: usize, -) -> VortexResult { - let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); - - let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); - centroids_buf.extend_from_slice(centroids); - let centroids_array = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); - - let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; - - let padded_dim_u32 = - u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - Ok(FixedSizeListArray::try_new( - dict.into_array(), - padded_dim_u32, - Validity::NonNullable, - num_rows, - )? - .into_array()) -} - -/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a -/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`. +/// ```text +/// ScalarFnArray(L2Denorm, [ +/// ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +/// norms, +/// ]) +/// ``` /// -/// The input must be a non-nullable Vector extension array whose rows are already unit-norm. -/// **Null vectors are not supported.** The caller must normalize and strip nullability before -/// calling this function, for example via [`normalize_as_l2_denorm`]. +/// # Errors /// -/// This function validates that every row is L2-normalized (or is exactly 0.0). Use -/// [`turboquant_encode_unchecked`] to skip this check when the caller has just performed -/// normalization. -/// -/// The returned array is a `SorfTransform` ScalarFnArray wrapping `FSL(Dict)` that decompresses -/// to unit-norm vectors. The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray -/// if the original magnitudes need to be restored. -/// -/// [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm -/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm +/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or +/// if [`turboquant_encode_unchecked`] rejects the input shape. pub fn turboquant_encode( - ext: ArrayView, + input: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx, ) -> VortexResult { - let ext_dtype = ext.dtype().clone(); - - vortex_ensure!( - !ext_dtype.is_nullable(), - "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}", - ); - - validate_l2_normalized_rows(ext.as_ref(), ctx)?; - - // SAFETY: We just validated that the input is non-nullable and all rows are unit-norm. - unsafe { turboquant_encode_unchecked(ext, config, ctx) } + // We must normalize the array before we can encode it with TurboQuant. + let l2_denorm = normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + let num_rows = l2_denorm.len(); + + let normalized_ext = normalized + .as_opt::() + .vortex_expect("normalize_as_l2_denorm always produces an Extension array child"); + + // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows). + let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?; + + // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally + // bypass the strict normalized-row validation when reattaching the stored norms. + Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) } /// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a @@ -240,7 +156,7 @@ pub unsafe fn turboquant_encode_unchecked( Validity::NonNullable, 0, )?; - let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?; + let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?; let sorf_options = SorfOptions { seed, @@ -255,8 +171,8 @@ pub unsafe fn turboquant_encode_unchecked( let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; let quantized_fsl = - build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?; - let padded_vector = wrap_padded_as_vector(quantized_fsl)?; + build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?; + let padded_vector = Vector::try_new_vector_array(quantized_fsl)?; let sorf_options = SorfOptions { seed, @@ -267,9 +183,88 @@ pub unsafe fn turboquant_encode_unchecked( Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) } -/// Wrap an `FSL` in a [`Vector`](crate::vector::Vector) extension so it can be -/// passed as the child of [`SorfTransform`], which expects a `Vector` input. -fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl).into_array()) +/// Shared intermediate results from the quantization loop. +struct QuantizationResult { + centroids: Buffer, + all_indices: Buffer, + padded_dim: usize, +} + +/// Core quantization: rotate and quantize already-normalized rows. +/// +/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null +/// vectors are not supported and must be zeroed out before reaching this function. The rotation +/// and centroid lookup happen in f32. +fn turboquant_quantize_core( + fsl: &FixedSizeListArray, + seed: u64, + bit_width: u8, + num_rounds: u8, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dimension = fsl.list_size() as usize; + let num_rows = fsl.len(); + + let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; + let padded_dim = rotation.padded_dim(); + let padded_dim_u32 = + u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); + + let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + let f32_elements = cast_to_f32(elements_prim)?; + + let centroids = get_centroids(padded_dim_u32, bit_width)?; + let boundaries = compute_centroid_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + let f32_slice = f32_elements.as_slice(); + for row in 0..num_rows { + let x = &f32_slice[row * dimension..(row + 1) * dimension]; + + // Zero-pad to the next power of 2. + padded[..dimension].copy_from_slice(x); + padded[dimension..].fill(0.0); + + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + Ok(QuantizationResult { + centroids, + all_indices: all_indices.freeze(), + padded_dim, + }) +} + +/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`. +/// +/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes indexes into the +/// centroid codebook. The Dict can be independently sliced, taken, or executed (dequantized) +/// without knowledge of the rotation. +fn build_quantized_fsl( + num_rows: usize, + all_indices: Buffer, + centroids: Buffer, + padded_dim: usize, +) -> VortexResult { + let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); + let centroids_array = PrimitiveArray::new::(centroids, Validity::NonNullable); + + let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; + + let padded_dim_u32 = + u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); + Ok(FixedSizeListArray::try_new( + dict.into_array(), + padded_dim_u32, + Validity::NonNullable, + num_rows, + )? + .into_array()) } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 49e53effd0d..51480955594 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -95,16 +95,12 @@ //! use vortex_array::arrays::ExtensionArray; //! use vortex_array::arrays::FixedSizeListArray; //! use vortex_array::arrays::PrimitiveArray; -//! use vortex_array::arrays::Extension; -//! use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -//! use vortex_array::dtype::extension::ExtDType; //! use vortex_array::extension::EmptyMetadata; +//! use vortex_array::session::ArraySession; //! use vortex_array::validity::Validity; //! use vortex_buffer::BufferMut; -//! use vortex_array::session::ArraySession; //! use vortex_session::VortexSession; -//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_unchecked}; -//! use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; //! use vortex_tensor::vector::Vector; //! //! // Create a Vector extension array of 100 random 128-d vectors. @@ -118,22 +114,15 @@ //! let fsl = FixedSizeListArray::try_new( //! elements.into_array(), dim, Validity::NonNullable, num_rows, //! ).unwrap(); -//! let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) -//! .unwrap().erased(); -//! let ext = ExtensionArray::new(ext_dtype, fsl.into_array()); +//! let vector = ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array()) +//! .map(|ext| ext.into_array()) +//! .unwrap(); //! -//! // Normalize, then quantize the normalized child at 2 bits per coordinate. +//! // Normalize and quantize at 2 bits per coordinate in one pass. //! let session = VortexSession::empty().with::(); //! let mut ctx = session.create_execution_ctx(); -//! let l2_denorm = normalize_as_l2_denorm(ext.into_array(), &mut ctx).unwrap(); -//! let normalized = l2_denorm.child_at(0).clone(); -//! -//! let normalized_ext = normalized.as_opt::().unwrap(); //! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 }; -//! // SAFETY: We just normalized the input. -//! let tq = unsafe { -//! turboquant_encode_unchecked(normalized_ext, &config, &mut ctx).unwrap() -//! }; +//! let tq = turboquant_encode(vector, &config, &mut ctx).unwrap(); //! //! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. //! assert!(tq.nbytes() < 51200); diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index b603f12e16e..81fca209a8d 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -3,8 +3,7 @@ //! TurboQuant compression scheme. //! -//! The scheme first normalizes the input via [`normalize_as_l2_denorm`], then encodes the -//! normalized child via [`turboquant_encode_unchecked`]. The result is: +//! The scheme is a thin [`Scheme`] adapter over [`turboquant_encode`], which produces: //! //! ```text //! ScalarFnArray(L2Denorm, [ @@ -18,14 +17,10 @@ //! //! Decompression is automatic: executing the outer array walks the ScalarFn tree. //! -//! [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm -//! [`turboquant_encode_unchecked`]: crate::encodings::turboquant::turboquant_encode_unchecked +//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode use vortex_array::ArrayRef; use vortex_array::Canonical; -use vortex_array::IntoArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; use vortex_compressor::estimate::CompressionEstimate; @@ -38,9 +33,7 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::MAX_CENTROIDS; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::tq_validate_vector_dtype; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::encodings::turboquant::turboquant_encode; /// TurboQuant compression scheme for [`Vector`] extension types. /// @@ -105,33 +98,8 @@ impl Scheme for TurboQuantScheme { data: &mut ArrayAndStats, _ctx: CompressorContext, ) -> VortexResult { - let ext_array = data - .array() - .as_opt::() - .vortex_expect("expected an extension array"); - let mut ctx = compressor.execution_ctx(); - - // 1. Normalize: produces L2Denorm(normalized_vectors, norms). - let l2_denorm = normalize_as_l2_denorm(ext_array.as_ref().clone(), &mut ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - // 2. Quantize the normalized child: SorfTransform(FSL(Dict)). - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - - let config = TurboQuantConfig::default(); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`, so all rows are - // guaranteed to be unit-norm (or zero for originally-null rows). - let sorf_dict = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? }; - - // 3. Wrap back in L2Denorm: the SorfTransform is the "normalized" child. - // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally - // bypass the strict normalized-row validation when reattaching the stored norms. - Ok(unsafe { L2Denorm::new_array_unchecked(sorf_dict, norms, num_rows) }?.into_array()) + turboquant_encode(data.array().clone(), &TurboQuantConfig::default(), &mut ctx) } } diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs index 0a9e0ab7a18..a26f0c54cfc 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/compute.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/compute.rs @@ -44,7 +44,7 @@ fn slice_preserves_data() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; // Full decompress then slice. let mut ctx = SESSION.create_execution_ctx(); @@ -89,7 +89,7 @@ fn scalar_at_matches_decompress() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let full_decoded = encoded.clone().execute::(&mut ctx)?; @@ -112,7 +112,7 @@ fn l2_norm_readthrough() -> VortexResult<()> { num_rounds: 5, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); // Stored norms should match the actual L2 norms of the input. @@ -150,7 +150,7 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; @@ -187,7 +187,7 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let encoded_cos = execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index b111c6e28ba..5b7e325b9b7 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -16,7 +16,6 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Dict; -use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -25,17 +24,13 @@ use vortex_array::arrays::dict::DictArraySlotsExt; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::encodings::turboquant::turboquant_encode; use crate::tests::SESSION; use crate::vector::Vector; @@ -71,31 +66,9 @@ fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { } /// Wrap a `FixedSizeListArray` in a `Vector` extension array. -fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .unwrap() - .erased(); - ExtensionArray::new(ext_dtype, fsl.clone().into_array()) -} - -/// Full encode pipeline: normalize → TQ-encode → wrap in L2Denorm. -fn normalize_and_encode( - ext: &ExtensionArray, - config: &TurboQuantConfig, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(ext.as_ref().clone(), ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx)? }; - - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) +fn make_vector_ext(fsl: &FixedSizeListArray) -> ArrayRef { + Vector::try_new_vector_array(fsl.clone().into_array()) + .vortex_expect("test FSL satisfies Vector storage constraints") } /// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). @@ -103,17 +76,7 @@ fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { let sfn = encoded .as_opt::() .expect("expected ScalarFnArray (L2Denorm)"); - let sorf_child = sfn.child_at(0).clone(); - let norms_child = sfn.child_at(1).clone(); - (sorf_child, norms_child) -} - -/// Unwrap a SorfTransform ScalarFnArray to get the FSL(Dict) child. -fn unwrap_sorf(sorf: &ArrayRef) -> ArrayRef { - let sfn = sorf - .as_opt::() - .expect("expected ScalarFnArray (SorfTransform)"); - sfn.child_at(0).clone() + (sfn.child_at(0).clone(), sfn.child_at(1).clone()) } /// Navigate the full tree to get (codes, centroids, norms) as flat arrays. @@ -122,7 +85,11 @@ fn unwrap_codes_centroids_norms( ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { let (sorf_child, norms_child) = unwrap_l2denorm(encoded); - let padded_vector_child = unwrap_sorf(&sorf_child); + let padded_vector_child = sorf_child + .as_opt::() + .expect("expected SorfTransform ScalarFnArray") + .child_at(0) + .clone(); // Vector wrapping FSL(Dict(codes, centroids)) let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; @@ -177,8 +144,7 @@ fn encode_decode( let prim = fsl.elements().clone().execute::(&mut ctx)?; prim.as_slice::().to_vec() }; - let ext = make_vector_ext(fsl); - let encoded = normalize_and_encode(&ext, config, &mut ctx)?; + let encoded = turboquant_encode(make_vector_ext(fsl), config, &mut ctx)?; let decoded_ext = encoded.execute::(&mut ctx)?; let decoded_fsl = decoded_ext .storage_array() @@ -193,19 +159,3 @@ fn encode_decode( }; Ok((original, decoded_elements)) } - -fn make_fsl_small(dim: usize) -> FixedSizeListArray { - let mut buf = BufferMut::::with_capacity(dim); - for i in 0..dim { - buf.push(i as f32 + 1.0); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - 1, - ) - .unwrap() -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs index 6fc19bb93ec..41124c27c80 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs @@ -27,7 +27,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; assert_eq!(encoded.len(), 10); assert!(encoded.dtype().is_nullable()); @@ -88,7 +88,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let norms_validity = norms_child.validity()?; @@ -118,7 +118,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let norm_sfn = L2Norm::try_new_array(encoded, 5)?; let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; @@ -160,7 +160,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let sliced = encoded.slice(1..6)?; assert_eq!(sliced.len(), 5); diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index cd61d9193da..dbe04f2e606 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -4,6 +4,7 @@ use rstest::rstest; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -12,6 +13,8 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use super::*; +use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; #[rstest] #[case(128, 1)] @@ -130,7 +133,7 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let decoded = encoded.execute::(&mut ctx)?; assert_eq!(decoded.len(), num_rows); Ok(()) @@ -141,7 +144,17 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { #[case(64)] #[case(127)] fn rejects_dimension_below_128(#[case] dim: usize) { - let fsl = make_fsl_small(dim); + let elements = PrimitiveArray::new::( + BufferMut::from_iter((0..dim).map(|i| i as f32 + 1.0)).freeze(), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into().expect("dim fits u32"), + Validity::NonNullable, + 1, + ) + .unwrap(); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 2, @@ -149,9 +162,7 @@ fn rejects_dimension_below_128(#[case] dim: usize) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - assert!( - crate::encodings::turboquant::turboquant_encode(ext.as_view(), &config, &mut ctx).is_err() - ); + assert!(turboquant_encode(ext, &config, &mut ctx).is_err()); } #[rstest] @@ -166,7 +177,7 @@ fn rejects_invalid_bit_width(#[case] bit_width: u8) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx) + let normalized = normalize_as_l2_denorm(ext, &mut ctx) .unwrap() .child_at(0) .clone(); @@ -255,7 +266,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); Ok(()) @@ -288,7 +299,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); @@ -300,44 +311,3 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { assert_eq!(decoded_fsl.len(), num_rows); Ok(()) } - -/// Verify that the checked encode accepts normalized f16 input. -#[test] -fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(half::f16::from_f32(normal.sample(&mut rng))); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into().unwrap(), - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - num_rounds: 3, - }; - - let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx)? - .child_at(0) - .clone(); - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - - let encoded = - crate::encodings::turboquant::turboquant_encode(normalized_ext, &config, &mut ctx)?; - assert_eq!(encoded.len(), num_rows); - Ok(()) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 87b59836b38..7cd2cdfcf66 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -24,7 +24,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let (_codes, centroids, _norms) = unwrap_codes_centroids_norms(&encoded, &mut ctx)?; let stored = centroids.as_slice::(); @@ -52,7 +52,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { // Encode twice with the same seed → should produce identical results. let mut ctx = SESSION.create_execution_ctx(); - let encoded1 = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded1 = turboquant_encode(ext.clone(), &config, &mut ctx)?; let decoded1 = encoded1.execute::(&mut ctx)?; let fsl1 = decoded1 .storage_array() @@ -64,7 +64,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { .execute::(&mut ctx)?; let mut ctx = SESSION.create_execution_ctx(); - let encoded2 = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded2 = turboquant_encode(ext, &config, &mut ctx)?; let decoded2 = encoded2.execute::(&mut ctx)?; let fsl2 = decoded2 .storage_array() @@ -94,7 +94,7 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; assert!( encoded.dtype().is_extension(), @@ -119,7 +119,7 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); @@ -176,7 +176,7 @@ fn dot_product_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_encode(ext, &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); diff --git a/vortex-tensor/src/fixed_shape/matcher.rs b/vortex-tensor/src/fixed_shape/matcher.rs index 5248ca6514b..f703ccba978 100644 --- a/vortex-tensor/src/fixed_shape/matcher.rs +++ b/vortex-tensor/src/fixed_shape/matcher.rs @@ -31,7 +31,7 @@ pub struct FixedShapeTensorMatcherMetadata<'a> { /// /// This matches the `FixedSizeList` list size in the storage dtype, which is the product of /// the logical shape dimensions. - flat_list_size: usize, + flat_list_size: u32, } impl Matcher for AnyFixedShapeTensor { @@ -64,7 +64,7 @@ impl Matcher for AnyFixedShapeTensor { Some(FixedShapeTensorMatcherMetadata { metadata, element_ptype: element_dtype.as_ptype(), - flat_list_size: *list_size as usize, + flat_list_size: *list_size, }) } } @@ -81,7 +81,7 @@ impl FixedShapeTensorMatcherMetadata<'_> { } /// Returns the flattened element count for each tensor row. - pub fn list_size(&self) -> usize { + pub fn flat_list_size(&self) -> u32 { self.flat_list_size } } @@ -118,7 +118,7 @@ mod tests { let metadata = ext_dtype.metadata::(); assert_eq!(metadata.element_ptype(), PType::F32); - assert_eq!(metadata.list_size(), 24); + assert_eq!(metadata.flat_list_size(), 24); assert_eq!(metadata.metadata().logical_shape(), &[2, 3, 4]); Ok(()) } diff --git a/vortex-tensor/src/fixed_shape/metadata.rs b/vortex-tensor/src/fixed_shape/metadata.rs index 264d18453c4..757138d3e50 100644 --- a/vortex-tensor/src/fixed_shape/metadata.rs +++ b/vortex-tensor/src/fixed_shape/metadata.rs @@ -215,6 +215,7 @@ impl fmt::Display for FixedShapeTensorMetadata { } if let Some(perm) = &self.permutation { + write!(f, ", [")?; for (i, p) in perm.iter().enumerate() { if i > 0 { write!(f, ", ")?; @@ -353,6 +354,44 @@ mod tests { Ok(()) } + // -- Display -- + + #[test] + fn display_shape_only() { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.to_string(), "Tensor(2, 3, 4)"); + } + + #[test] + fn display_scalar_0d() { + let m = FixedShapeTensorMetadata::new(vec![]); + assert_eq!(m.to_string(), "Tensor()"); + } + + #[test] + fn display_with_dim_names() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()])?; + assert_eq!(m.to_string(), "Tensor(rows: 3, cols: 4)"); + Ok(()) + } + + #[test] + fn display_with_permutation() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![1, 0, 2])?; + assert_eq!(m.to_string(), "Tensor(2, 3, 4, [1, 0, 2])"); + Ok(()) + } + + #[test] + fn display_with_dim_names_and_permutation() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? + .with_permutation(vec![1, 2, 0])?; + assert_eq!(m.to_string(), "Tensor(x: 2, y: 3, z: 4, [1, 2, 0])"); + Ok(()) + } + #[test] fn dim_names_wrong_length() { let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]); diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index 973562be3da..9ff58d17c4f 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -42,10 +42,10 @@ impl TensorMatch<'_> { } /// Returns the flattened element count for each logical tensor row. - pub fn list_size(self) -> usize { + pub fn list_size(self) -> u32 { match self { - Self::FixedShapeTensor(metadata) => metadata.list_size(), - Self::Vector(metadata) => metadata.dimensions() as usize, + Self::FixedShapeTensor(metadata) => metadata.flat_list_size(), + Self::Vector(metadata) => metadata.dimensions(), } } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 9f8eff0b361..85d16236c8c 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -11,7 +11,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; @@ -32,16 +31,15 @@ use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; -use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::scalar_fns::inner_product::BinaryTensorOpMetadata; use crate::scalar_fns::inner_product::InnerProduct; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extract_l2_denorm_children; -use crate::utils::validate_tensor_float_input; +use crate::utils::validate_binary_tensor_float_inputs; /// Cosine similarity between two columns. /// @@ -59,6 +57,7 @@ use crate::utils::validate_tensor_float_input; /// /// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor /// [`Vector`]: crate::vector::Vector +/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm #[derive(Clone)] pub struct CosineSimilarity; @@ -84,7 +83,7 @@ impl ScalarFnVTable for CosineSimilarity { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.cosine_similarity") + ScalarFnId::new("vortex.tensor.cosine_similarity") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -116,16 +115,8 @@ impl ScalarFnVTable for CosineSimilarity { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // Both must have the same dtype (ignoring top-level nullability). - vortex_ensure!( - lhs.eq_ignore_nullability(rhs), - "CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}" - ); - - // We don't need to look at rhs anymore since we know lhs and rhs are equal. - let tensor_match = validate_tensor_float_input(lhs)?; + let tensor_match = validate_binary_tensor_float_inputs("CosineSimilarity", lhs, rhs)?; let ptype = tensor_match.element_ptype(); - let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -141,32 +132,24 @@ impl ScalarFnVTable for CosineSimilarity { let len = args.row_count(); // If either side is a constant tensor-like extension array, eagerly normalize the single - // stored row and re-wrap it as an `L2Denorm` whose children are both [`ConstantArray`]s. + // stored row and re-wrap it as an `L2Denorm` whose children are both `ConstantArray`s. // The L2Denorm fast path below then picks it up. - if let Some(lhs_constant) = - try_build_constant_l2_denorm(&lhs_ref, len, ctx)?.map(|sfn| sfn.into_array()) - { - lhs_ref = lhs_constant; + if let Some(sfn) = try_build_constant_l2_denorm(&lhs_ref, len, ctx)? { + lhs_ref = sfn.into_array(); } - if let Some(rhs_constant) = - try_build_constant_l2_denorm(&rhs_ref, len, ctx)?.map(|sfn| sfn.into_array()) - { - rhs_ref = rhs_constant; + if let Some(sfn) = try_build_constant_l2_denorm(&rhs_ref, len, ctx)? { + rhs_ref = sfn.into_array(); } - // Check if any of our children have be already normalized. - { - let lhs_is_denorm = lhs_ref.is::>(); - let rhs_is_denorm = rhs_ref.is::>(); - - if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); - } else if lhs_is_denorm || rhs_is_denorm { - if rhs_is_denorm { - (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); - } - return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); + // Take any L2Denorm-wrapped fast path that applies. + match DenormOrientation::classify(&lhs_ref, &rhs_ref) { + DenormOrientation::Both { lhs, rhs } => { + return self.execute_both_denorm(lhs, rhs, len); + } + DenormOrientation::One { denorm, plain } => { + return self.execute_one_denorm(denorm, plain, len, ctx); } + DenormOrientation::Neither => {} } // Compute combined validity. @@ -266,7 +249,6 @@ impl CosineSimilarity { lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, - _ctx: &mut ExecutionCtx, ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; @@ -347,9 +329,10 @@ mod tests { use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; + use crate::vector::Vector; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { @@ -508,7 +491,7 @@ mod tests { 1.0, 0.0, 0.0, // vector 3 ], )?; - let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?; + let query = Vector::constant_array(&[1.0, 0.0, 0.0], 4)?; assert_close( &eval_cosine_similarity(data, query, 4)?, @@ -536,25 +519,13 @@ mod tests { Ok(()) } - /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. - fn l2_denorm_array( - shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], - ) -> VortexResult { - let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); - let mut ctx = SESSION.create_execution_ctx(); - Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) - } - #[test] fn both_denorm_self_similarity() -> VortexResult<()> { // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8]. // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0]. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); @@ -565,8 +536,9 @@ mod tests { fn both_denorm_orthogonal() -> VortexResult<()> { // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0. // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0. - let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?; - let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); Ok(()) @@ -575,8 +547,9 @@ mod tests { #[test] fn both_denorm_zero_norm() -> VortexResult<()> { // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); @@ -588,7 +561,8 @@ mod tests { // LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // RHS is plain [3.0, 4.0]. // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[3.0, 4.0])?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); @@ -599,8 +573,9 @@ mod tests { fn one_side_denorm_rhs() -> VortexResult<()> { // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6. + let mut ctx = SESSION.create_execution_ctx(); let lhs = tensor_array(&[2], &[1.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); Ok(()) @@ -609,11 +584,11 @@ mod tests { #[test] fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on rhs). - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); - let mut ctx = SESSION.create_execution_ctx(); let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); @@ -711,7 +686,7 @@ mod tests { #[test] fn vector_constant_matches_plain() -> VortexResult<()> { // Exercise the `Vector` extension variant through the new pre-pass. - let lhs = constant_vector_array(&[1.0, 2.0, 2.0], 4)?; + let lhs = Vector::constant_array(&[1.0, 2.0, 2.0], 4)?; let rhs = vector_array( 3, &[ diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index dd9c2a7381f..22b28f9e4d5 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -4,7 +4,6 @@ //! Inner product expression for tensor-like types. use std::fmt::Formatter; -use std::sync::Arc; use num_traits::Float; use prost::Message; @@ -32,13 +31,10 @@ use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; -use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; -use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; -use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::EmptyOptions; @@ -56,11 +52,13 @@ use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; +use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; use crate::utils::extract_l2_denorm_children; +use crate::utils::validate_binary_tensor_float_inputs; use crate::vector::Vector; /// Inner product (dot product) between two columns. @@ -99,7 +97,7 @@ impl ScalarFnVTable for InnerProduct { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.inner_product") + ScalarFnId::new("vortex.tensor.inner_product") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -131,32 +129,9 @@ impl ScalarFnVTable for InnerProduct { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // Both must have the same dtype (ignoring top-level nullability). - vortex_ensure!( - lhs.eq_ignore_nullability(rhs), - "InnerProduct requires both inputs to have the same dtype, got {lhs} and {rhs}" - ); - - // Both inputs must be tensor-like extension types. - let lhs_ext = lhs - .as_extension_opt() - .ok_or_else(|| vortex_err!("InnerProduct lhs must be an extension type, got {lhs}"))?; - - vortex_ensure!( - lhs_ext.is::(), - "InnerProduct inputs must be an `AnyTensor`, got {lhs}" - ); - - let tensor_match = lhs_ext - .metadata_opt::() - .ok_or_else(|| vortex_err!("InnerProduct inputs must be an `AnyTensor`, got {lhs}"))?; + // TODO(connor): relax the float-only gate once integer tensors are supported. + let tensor_match = validate_binary_tensor_float_inputs("InnerProduct", lhs, rhs)?; let ptype = tensor_match.element_ptype(); - // TODO(connor): This should support integer tensors! - vortex_ensure!( - ptype.is_float(), - "InnerProduct element dtype must be a float primitive, got {ptype}" - ); - let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -167,23 +142,19 @@ impl ScalarFnVTable for InnerProduct { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let mut lhs_ref = args.get(0)?; - let mut rhs_ref = args.get(1)?; + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let len = args.row_count(); - // Check if any of our children have be already normalized. - { - let lhs_is_denorm = lhs_ref.is::>(); - let rhs_is_denorm = rhs_ref.is::>(); - - if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); - } else if lhs_is_denorm || rhs_is_denorm { - if rhs_is_denorm { - (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); - } - return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); + // Take any L2Denorm-wrapped fast path that applies. + match DenormOrientation::classify(&lhs_ref, &rhs_ref) { + DenormOrientation::Both { lhs, rhs } => { + return self.execute_both_denorm(lhs, rhs, len, ctx); } + DenormOrientation::One { denorm, plain } => { + return self.execute_one_denorm(denorm, plain, len, ctx); + } + DenormOrientation::Neither => {} } // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to @@ -212,7 +183,7 @@ impl ScalarFnVTable for InnerProduct { let tensor_match = ext .metadata_opt::() .vortex_expect("we already validated this in `return_dtype`"); - let dimensions = tensor_match.list_size(); + let dimensions = tensor_match.list_size() as usize; // Extract the storage array from each extension input. We pass the storage (FSL) rather // than the extension array to avoid canonicalizing the extension wrapper. @@ -409,20 +380,21 @@ impl InnerProduct { /// Fast path when one side is `ExactScalarFn` and the other side is a /// constant-backed tensor-like extension. Rewrites to /// `InnerProduct(sorf_child, forward_rotate(zero_pad(const_query)))` because SORF is - /// orthogonal, so ` = ` where `T` is the truncation - /// from `padded_dim` to `dim` applied by `SorfTransform` and `R` is the SORF forward - /// matrix. See the proof in the crate-level docs and in the plan file. + /// orthogonal, so ` = ` where `T` is the truncation from + /// `padded_dim` to `dim` applied by `SorfTransform` and `R` is the SORF forward matrix. See the + /// proof in the crate-level docs and in the plan file. /// - /// Returns `Ok(None)` if neither side matches or when `element_ptype` is not `F32`. The - /// caller is expected to fall through to the standard path in that case. + /// Returns `Ok(None)` if neither side matches, when the operand element type is not `F32`, or + /// when the constant side is not a constant-backed tensor extension. The caller is expected to + /// fall through to the standard path in that case. /// - /// # TODO(connor): + /// # F32-only /// - /// This rewrite is only sound for `PType::F32` because `SorfTransform` applies an - /// `f32 -> element_ptype` cast at the end of its execute (see `sorf_transform/vtable.rs` - /// line ~218). For F16/F64 the cast changes the inner product's rounding and would - /// change the semantics of the rewrite. Until we push the cast through `InnerProduct`, - /// this path only fires for F32. + /// TODO(connor): this rewrite is only sound for `PType::F32` because `SorfTransform` applies an + /// `f32 -> element_ptype` cast at the end of its `execute`. For `F16`/`F64` the cast changes + /// the inner product's rounding and the rewrite would not be semantically equivalent. Until we + /// push the cast through `InnerProduct`, both the SorfTransform output ptype and the + /// constant-side element ptype must be `F32` here. fn try_execute_sorf_constant( &self, lhs_ref: &ArrayRef, @@ -440,10 +412,6 @@ impl InnerProduct { return Ok(None); }; - // TODO(connor): pull-through is only sound for F32 because SorfTransform applies an - // `f32 -> element_ptype` cast at the end of its execute. For F16/F64 the rewrite - // would change the inner product's rounding semantics. Fall through so the standard - // path (which does the cast before inner product) handles it. if sorf_view.options.element_ptype != PType::F32 { return Ok(None); } @@ -458,45 +426,24 @@ impl InnerProduct { let seed = sorf_view.options.seed; let padded_dim = dim.next_power_of_two(); - // Extract the single stored row of the constant via the stride-0 short-circuit. - let flat = extract_flat_elements(&const_storage, dim, ctx)?; + // Extract the single stored row of the constant. + let flat = extract_constant_flat_row(&const_storage, ctx)?; if flat.ptype() != PType::F32 { - // TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The - // standard path handles them correctly. return Ok(None); } // Zero-pad the query from `dim` to `padded_dim` and forward-rotate. let mut padded_query = vec![0.0f32; padded_dim]; - padded_query[..dim].copy_from_slice(flat.row::(0)); + padded_query[..dim].copy_from_slice(flat.as_slice::()); let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?; let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); - // Build the rewritten constant as a `Vector` extension scalar. We reuse - // the original storage FSL nullability so the new extension dtype stays consistent with - // whatever the original tree expected. - let storage_fsl_nullability = const_storage.dtype().nullability(); - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = rotated_query - .into_iter() - .map(|v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let fsl_scalar = - Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability); - - // Build a fresh `Vector` extension dtype. We cannot reuse the - // original extension dtype because that one has `dim`, not `padded_dim`. - let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim fits u32"); - let new_fsl_dtype = DType::FixedSizeList( - Arc::new(element_dtype), - padded_dim_u32, - storage_fsl_nullability, - ); - let new_ext_dtype = ExtDType::::try_new(EmptyMetadata, new_fsl_dtype)?.erased(); - let new_constant = - ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array(); + // Wrap the rotated query as a `Vector` constant broadcast to `len` + // rows. The new extension dtype has `padded_dim` instead of `dim`, matching the + // SorfTransform child we are about to dot it with. + let new_constant = Vector::constant_array(&rotated_query, len)?; // Extract the SorfTransform child (the already-padded Vector). let sorf_child = sorf_view @@ -575,8 +522,7 @@ impl InnerProduct { // Gate: u8 codes and f32 centroids. if codes_prim.ptype() != PType::U8 { - // TODO(connor): support wider code widths (u16, u32). TurboQuant only emits u8 - // codes today, so this is the only path we need for now. + // TODO(connor): Should we support wider codes? return Ok(None); } if values_prim.ptype() != PType::F32 { @@ -587,7 +533,7 @@ impl InnerProduct { let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize"); - let flat = extract_flat_elements(&const_storage, padded_dim, ctx)?; + let flat = extract_constant_flat_row(&const_storage, ctx)?; if flat.ptype() != PType::F32 { // TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard // path, which computes the inner product with the correct element type. @@ -606,14 +552,14 @@ impl InnerProduct { return Ok(Some(empty.into_array())); } - let q: &[f32] = flat.row::(0); + let q: &[f32] = flat.as_slice::(); debug_assert_eq!(q.len(), padded_dim); let codes: &[u8] = codes_prim.as_slice::(); let values: &[f32] = values_prim.as_slice::(); debug_assert_eq!(codes.len(), len * padded_dim); - // The hot loop is extracted into [`execute_dict_constant_inner_product`] with - // unchecked indexing so the compiler can vectorize the inner gather-accumulate. + // The hot loop is extracted into [`execute_dict_constant_inner_product`] so the compiler + // can prove the chunked indices stay in bounds and vectorize the inner gather-accumulate. let out = execute_dict_constant_inner_product(q, values, codes, len, padded_dim); // SAFETY: the buffer length equals `len`, which matches the validity length. @@ -644,10 +590,10 @@ fn inner_product_row(a: &[T], b: &[T]) -> T { /// Compute inner products between a constant query vector and dictionary-encoded rows. /// -/// For each row, computes `sum(q[j] * values[codes[row * dim + j]])` using the codebook -/// `values` directly instead of decoding the dictionary into dense vectors. +/// For each row, computes `sum(q[j] * values[codes[row * dim + j]])` using the codebook `values` +/// directly instead of decoding the dictionary into dense vectors. /// -/// The inner loop uses four independent accumulators so the CPU can pipeline FP additions +/// The inner loop uses `PARTIAL_SUMS` independent accumulators so the CPU can pipeline FP additions /// instead of waiting for each `fadd` to retire before starting the next. fn execute_dict_constant_inner_product( q: &[f32], @@ -704,6 +650,7 @@ mod tests { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -818,28 +765,14 @@ mod tests { Ok(()) } - /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. - fn l2_denorm_array( - shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], - ) -> VortexResult { - use vortex_array::IntoArray; - - let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); - let mut ctx = SESSION.create_execution_ctx(); - Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) - } - #[test] fn both_denorm() -> VortexResult<()> { // LHS: [3.0, 4.0] = L2Denorm([0.6, 0.8], 5.0). // RHS: [1.0, 0.0] = L2Denorm([1.0, 0.0], 1.0). // dot([3.0, 4.0], [1.0, 0.0]) = 3.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; - let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); @@ -850,8 +783,9 @@ mod tests { fn both_denorm_multiple_rows() -> VortexResult<()> { // Row 0: [3.0, 4.0] dot [3.0, 4.0] = 25.0. // Row 1: [1.0, 0.0] dot [0.0, 1.0] = 0.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); Ok(()) @@ -862,7 +796,8 @@ mod tests { // LHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // RHS: plain [1.0, 2.0]. // dot([3.0, 4.0], [1.0, 2.0]) = 3.0 + 8.0 = 11.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[1.0, 2.0])?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); @@ -874,8 +809,9 @@ mod tests { // LHS: plain [1.0, 2.0]. // RHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // dot([1.0, 2.0], [3.0, 4.0]) = 3.0 + 8.0 = 11.0. + let mut ctx = SESSION.create_execution_ctx(); let lhs = tensor_array(&[2], &[1.0, 2.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -889,7 +825,7 @@ mod tests { let mut ctx = SESSION.create_execution_ctx(); let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; @@ -948,15 +884,11 @@ mod tests { reason = "tests build small fixtures with deterministic in-range indices" )] mod constant_query_optimizations { - use std::sync::LazyLock; - use rstest::rstest; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Constant; - use vortex_array::arrays::ConstantArray; - use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; @@ -964,67 +896,23 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::scalar::Scalar; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::inner_product::constant_tensor_storage; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; + use crate::tests::SESSION; use crate::utils::extract_flat_elements; + use crate::utils::test_helpers::literal_vector_array; + use crate::utils::test_helpers::vector_array; use crate::vector::Vector; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - - /// Compact f32 Vector extension over a column-major `elements` slice. - fn vector_f32(dim: u32, elements: &[f32]) -> VortexResult { - let row_count = elements.len() / dim as usize; - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - /// Compact constant-backed f32 Vector extension with a single stored row. - fn constant_vector_f32(elements: &[f32], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let storage = ConstantArray::new(storage_scalar, len).into_array(); - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) - } - - /// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension. - fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let vector_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(vector_scalar, len).into_array() - } - - /// Build an `ExtensionArray>` whose storage is - /// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that - /// TurboQuant produces as the SorfTransform child. + /// Build a `Vector` whose storage is `FSL(DictArray(codes: u8, values: + /// f32))`. This mirrors the shape that TurboQuant produces as the SorfTransform child. fn dict_vector_f32(list_size: u32, codes: &[u8], values: &[f32]) -> VortexResult { let num_rows = codes.len() / list_size as usize; let codes_arr = @@ -1040,9 +928,7 @@ mod tests { Validity::NonNullable, num_rows, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::try_new_vector_array(fsl.into_array()) } /// Execute an inner product and return the flat `f32` results. @@ -1128,7 +1014,7 @@ mod tests { #[test] fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> { - let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5); + let literal = literal_vector_array(&[1.0f32, 2.0, 3.0], 5); let storage = constant_tensor_storage(&literal).expect("literal vector should be recognized"); @@ -1164,7 +1050,7 @@ mod tests { // Query has `dim` elements. let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; // Ground truth: decode LHS to plain f32 vectors, dot each with the query. let decoded = decode_sorf_dict( @@ -1202,7 +1088,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.2).cos()).collect(); - let const_lhs = constant_vector_f32(&query_elems, num_rows)?; + let const_lhs = Vector::constant_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1240,7 +1126,7 @@ mod tests { assert_eq!(padded_dim, dim as usize); let query_elems: Vec = (0..dim).map(|i| i as f32 * 0.01 - 0.5).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1277,7 +1163,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = vec![0.0; dim as usize]; - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1300,7 +1186,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|i| (i as f32 + 1.0) * 0.3).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1330,7 +1216,7 @@ mod tests { let dict_rhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.5, -1.0, 2.5, -0.25]; - let const_lhs = constant_vector_f32(&query, num_rows)?; + let const_lhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1375,12 +1261,10 @@ mod tests { Validity::NonNullable, num_rows, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - let dict_lhs = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + let dict_lhs = Vector::try_new_vector_array(fsl.into_array())?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; // Build expected by decoding by hand. let expected: Vec = (0..num_rows) @@ -1408,10 +1292,10 @@ mod tests { let lhs_elems: Vec = (0..num_rows * dim as usize) .map(|i| i as f32 * 0.25) .collect(); - let plain_lhs = vector_f32(dim, &lhs_elems)?; + let plain_lhs = vector_array(dim, &lhs_elems)?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1438,7 +1322,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.0; 4]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1458,7 +1342,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| ((i as f32) * 0.15).sin() * 0.4).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; // Ground truth via full decode + naive dot. let decoded = decode_sorf_dict( @@ -1531,7 +1415,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1575,7 +1459,7 @@ mod tests { // has cancellation. let mut rng = XorShift64::new(seed ^ 0xABCD_1234); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1621,7 +1505,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1671,7 +1555,7 @@ mod tests { SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1721,7 +1605,7 @@ mod tests { let mut rng = XorShift64::new(seed ^ (num_rounds as u64)); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1755,7 +1639,7 @@ mod tests { let mut rng = XorShift64::new(seed); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_lhs = constant_vector_f32(&query, num_rows)?; + let const_lhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 673c454b56b..875d7f57e2f 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -22,6 +22,7 @@ use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; @@ -59,6 +60,7 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_norm::L2Norm; +use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; use crate::utils::validate_tensor_float_input; @@ -113,7 +115,7 @@ impl L2Denorm { len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - validate_l2_denorm_children(&normalized, &norms, ctx)?; + validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; // SAFETY: We just validated that it is normalized. unsafe { Self::new_array_unchecked(normalized, norms, len) } @@ -212,7 +214,11 @@ impl ScalarFnVTable for L2Denorm { if let Some(const_norms) = norms_ref.as_opt::() { let norm_scalar = const_norms.scalar(); - vortex_ensure!(norm_scalar.dtype().is_float()); + vortex_ensure!( + norm_scalar.dtype().is_float(), + "L2Denorm constant norms must be a float scalar, got {}", + norm_scalar.dtype(), + ); if let Some(norm_value) = norm_scalar.value() { return execute_l2_denorm_constant_norms( @@ -235,13 +241,11 @@ impl ScalarFnVTable for L2Denorm { .as_extension() .metadata_opt::() .vortex_expect("we already validated this in `return_dtype`"); - let tensor_flat_size = tensor_match.list_size(); + let tensor_flat_size = tensor_match.list_size() as usize; let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; - // TODO(connor): Theoretically we could model this as a multiplication between the - // normalized array and a `RunEnd(Sequence(0, dimensions), norms)`. But since we have - // already canonicalized the array, it is probably not faster to do that. + // TODO(connor): Do we want a "broadcast" expression for the List types, or is this fine? match_each_float_ptype!(flat.ptype(), |T| { let norms = norms.as_slice::(); @@ -418,7 +422,7 @@ pub fn normalize_as_l2_denorm( ) -> VortexResult { let row_count = input.len(); let tensor_match = validate_tensor_float_input(input.dtype())?; - let tensor_flat_size = tensor_match.list_size(); + let tensor_flat_size = tensor_match.list_size() as usize; // Constant fast path: if the input is a constant-backed extension, normalize the single // stored row once and return an `L2Denorm` whose children are both `ConstantArray`s. @@ -515,17 +519,17 @@ pub(crate) fn try_build_constant_l2_denorm( .as_extension() .metadata_opt::() .vortex_expect("caller validated input has AnyTensor metadata"); - let list_size = tensor_match.list_size(); + let list_size = tensor_match.list_size() as usize; let original_nullability = input.dtype().nullability(); let ext_dtype = input.dtype().as_extension().clone(); let storage_fsl_nullability = storage.dtype().nullability(); - // `extract_flat_elements` takes the stride-0 single-row path for `Constant` storage, so - // this is cheap and does not expand the constant to the full column length. - let flat = extract_flat_elements(storage, list_size, ctx)?; + // Materialize just the single stored row; this does not expand the constant to the full + // column length. + let flat = extract_constant_flat_row(storage, ctx)?; let (normalized_fsl_scalar, norms_scalar) = match_each_float_ptype!(flat.ptype(), |T| { - let row = flat.row::(0); + let row = flat.as_slice::(); let mut sum_sq = T::zero(); for &x in row { @@ -600,25 +604,14 @@ fn unit_norm_tolerance(element_ptype: PType) -> f64 { } } -/// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0). -pub fn validate_l2_normalized_rows(input: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - validate_l2_normalized_rows_impl(input, None, ctx) -} - -/// Validates that the `normalized` and `norms` children jointly satisfy the [`L2Denorm`] -/// invariants, which are: +/// Validates that `normalized` and (when supplied) the matching `norms` jointly satisfy the +/// [`L2Denorm`] invariants: /// -/// - All vectors in the normalized array have length 1 or 0. -/// - If the vector has a norm of 0, then the vector must be all 0s. -fn validate_l2_denorm_children( - normalized: &ArrayRef, - norms: &ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult<()> { - validate_l2_normalized_rows_impl(normalized, Some(norms), ctx) -} - -fn validate_l2_normalized_rows_impl( +/// - Every valid row of `normalized` has L2 norm `1.0` or `0.0` (within element-precision +/// tolerance). +/// - When `norms` is supplied, every stored norm is non-negative and any row whose stored norm is +/// `0.0` is exactly the zero vector in `normalized`. +pub fn validate_l2_normalized_rows_against_norms( normalized: &ArrayRef, norms: Option<&ArrayRef>, ctx: &mut ExecutionCtx, @@ -631,7 +624,7 @@ fn validate_l2_normalized_rows_impl( let tensor_match = validate_tensor_float_input(normalized.dtype())?; let element_ptype = tensor_match.element_ptype(); let tolerance = unit_norm_tolerance(element_ptype); - let tensor_flat_size = tensor_match.list_size(); + let tensor_flat_size = tensor_match.list_size() as usize; if let Some(norms) = norms { vortex_ensure_eq!( @@ -697,6 +690,51 @@ fn validate_l2_normalized_rows_impl( Ok(()) } +/// Classification of a binary operand pair by which side (if any) is wrapped in [`L2Denorm`]. +/// +/// Symmetric binary tensor operators (e.g. [`CosineSimilarity`], [`InnerProduct`]) have identical +/// fast paths for "only the lhs is denormalized" and "only the rhs is denormalized", and a separate +/// fast path for "both are denormalized". Rather than hand-rolling the commutative swap at every +/// call site, callers classify their operands with [`Self::classify`] and pattern-match on the +/// returned variant. +/// +/// [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity +/// [`InnerProduct`]: crate::scalar_fns::inner_product::InnerProduct +pub(crate) enum DenormOrientation<'a> { + /// Both operands are [`ExactScalarFn`] arrays. + Both { + lhs: &'a ArrayRef, + rhs: &'a ArrayRef, + }, + /// Exactly one operand is an [`ExactScalarFn`]; the other is plain. + One { + denorm: &'a ArrayRef, + plain: &'a ArrayRef, + }, + /// Neither operand is an [`ExactScalarFn`]. + Neither, +} + +impl<'a> DenormOrientation<'a> { + /// Classify `(lhs, rhs)` by which side (if any) is wrapped in [`L2Denorm`]. + pub(crate) fn classify(lhs: &'a ArrayRef, rhs: &'a ArrayRef) -> Self { + let lhs_denorm = lhs.is::>(); + let rhs_denorm = rhs.is::>(); + match (lhs_denorm, rhs_denorm) { + (true, true) => Self::Both { lhs, rhs }, + (true, false) => Self::One { + denorm: lhs, + plain: rhs, + }, + (false, true) => Self::One { + denorm: rhs, + plain: lhs, + }, + (false, false) => Self::Neither, + } + } +} + #[cfg(test)] mod tests { @@ -719,23 +757,18 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; use vortex_array::extension::datetime::Date; use vortex_array::extension::datetime::TimeUnit; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; - use vortex_buffer::Buffer; use vortex_error::VortexResult; - use crate::fixed_shape::FixedShapeTensor; - use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; - use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; + use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::constant_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; use crate::vector::Vector; @@ -747,20 +780,6 @@ mod tests { result.into_array().execute(&mut ctx) } - fn integer_tensor_array(shape: &[usize], elements: &[i32]) -> VortexResult { - let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - - let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); - let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - fn non_tensor_extension_array() -> VortexResult { let storage = PrimitiveArray::from_iter([1i32, 2]).into_array(); let ext_dtype = @@ -768,16 +787,6 @@ mod tests { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - fn f16_vector_array(dim: u32, elements: &[f32]) -> VortexResult { - let row_count = elements.len() / dim as usize; - let values: Vec<_> = elements.iter().copied().map(half::f16::from_f32).collect(); - let elems: ArrayRef = Buffer::copy_from(values.as_slice()).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - fn tensor_snapshot(array: ArrayRef) -> VortexResult<(DType, Vec, Vec)> { let mut ctx = SESSION.create_execution_ctx(); let ext: ExtensionArray = array.execute(&mut ctx)?; @@ -866,7 +875,7 @@ mod tests { #[test] fn l2_denorm_rejects_integer_tensor_lhs() -> VortexResult<()> { - let lhs = integer_tensor_array(&[2], &[1, 2, 3, 4])?; + let lhs = tensor_array(&[2], &[1i32, 2, 3, 4])?; let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); @@ -888,10 +897,10 @@ mod tests { #[test] fn validate_l2_normalized_rows_accepts_normalized_f16_input() -> VortexResult<()> { - let input = f16_vector_array(2, &[3.0, 4.0, 0.0, 0.0])?; + let input = vector_array(2, &[3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32))?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; - validate_l2_normalized_rows(&roundtrip.child_at(0).clone(), &mut ctx)?; + validate_l2_normalized_rows_against_norms(&roundtrip.child_at(0).clone(), None, &mut ctx)?; Ok(()) } @@ -899,7 +908,7 @@ mod tests { fn validate_l2_normalized_rows_rejects_unnormalized_input() -> VortexResult<()> { let input = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); - let result = validate_l2_normalized_rows(&input, &mut ctx); + let result = validate_l2_normalized_rows_against_norms(&input, None, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -982,7 +991,7 @@ mod tests { #[test] fn normalize_as_l2_denorm_supports_constant_vectors() -> VortexResult<()> { - let input = constant_vector_array(&[3.0, 4.0], 2)?; + let input = Vector::constant_array(&[3.0, 4.0], 2)?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; @@ -996,7 +1005,7 @@ mod tests { // The constant fast path in `normalize_as_l2_denorm` must produce an `L2Denorm` whose // normalized storage and norms child are both still `ConstantArray`s. This is what // allows downstream ops (cosine similarity, inner product) to short-circuit. - let input = constant_vector_array(&[3.0, 4.0], 16)?; + let input = Vector::constant_array(&[3.0, 4.0], 16)?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 59a2e5a45b1..47ba68a35c4 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -47,6 +47,7 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::extract_flat_elements; +use crate::utils::extract_l2_denorm_children; use crate::utils::validate_tensor_float_input; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -84,7 +85,7 @@ impl ScalarFnVTable for L2Norm { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.l2_norm") + ScalarFnId::new("vortex.tensor.l2_norm") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -131,7 +132,7 @@ impl ScalarFnVTable for L2Norm { let tensor_match = ext .metadata_opt::() .vortex_expect("we already validated this in `return_dtype`"); - let tensor_flat_size = tensor_match.list_size(); + let tensor_flat_size = tensor_match.list_size() as usize; let element_ptype = tensor_match.element_ptype(); let norm_dtype = DType::Primitive(element_ptype, ext.nullability()); @@ -139,13 +140,9 @@ impl ScalarFnVTable for L2Norm { // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics // instead of forcing a decode-and-recompute path here. - if let Some(sfn) = input_ref.as_opt::>() { - let norms = sfn - .nth_child(1) - .vortex_expect("L2Denom must have at 2 children"); - + if input_ref.is::>() { + let (_, norms) = extract_l2_denorm_children(&input_ref); vortex_ensure_eq!(norms.dtype(), &norm_dtype); - return Ok(norms); } @@ -290,6 +287,7 @@ mod tests { use crate::scalar_fns::l2_norm::L2Norm; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::literal_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; use crate::vector::Vector; @@ -363,27 +361,13 @@ mod tests { Ok(()) } - /// Builds a [`ConstantArray`] whose scalar is a [`Vector`] extension scalar wrapping a - /// fixed-size list of `elements`, broadcast to `len` rows. - fn constant_vector_extension_array(elements: &[f64], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let ext_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(ext_scalar, len).into_array() - } - /// A constant input whose scalar is a non-null tensor should short-circuit to a /// [`ConstantArray`] output whose scalar is the precomputed norm. Uses [`execute_until`] so /// execution stops at the [`Constant`] encoding instead of canonicalizing into a /// [`PrimitiveArray`]. #[test] fn constant_non_null_input_yields_constant_output() -> VortexResult<()> { - let input = constant_vector_extension_array(&[3.0, 4.0], 4); + let input = literal_vector_array(&[3.0f64, 4.0], 4); let scalar_fn = L2Norm::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array(); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 64f92da384f..b54158a753f 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -70,13 +70,13 @@ impl ScalarFnVTable for SorfTransform { fn fmt_sql( &self, - _options: &Self::Options, + options: &Self::Options, expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { write!(f, "sorf_transform(")?; expr.child(0).fmt_sql(f)?; - write!(f, ")") + write!(f, ", {options})") } fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { @@ -143,9 +143,7 @@ impl ScalarFnVTable for SorfTransform { validity, 0, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::try_new_vector_array(fsl.into_array()) }); } @@ -220,15 +218,8 @@ impl ScalarFnArrayVTable for SorfTransform { view: &ScalarFnArrayView, _session: &VortexSession, ) -> VortexResult>> { - let options = view.options; Ok(Some( - SorfTransformMetadata { - seed: options.seed, - num_rounds: u32::from(options.num_rounds), - dimension: options.dimension, - element_ptype: options.element_ptype as i32, - } - .encode_to_vec(), + SorfTransformMetadata::from(view.options).encode_to_vec(), )) } @@ -240,20 +231,9 @@ impl ScalarFnArrayVTable for SorfTransform { children: &dyn ArrayChildren, _session: &VortexSession, ) -> VortexResult> { - let metadata = SorfTransformMetadata::decode(metadata) - .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; - let options = SorfOptions { - seed: metadata.seed, - num_rounds: u8::try_from(metadata.num_rounds).map_err(|_| { - vortex_err!( - "SorfTransform num_rounds {} does not fit in u8", - metadata.num_rounds - ) - })?, - dimension: metadata.dimension, - element_ptype: metadata.element_ptype(), - }; - validate_sorf_options(&options)?; + let options = SorfTransformMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))? + .to_options()?; // `return_dtype` sets the output FSL's nullability to the child's nullability (see // `return_dtype` above), so we read the child nullability back from the parent dtype. @@ -316,7 +296,37 @@ fn inverse_rotate_typed( let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; + Vector::try_new_vector_array(fsl.into_array()) +} - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) +impl From<&SorfOptions> for SorfTransformMetadata { + fn from(options: &SorfOptions) -> Self { + Self { + seed: options.seed, + num_rounds: u32::from(options.num_rounds), + dimension: options.dimension, + element_ptype: options.element_ptype as i32, + } + } +} + +impl SorfTransformMetadata { + /// Rebuild the [`SorfOptions`] this metadata was serialized from, validating that the wire + /// values are in range. + fn to_options(&self) -> VortexResult { + let num_rounds = u8::try_from(self.num_rounds).map_err(|_| { + vortex_err!( + "SorfTransform num_rounds {} does not fit in u8", + self.num_rounds + ) + })?; + let options = SorfOptions { + seed: self.seed, + num_rounds, + dimension: self.dimension, + element_ptype: self.element_ptype(), + }; + validate_sorf_options(&options)?; + Ok(options) + } } diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 76d65ce9eef..3434058f082 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::ConstantArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::primitive::PrimitiveArrayExt; use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; @@ -43,6 +44,23 @@ pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult( + op_name: &str, + lhs: &'a DType, + rhs: &DType, +) -> VortexResult> { + vortex_ensure!( + lhs.eq_ignore_nullability(rhs), + "{op_name} requires both inputs to have the same dtype, got {lhs} and {rhs}" + ); + validate_tensor_float_input(lhs) +} + /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// /// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively @@ -80,12 +98,12 @@ pub fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { /// The flat primitive elements of a tensor storage array, with typed row access. /// /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a -/// constant input materializes only a single row (stride=0), while a full array uses -/// stride=list_size. +/// constant-backed input materializes only a single row that every index reads (`is_constant = +/// true`), while a full array stores one row per index. pub struct FlatElements { elems: PrimitiveArray, - stride: usize, list_size: usize, + is_constant: bool, } impl FlatElements { @@ -96,47 +114,101 @@ impl FlatElements { } /// Returns the `i`-th row as a typed slice of length `list_size`. + /// + /// When the source was a constant-backed storage, all indices resolve to the single stored + /// row. #[must_use] pub fn row(&self, i: usize) -> &[T] { + let row_idx = if self.is_constant { 0 } else { i }; let slice = self.elems.as_slice::(); - &slice[i * self.stride..][..self.list_size] + &slice[row_idx * self.list_size..][..self.list_size] } } -// TODO(connor): Usage of this function is sometimes incorrect / not performant. -// Make sure to fix them. /// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). /// /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is -/// materialized to avoid expanding it to the full column length. +/// materialized to avoid expanding it to the full column length. Callers that have already +/// confirmed the storage is constant-backed should prefer [`extract_constant_flat_row`]. pub fn extract_flat_elements( storage: &ArrayRef, list_size: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - if let Some(constant) = storage.as_opt::() { - // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge - // amount of data. + // Constant-backed storage: materialize just the single stored row so canonicalization does + // not expand the array to the full column length. + let (source, is_constant) = if let Some(constant) = storage.as_opt::() { let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl: FixedSizeListArray = single.execute(ctx)?; - let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - return Ok(FlatElements { - elems, - stride: 0, - list_size, - }); - } + (single, true) + } else { + (storage.clone(), false) + }; - // Otherwise we have to fully expand all of the data. - let fsl: FixedSizeListArray = storage.clone().execute(ctx)?; + let fsl: FixedSizeListArray = source.execute(ctx)?; let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + vortex_ensure!( + !elems.nullability().is_nullable(), + "tensor storage elements must be non-nullable, got {}", + elems.dtype(), + ); Ok(FlatElements { elems, - stride: list_size, list_size, + is_constant, }) } +/// The single stored row of a constant-backed tensor storage array. +/// +/// Contrast with [`FlatElements`], which exposes arbitrary row indices: a `FlatRow` statically +/// encodes "there is exactly one row available," so call sites that have gated on a constant input +/// read the row via [`Self::as_slice`] instead of `row(0)`. +pub struct FlatRow { + elems: PrimitiveArray, +} + +impl FlatRow { + /// Returns the [`PType`] of the underlying elements. + #[must_use] + pub fn ptype(&self) -> PType { + self.elems.ptype() + } + + /// Returns the stored row as a typed slice. Its length equals the storage scalar's + /// fixed-size-list size. + #[must_use] + pub fn as_slice(&self) -> &[T] { + self.elems.as_slice::() + } +} + +/// Extracts the single stored row from a [`Constant`]-backed tensor storage array. +/// +/// The caller must have confirmed that `storage` is a [`Constant`] encoding whose scalar is a +/// non-null fixed-size list. This is the fast path for constant query vectors: exactly one row is +/// materialized regardless of the column length. +/// +/// # Panics +/// +/// Panics if `storage` is not a [`Constant`] encoding. +pub fn extract_constant_flat_row( + storage: &ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let constant = storage + .as_opt::() + .vortex_expect("extract_constant_flat_row requires Constant-backed storage"); + let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); + let fsl: FixedSizeListArray = single.execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + vortex_ensure!( + !elems.nullability().is_nullable(), + "tensor storage elements must be non-nullable, got {}", + elems.dtype(), + ); + Ok(FlatRow { elems }) +} + /// Extracts the `(normalized, norms)` children from an [`L2Denorm`] scalar function array. /// /// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm @@ -154,15 +226,17 @@ pub fn extract_l2_denorm_children(array: &ArrayRef) -> (ArrayRef, ArrayRef) { #[cfg(test)] pub mod test_helpers { use vortex_array::ArrayRef; + use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; + use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; - use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; + use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -170,81 +244,86 @@ pub mod test_helpers { use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; + use crate::scalar_fns::l2_denorm::L2Denorm; use crate::vector::Vector; - /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + /// Builds a `FixedSizeList` storage array from flat `elements`. The row count is + /// inferred from `elements.len() / list_size`. + fn flat_fsl(elements: &[T], list_size: u32) -> ArrayRef { + let row_count = elements.len() / list_size as usize; + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count).into_array() + } + + /// Builds an FSL-valued [`Scalar`] from `elements` for use as a constant query. + fn fsl_scalar>(elements: &[T]) -> Scalar { + let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable) + } + + /// Builds a [`FixedShapeTensor`] extension array from flat `elements` and a logical shape. /// /// The number of rows is inferred from the total element count divided by the product of the /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. - pub fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + pub fn tensor_array(shape: &[usize], elements: &[T]) -> VortexResult { let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - + let storage = flat_fsl(elements, list_size); let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. - pub fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { - let row_count = elements.len() / dim as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + /// Builds a [`Vector`] extension array from flat `elements` and a vector dimension size. + pub fn vector_array(dim: u32, elements: &[T]) -> VortexResult { + Vector::try_new_vector_array(flat_fsl(elements, dim)) } /// Builds a [`FixedShapeTensor`] extension array whose storage is a [`ConstantArray`], /// representing a single query tensor broadcast to `len` rows. - pub fn constant_tensor_array( + pub fn constant_tensor_array>( shape: &[usize], - elements: &[f64], + elements: &[T], len: usize, ) -> VortexResult { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - + let storage = ConstantArray::new(fsl_scalar(elements), len).into_array(); let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); let ext_dtype = ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a - /// single query vector broadcast to `len` rows. - pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + /// Builds a [`ConstantArray`] whose scalar is itself a [`Vector`] extension scalar, broadcast + /// to `len` rows. This is the shape produced by an `lit(vector_scalar)` literal expression — + /// the constant lives at the extension level rather than inside the FSL storage, in contrast + /// to [`Vector::constant_array`]. + pub fn literal_vector_array>( + elements: &[T], + len: usize, + ) -> ArrayRef { + use vortex_array::extension::EmptyMetadata; + let ext_scalar = Scalar::extension::(EmptyMetadata, fsl_scalar(elements)); + ConstantArray::new(ext_scalar, len).into_array() + } - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + /// Creates an [`L2Denorm`] scalar function array from pre-normalized tensor elements and + /// matching norms. The caller must ensure every row of `normalized_elements` is unit-norm or + /// zero. + pub fn l2_denorm_array( + shape: &[usize], + normalized_elements: &[T], + norms: &[T], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let len = norms.len(); + let normalized = tensor_array(shape, normalized_elements)?; + let norms = + PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); + Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) } /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` diff --git a/vortex-tensor/src/vector/matcher.rs b/vortex-tensor/src/vector/matcher.rs index 0da5384f303..e4838b77426 100644 --- a/vortex-tensor/src/vector/matcher.rs +++ b/vortex-tensor/src/vector/matcher.rs @@ -49,7 +49,7 @@ impl Matcher for AnyVector { let dimensions = *list_size; - assert!(element_dtype.is_float(), "element dtype must be primitive"); + assert!(element_dtype.is_float(), "element dtype must be float"); assert!( !element_dtype.is_nullable(), "element dtype must be non-nullable" diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs index 3c6a8a8c8cc..d077a183713 100644 --- a/vortex-tensor/src/vector/mod.rs +++ b/vortex-tensor/src/vector/mod.rs @@ -3,10 +3,54 @@ //! Vector extension type for fixed-length float vectors (e.g., embeddings). +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::PValue; +use vortex_array::scalar::Scalar; +use vortex_error::VortexResult; + /// The Vector extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Vector; +impl Vector { + /// Helper function for creating a new [`Vector`] [`ExtensionArray`]. + /// + /// # Errors + /// + /// Returns an error if the [`Vector`] extension dtype rejects the storage array. + pub(crate) fn try_new_vector_array(storage: ArrayRef) -> VortexResult { + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage) + .map(|ext| ext.into_array()) + } + + /// Helper function to build a [`Vector`] [`ExtensionArray`] whose storage is a + /// [`ConstantArray`], broadcasting a single vector `elements` across `len` rows. + /// + /// # Errors + /// + /// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype. + pub(crate) fn constant_array>( + elements: &[T], + len: usize, + ) -> VortexResult { + let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + Self::try_new_vector_array(ConstantArray::new(storage_scalar, len).into_array()) + } +} + mod matcher; pub use matcher::AnyVector; diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 81a379683db..c5551c99f2d 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -4,23 +4,14 @@ //! Reusable helpers for building brute-force vector similarity search expressions over //! [`Vector`] extension arrays. //! -//! This module exposes three small building blocks that together make it straightforward to -//! stand up a cosine-similarity-plus-threshold scan on top of a prepared data array: +//! [`build_similarity_search_tree`] broadcasts the query into the shape expected by +//! [`CosineSimilarity`] via `Vector::constant_array` and returns a lazy +//! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. The caller is responsible +//! for preparing `data` (e.g. by running it through [`turboquant_encode`]); this builder does not +//! compress. //! -//! - [`compress_turboquant`] applies the canonical TurboQuant encoding pipeline -//! (`L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)`) to a raw -//! `Vector` array without requiring the caller to plumb the -//! `unstable_encodings` feature flag on the `vortex` facade. -//! - [`build_constant_query_vector`] wraps a single query vector into a -//! [`Vector`] extension array whose storage is a [`ConstantArray`] broadcast -//! across `num_rows` rows. This is the shape expected by -//! [`CosineSimilarity::try_new_array`] for the RHS of a database-vs-query scan. -//! - [`build_similarity_search_tree`] wires everything together into a lazy -//! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. -//! -//! Executing the tree from [`build_similarity_search_tree`] into a -//! [`BoolArray`](vortex_array::arrays::BoolArray) yields one boolean per row indicating whether -//! that row's cosine similarity to the query exceeds `threshold`. +//! Executing the tree into a [`BoolArray`] yields one boolean per row indicating whether that row's +//! cosine similarity to the query exceeds `threshold`. //! //! # Example //! @@ -28,11 +19,12 @@ //! use vortex_array::{ArrayRef, VortexSessionExecute}; //! use vortex_array::arrays::BoolArray; //! use vortex_session::VortexSession; -//! use vortex_tensor::vector_search::{build_similarity_search_tree, compress_turboquant}; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; +//! use vortex_tensor::vector_search::build_similarity_search_tree; //! //! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> { //! let mut ctx = session.create_execution_ctx(); -//! let data = compress_turboquant(data, &mut ctx)?; +//! let data = turboquant_encode(data, &TurboQuantConfig::default(), &mut ctx)?; //! let tree = build_similarity_search_tree(data, query, 0.8)?; //! let _matches: BoolArray = tree.execute(&mut ctx)?; //! Ok(()) @@ -40,98 +32,24 @@ //! ``` //! //! [`Vector`]: crate::vector::Vector -//! [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array +//! [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity +//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode +//! [`BoolArray`]: vortex_array::arrays::BoolArray use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::builtins::ArrayBuiltins; -use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::operators::Operator; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_encode_unchecked; use crate::scalar_fns::cosine_similarity::CosineSimilarity; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::vector::Vector; -/// Apply the canonical TurboQuant encoding pipeline to a `Vector` array. -/// -/// The returned array has the shape -/// `L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)` — exactly what -/// [`crate::encodings::turboquant::TurboQuantScheme`] produces when invoked through -/// `BtrBlocksCompressorBuilder::with_turboquant()`, but without requiring callers to enable -/// the `unstable_encodings` feature on the `vortex` facade. -/// -/// The input `data` must be a [`Vector`] extension array whose element type is `f32` and whose -/// dimensionality is at least -/// [`turboquant::MIN_DIMENSION`](crate::encodings::turboquant::MIN_DIMENSION). The TurboQuant -/// configuration used is [`TurboQuantConfig::default()`] (8-bit codes, 3 SORF rounds, seed 42). -/// -/// # Errors -/// -/// Returns an error if `data` is not a [`Vector`] extension array, if normalization fails, or -/// if the underlying TurboQuant encoder rejects the input shape. -pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(data, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let Some(normalized_ext) = normalized.as_opt::() else { - vortex_bail!("normalize_as_l2_denorm must produce an Extension array child"); - }; - - let config = TurboQuantConfig::default(); - // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero), which is - // the invariant `turboquant_encode_unchecked` expects. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?; - - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) -} - -/// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a single -/// query vector across `num_rows` rows. -/// -/// The element type is inferred from `T` (e.g. `f32` or `f64`). This is the shape expected for -/// the RHS of a database-vs-query [`CosineSimilarity`] scan: the `ScalarFnArray` contract -/// requires both children to have the same length, so rather than hand-rolling a 1-row input we -/// broadcast the query across the whole database. -/// -/// # Errors -/// -/// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype. -pub fn build_constant_query_vector>( - query: &[T], - num_rows: usize, -) -> VortexResult { - let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); - - let children: Vec = query - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, num_rows).into_array(); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) -} - /// Build the lazy similarity-search expression tree for a prepared database array and a /// single query vector. /// @@ -163,7 +81,7 @@ pub fn build_similarity_search_tree>( threshold: T, ) -> VortexResult { let num_rows = data.len(); - let query_vec = build_constant_query_vector(query, num_rows)?; + let query_vec = Vector::constant_array(query, num_rows)?; let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array(); @@ -175,64 +93,16 @@ pub fn build_similarity_search_tree>( #[cfg(test)] mod tests { - use vortex_array::ArrayRef; - use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; - use vortex_array::arrays::Extension; - use vortex_array::arrays::ExtensionArray; - use vortex_array::arrays::FixedSizeListArray; - use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::bool::BoolArrayExt; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::session::ArraySession; - use vortex_array::validity::Validity; - use vortex_buffer::BufferMut; use vortex_error::VortexResult; - use vortex_session::VortexSession; - use super::build_constant_query_vector; use super::build_similarity_search_tree; - use super::compress_turboquant; - use crate::vector::Vector; - - /// Build a `Vector` extension array from a flat f32 slice. Each contiguous - /// group of `DIM` values becomes one row. - fn vector_array(dim: u32, values: &[f32]) -> VortexResult { - let dim_usize = dim as usize; - assert_eq!(values.len() % dim_usize, 0); - let num_rows = values.len() / dim_usize; - - let mut buf = BufferMut::::with_capacity(values.len()); - for &v in values { - buf.push(v); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim, - Validity::NonNullable, - num_rows, - )?; - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - fn test_session() -> VortexSession { - VortexSession::empty().with::() - } - - #[test] - fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> { - let query = vec![1.0f32, 0.0, 0.0, 0.0]; - let rhs = build_constant_query_vector(&query, 5)?; - - assert_eq!(rhs.len(), 5); - assert!(rhs.as_opt::().is_some()); - Ok(()) - } + use crate::encodings::turboquant::TurboQuantConfig; + use crate::encodings::turboquant::turboquant_encode; + use crate::tests::SESSION; + use crate::utils::test_helpers::vector_array; #[test] fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> { @@ -240,7 +110,7 @@ mod tests { let data = vector_array( 3, &[ - 1.0, 0.0, 0.0, // + 1.0f32, 0.0, 0.0, // 0.0, 1.0, 0.0, // 0.0, 0.0, 1.0, // 1.0, 0.0, 0.0, // @@ -249,7 +119,7 @@ mod tests { let query = [1.0f32, 0.0, 0.0]; let tree = build_similarity_search_tree(data, &query, 0.5)?; - let mut ctx = test_session().create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let result: BoolArray = tree.execute(&mut ctx)?; let bits = result.to_bit_buffer(); @@ -287,8 +157,8 @@ mod tests { } let data = vector_array(DIM, &values)?; - let mut ctx = test_session().create_execution_ctx(); - let compressed = compress_turboquant(data, &mut ctx)?; + let mut ctx = SESSION.create_execution_ctx(); + let compressed = turboquant_encode(data, &TurboQuantConfig::default(), &mut ctx)?; assert_eq!(compressed.len(), NUM_ROWS); // Build a tree with a low threshold so row 0 (cosine=1.0 exact) matches. diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 4e2ef9b1beb..e74301cea9f 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -531,7 +531,7 @@ mod turboquant_benches { macro_rules! turboquant_bench { (compress, $dim:literal, $bits:literal, $name:ident) => { paste! { - #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] + #[divan::bench(name = concat!("turboquant_encode_dim", stringify!($dim), "_", stringify!($bits), "bit"))] fn $name(bencher: Bencher) { let normalized_ext = setup_normalized_vector_ext($dim); let config = turboquant_config($bits);