diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index a30d28ab886..0a3e31f32e3 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -56,8 +56,6 @@ use vortex_pco::Pco; use vortex_runend::RunEnd; use vortex_sequence::Sequence; use vortex_sparse::Sparse; -#[cfg(feature = "unstable_encodings")] -use vortex_tensor::encodings::turboquant::TurboQuant; use vortex_utils::aliases::hash_map::HashMap; use vortex_utils::aliases::hash_set::HashSet; use vortex_zigzag::ZigZag; @@ -111,8 +109,6 @@ pub static ALLOWED_ENCODINGS: LazyLock> = LazyLock::new(|| { allowed.insert(RunEnd.id()); allowed.insert(Sequence.id()); allowed.insert(Sparse.id()); - #[cfg(feature = "unstable_encodings")] - allowed.insert(TurboQuant.id()); allowed.insert(ZigZag.id()); #[cfg(feature = "zstd")] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 9f94a0c2d3d..3dd463bf4df 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -29,8 +29,8 @@ half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } -rand = { workspace = true } [dev-dependencies] +rand = { workspace = true } rand_distr = { workspace = true } rstest = { workspace = true } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 8a14e2204f0..bec8df1cb29 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -4,78 +4,6 @@ pub mod vortex_tensor::encodings pub mod vortex_tensor::encodings::turboquant -pub struct vortex_tensor::encodings::turboquant::TurboQuant - -impl vortex_tensor::encodings::turboquant::TurboQuant - -pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId - -pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_BIT_WIDTH: u8 - -pub const vortex_tensor::encodings::turboquant::TurboQuant::MAX_CENTROIDS: usize - -pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32 - -pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuant::new_array_unchecked(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_tensor::encodings::turboquant::TurboQuantArray - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult - -impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuant - -impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquant::TurboQuant - -pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData - -pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant - -pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> vortex_array::buffer::BufferHandle - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer_name(_array: vortex_array::array::view::ArrayView<'_, Self>, _idx: usize) -> core::option::Option - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute(array: vortex_array::array::typed::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::id(&self) -> vortex_array::array::ArrayId - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::nbuffers(_array: vortex_array::array::view::ArrayView<'_, Self>) -> usize - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::reduce_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(array: vortex_array::array::view::ArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::slot_name(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> alloc::string::String - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate(&self, data: &Self::ArrayData, dtype: &vortex_array::dtype::DType, len: usize, slots: &[core::option::Option]) -> vortex_error::VortexResult<()> - -impl vortex_array::array::vtable::operations::OperationsVTable for vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult - -impl vortex_array::array::vtable::validity::ValidityVTable for vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::validity(_array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>) -> vortex_error::VortexResult - -impl vortex_array::arrays::dict::take::TakeExecute for vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::take(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, indices: &vortex_array::array::erased::ArrayRef, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> - -impl vortex_array::arrays::slice::SliceReduce for vortex_tensor::encodings::turboquant::TurboQuant - -pub fn vortex_tensor::encodings::turboquant::TurboQuant::slice(array: vortex_array::array::view::ArrayView<'_, vortex_tensor::encodings::turboquant::TurboQuant>, range: core::ops::range::Range) -> vortex_error::VortexResult> - pub struct vortex_tensor::encodings::turboquant::TurboQuantConfig pub vortex_tensor::encodings::turboquant::TurboQuantConfig::bit_width: u8 @@ -96,44 +24,6 @@ impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantConfig pub fn vortex_tensor::encodings::turboquant::TurboQuantConfig::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -pub struct vortex_tensor::encodings::turboquant::TurboQuantData - -impl vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::bit_width(&self) -> u8 - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32 - -pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::num_rounds(&self) -> u8 - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32 - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> vortex_error::VortexResult - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()> - -impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantData - -impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl core::fmt::Display for vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl vortex_array::hash::ArrayEq for vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::array_eq(&self, other: &Self, _precision: vortex_array::hash::Precision) -> bool - -impl vortex_array::hash::ArrayHash for vortex_tensor::encodings::turboquant::TurboQuantData - -pub fn vortex_tensor::encodings::turboquant::TurboQuantData::array_hash(&self, state: &mut H, _precision: vortex_array::hash::Precision) - pub struct vortex_tensor::encodings::turboquant::TurboQuantScheme impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantScheme @@ -164,28 +54,18 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::matches(&self, ca pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self) -> &'static str -pub trait vortex_tensor::encodings::turboquant::TurboQuantArrayExt: vortex_array::array::typed::TypedArrayRef - -pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::centroids(&self) -> &vortex_array::array::erased::ArrayRef +pub const vortex_tensor::encodings::turboquant::MAX_BIT_WIDTH: u8 -pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::codes(&self) -> &vortex_array::array::erased::ArrayRef +pub const vortex_tensor::encodings::turboquant::MAX_CENTROIDS: usize -pub fn vortex_tensor::encodings::turboquant::TurboQuantArrayExt::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef +pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32 -impl> vortex_tensor::encodings::turboquant::TurboQuantArrayExt for T - -pub fn T::centroids(&self) -> &vortex_array::array::erased::ArrayRef - -pub fn T::codes(&self) -> &vortex_array::array::erased::ArrayRef - -pub fn T::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef +pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::array::typed::Array - pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::AnyFixedShapeTensor @@ -360,9 +240,9 @@ pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity impl vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::new() -> vortex_array::scalar_fn::typed::ScalarFn -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array(lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity @@ -370,13 +250,13 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&se impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity -pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_tensor::scalar_fns::ApproxOptions +pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -396,9 +276,9 @@ pub struct vortex_tensor::scalar_fns::inner_product::InnerProduct impl vortex_tensor::scalar_fns::inner_product::InnerProduct -pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::new() -> vortex_array::scalar_fn::typed::ScalarFn -pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::try_new_array(lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::scalar_fns::inner_product::InnerProduct @@ -406,13 +286,13 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::clone(&self) -> v impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct -pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_tensor::scalar_fns::ApproxOptions +pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_array::scalar_fn::vtable::EmptyOptions pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -432,11 +312,11 @@ pub struct vortex_tensor::scalar_fns::l2_denorm::L2Denorm impl vortex_tensor::scalar_fns::l2_denorm::L2Denorm -pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn +pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::new() -> vortex_array::scalar_fn::typed::ScalarFn -pub unsafe fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::new_array_unchecked(options: &vortex_tensor::scalar_fns::ApproxOptions, normalized: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult +pub unsafe fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::new_array_unchecked(normalized: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, normalized: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, len: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::try_new_array(normalized: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, len: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::scalar_fns::l2_denorm::L2Denorm @@ -444,7 +324,7 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::clone(&self) -> vortex_te impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm -pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_tensor::scalar_fns::ApproxOptions +pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_array::scalar_fn::vtable::EmptyOptions pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity @@ -464,9 +344,9 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::return_dtype(&self, _opti pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> -pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(options: &vortex_tensor::scalar_fns::ApproxOptions, input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows(input: &vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()> pub mod vortex_tensor::scalar_fns::l2_norm @@ -474,9 +354,9 @@ pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm impl vortex_tensor::scalar_fns::l2_norm::L2Norm -pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::new() -> vortex_array::scalar_fn::typed::ScalarFn -pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, child: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::try_new_array(child: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm @@ -484,7 +364,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm -pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_tensor::scalar_fns::ApproxOptions +pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity @@ -504,45 +384,87 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::return_dtype(&self, _options: pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> -pub enum vortex_tensor::scalar_fns::ApproxOptions +pub mod vortex_tensor::scalar_fns::sorf_transform + +pub struct vortex_tensor::scalar_fns::sorf_transform::SorfMatrix + +impl vortex_tensor::scalar_fns::sorf_transform::SorfMatrix + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::inverse_rotate(&self, input: &[f32], output: &mut [f32]) + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::padded_dim(&self) -> usize + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::rotate(&self, input: &[f32], output: &mut [f32]) + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfMatrix::try_new(seed: u64, dimension: usize, num_rounds: usize) -> vortex_error::VortexResult + +pub struct vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::dimension: u32 + +pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::element_ptype: vortex_array::dtype::ptype::PType + +pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::num_rounds: u8 + +pub vortex_tensor::scalar_fns::sorf_transform::SorfOptions::seed: u64 + +impl core::clone::Clone for vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfOptions::clone(&self) -> vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +impl core::cmp::Eq for vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +impl core::cmp::PartialEq for vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfOptions::eq(&self, other: &vortex_tensor::scalar_fns::sorf_transform::SorfOptions) -> bool + +impl core::fmt::Debug for vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_tensor::scalar_fns::sorf_transform::SorfOptions + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_tensor::scalar_fns::sorf_transform::SorfOptions -pub vortex_tensor::scalar_fns::ApproxOptions::Approximate +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) -pub vortex_tensor::scalar_fns::ApproxOptions::Exact +impl core::marker::StructuralPartialEq for vortex_tensor::scalar_fns::sorf_transform::SorfOptions -impl vortex_tensor::scalar_fns::ApproxOptions +pub struct vortex_tensor::scalar_fns::sorf_transform::SorfTransform -pub fn vortex_tensor::scalar_fns::ApproxOptions::is_approx(&self) -> bool +impl vortex_tensor::scalar_fns::sorf_transform::SorfTransform -pub fn vortex_tensor::scalar_fns::ApproxOptions::is_exact(&self) -> bool +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::new(options: &vortex_tensor::scalar_fns::sorf_transform::SorfOptions) -> vortex_array::scalar_fn::typed::ScalarFn -impl core::clone::Clone for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::try_new_array(options: &vortex_tensor::scalar_fns::sorf_transform::SorfOptions, child: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::ApproxOptions::clone(&self) -> vortex_tensor::scalar_fns::ApproxOptions +impl core::clone::Clone for vortex_tensor::scalar_fns::sorf_transform::SorfTransform -impl core::cmp::Eq for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) -> vortex_tensor::scalar_fns::sorf_transform::SorfTransform -impl core::cmp::PartialEq for vortex_tensor::scalar_fns::ApproxOptions +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform -pub fn vortex_tensor::scalar_fns::ApproxOptions::eq(&self, other: &vortex_tensor::scalar_fns::ApproxOptions) -> bool +pub type vortex_tensor::scalar_fns::sorf_transform::SorfTransform::Options = vortex_tensor::scalar_fns::sorf_transform::SorfOptions -impl core::default::Default for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity -pub fn vortex_tensor::scalar_fns::ApproxOptions::default() -> vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -impl core::fmt::Debug for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::ApproxOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl core::fmt::Display for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::id(&self) -> vortex_array::scalar_fn::ScalarFnId -pub fn vortex_tensor::scalar_fns::ApproxOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::is_fallible(&self, _options: &Self::Options) -> bool -impl core::hash::Hash for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::is_null_sensitive(&self, _options: &Self::Options) -> bool -pub fn vortex_tensor::scalar_fns::ApproxOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult -impl core::marker::StructuralPartialEq for vortex_tensor::scalar_fns::ApproxOptions +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> pub mod vortex_tensor::vector diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 7c75269b632..084baf97e57 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -4,7 +4,6 @@ //! Encodings for the different tensor types. // TODO(connor): -// pub mod norm; // Unit-normalized vectors. // pub mod spherical; // Spherical transform on unit-normalized vectors. pub mod turboquant; diff --git a/vortex-tensor/src/encodings/turboquant/array/data.rs b/vortex-tensor/src/encodings/turboquant/array/data.rs deleted file mode 100644 index dd1867eb598..00000000000 --- a/vortex-tensor/src/encodings/turboquant/array/data.rs +++ /dev/null @@ -1,261 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt::Display; -use std::fmt::Formatter; -use std::sync::Arc; - -use vortex_array::ArrayRef; -use vortex_array::TypedArrayRef; -use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_ensure_eq; - -use crate::encodings::turboquant::array::slots::Slot; -use crate::encodings::turboquant::vtable::TurboQuant; - -/// TurboQuant array data. -/// -/// TurboQuant is a lossy vector quantization encoding for [`Vector`](crate::vector::Vector) -/// extension arrays. It stores quantized coordinate codes for unit-norm vectors, along with shared -/// codebook centroids and the parameters of the current structured rotation. -/// -/// Norms should be stored externally in the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) -/// `ScalarFnArray` wrapper. -/// -/// See the [module docs](crate::encodings::turboquant) for algorithmic details. -/// -/// Note that degenerate TurboQuant arrays have zero rows and `bit_width == 0`, with all slots -/// empty. -#[derive(Clone, Debug)] -pub struct TurboQuantData { - /// The vector dimension `d`, cached from the `FixedSizeList` storage dtype's list size. - /// - /// Stored as a convenience field to avoid repeatedly extracting it from `dtype`. - pub(crate) dimension: u32, - - /// The number of bits per coordinate (0-8), derived from `log2(centroids.len())`. - /// - /// This is 0 for degenerate empty arrays. - pub(crate) bit_width: u8, - - /// The number of sign-diagonal + WHT rounds in the structured rotation. - /// - /// This is 0 for degenerate empty arrays. - pub(crate) num_rounds: u8, -} - -impl Display for TurboQuantData { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "dimension: {}, bit_width: {}, num_rounds: {}", - self.dimension, self.bit_width, self.num_rounds - ) - } -} - -impl TurboQuantData { - /// Build a `TurboQuantData` with validation. - /// - /// # Errors - /// - /// Returns an error if: - /// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION). - /// - `bit_width` is greater than [`MAX_BIT_WIDTH`](TurboQuant::MAX_BIT_WIDTH). - pub fn try_new(dimension: u32, bit_width: u8, num_rounds: u8) -> VortexResult { - vortex_ensure!( - dimension >= TurboQuant::MIN_DIMENSION, - "TurboQuant requires dimension >= {}, got {dimension}", - TurboQuant::MIN_DIMENSION - ); - vortex_ensure!( - bit_width <= TurboQuant::MAX_BIT_WIDTH, - "bit_width is expected to be between 0 and {}, got {bit_width}", - TurboQuant::MAX_BIT_WIDTH - ); - - Ok(Self { - dimension, - bit_width, - num_rounds, - }) - } - - /// Build a `TurboQuantData` without validation. - /// - /// # Safety - /// - /// The caller must ensure: - /// - /// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION). - /// - `bit_width` is in the range `[0, MAX_BIT_WIDTH]`. - /// - `num_rounds` is >= 1 (or 0 for degenerate empty arrays). - /// - /// Violating these invariants may produce incorrect results during decompression. - pub unsafe fn new_unchecked(dimension: u32, bit_width: u8, num_rounds: u8) -> Self { - Self { - dimension, - bit_width, - num_rounds, - } - } - - /// Validates the components that would be used to create a `TurboQuantData`. - /// - /// This function checks all the invariants required by [`new_unchecked`](Self::new_unchecked). - pub fn validate( - dtype: &DType, - codes: &ArrayRef, - centroids: &ArrayRef, - rotation_signs: &ArrayRef, - ) -> VortexResult<()> { - let vector_metadata = TurboQuant::validate_dtype(dtype)?; - let dimension = vector_metadata.dimensions(); - let padded_dim = dimension.next_power_of_two(); - - // TurboQuant arrays are always non-nullable. Nullability should be handled by the external - // L2Denorm ScalarFnArray wrapper. - vortex_ensure!( - !dtype.is_nullable(), - "TurboQuant dtype must be non-nullable, got {dtype}", - ); - - // Codes must be a non-nullable FixedSizeList with list_size == padded_dim. - let expected_codes_dtype = DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, - Nullability::NonNullable, - ); - vortex_ensure_eq!( - *codes.dtype(), - expected_codes_dtype, - "codes dtype does not match expected {expected_codes_dtype}", - ); - - // Centroids are always f32 regardless of element type. - let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - vortex_ensure_eq!( - *centroids.dtype(), - centroids_dtype, - "centroids dtype must be non-nullable f32", - ); - - // Rotation signs must be a FixedSizeList with list_size == padded_dim. The FSL length - // is the number of rotation rounds. - let expected_signs_dtype = DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, - Nullability::NonNullable, - ); - vortex_ensure_eq!( - *rotation_signs.dtype(), - expected_signs_dtype, - "rotation_signs dtype does not match expected {expected_signs_dtype}", - ); - // Degenerate (empty) case: all children must be empty, and bit_width is 0. - let num_rows = codes.len(); - if num_rows == 0 { - vortex_ensure!( - centroids.is_empty(), - "degenerate TurboQuant must have empty centroids, got length {}", - centroids.len() - ); - vortex_ensure!( - rotation_signs.is_empty(), - "degenerate TurboQuant must have empty rotation_signs, got length {}", - rotation_signs.len() - ); - return Ok(()); - } - - vortex_ensure!( - !rotation_signs.is_empty(), - "rotation_signs must have at least 1 round" - ); - - // Non-degenerate: derive and validate bit_width from centroids. - let num_centroids = centroids.len(); - vortex_ensure!( - num_centroids.is_power_of_two() - && (2..=TurboQuant::MAX_CENTROIDS).contains(&num_centroids), - "centroids length must be a power of 2 in [2, {}], got {num_centroids}", - TurboQuant::MAX_CENTROIDS - ); - - #[expect( - clippy::cast_possible_truncation, - reason = "Guaranteed to be [1,8] by the preceding power-of-2 and range checks." - )] - let bit_width = num_centroids.trailing_zeros() as u8; - vortex_ensure!( - (1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width), - "derived bit_width must be 1-{}, got {bit_width}", - TurboQuant::MAX_BIT_WIDTH - ); - - Ok(()) - } - - pub(crate) fn make_slots( - codes: ArrayRef, - centroids: ArrayRef, - rotation_signs: ArrayRef, - ) -> Vec> { - let mut slots = vec![None; Slot::COUNT]; - slots[Slot::Codes as usize] = Some(codes); - slots[Slot::Centroids as usize] = Some(centroids); - slots[Slot::RotationSigns as usize] = Some(rotation_signs); - slots - } - - /// The vector dimension `d`, as stored in the [`Vector`](crate::vector::Vector) extension - /// dtype's `FixedSizeList` storage. - pub fn dimension(&self) -> u32 { - self.dimension - } - - /// MSE bits per coordinate (1-MAX_BIT_WIDTH for non-empty arrays, 0 for degenerate empty arrays). - pub fn bit_width(&self) -> u8 { - self.bit_width - } - - /// The number of sign-diagonal + WHT rounds in the structured rotation. - pub fn num_rounds(&self) -> u8 { - self.num_rounds - } - - /// Padded dimension (next power of 2 >= [`dimension`](Self::dimension)). - /// - /// The current Walsh-Hadamard-based structured rotation requires power-of-2 input, so - /// non-power-of-2 dimensions are zero-padded to this value. - pub fn padded_dim(&self) -> u32 { - self.dimension.next_power_of_two() - } -} - -pub trait TurboQuantArrayExt: TypedArrayRef { - fn codes(&self) -> &ArrayRef { - self.as_ref().slots()[Slot::Codes as usize] - .as_ref() - .vortex_expect("TurboQuantArray codes slot") - } - - fn centroids(&self) -> &ArrayRef { - self.as_ref().slots()[Slot::Centroids as usize] - .as_ref() - .vortex_expect("TurboQuantArray centroids slot") - } - - fn rotation_signs(&self) -> &ArrayRef { - self.as_ref().slots()[Slot::RotationSigns as usize] - .as_ref() - .vortex_expect("TurboQuantArray rotation_signs slot") - } -} - -impl> TurboQuantArrayExt for T {} diff --git a/vortex-tensor/src/encodings/turboquant/array/mod.rs b/vortex-tensor/src/encodings/turboquant/array/mod.rs deleted file mode 100644 index e82313f1dc6..00000000000 --- a/vortex-tensor/src/encodings/turboquant/array/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant array definition: stores quantized coordinate codes, norms, centroids (codebook), -//! and rotation signs. - -pub(crate) mod data; -pub(crate) mod slots; - -pub(crate) mod centroids; -pub(crate) mod rotation; diff --git a/vortex-tensor/src/encodings/turboquant/array/slots.rs b/vortex-tensor/src/encodings/turboquant/array/slots.rs deleted file mode 100644 index c4fe7e0c5bf..00000000000 --- a/vortex-tensor/src/encodings/turboquant/array/slots.rs +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -/// Slot positions for TurboQuantArray children. -/// -/// Norms are not stored in the TurboQuantArray. They live in the external [`L2Denorm`] -/// ScalarFnArray wrapper returned by [`turboquant_encode`]. -/// -/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm -/// [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode -#[repr(usize)] -#[derive(Clone, Copy, Debug)] -pub(crate) enum Slot { - Codes = 0, - Centroids = 1, - RotationSigns = 2, -} - -impl Slot { - pub(crate) const COUNT: usize = 3; - - pub(crate) fn name(self) -> &'static str { - match self { - Self::Codes => "codes", - Self::Centroids => "centroids", - Self::RotationSigns => "rotation_signs", - } - } - - pub(crate) fn from_index(idx: usize) -> Self { - match idx { - 0 => Self::Codes, - 1 => Self::Centroids, - 2 => Self::RotationSigns, - _ => vortex_error::vortex_panic!("invalid slot index {idx}"), - } - } -} diff --git a/vortex-tensor/src/encodings/turboquant/array/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs similarity index 96% rename from vortex-tensor/src/encodings/turboquant/array/centroids.rs rename to vortex-tensor/src/encodings/turboquant/centroids.rs index e3027d0a58a..cd7b8b889ce 100644 --- a/vortex-tensor/src/encodings/turboquant/array/centroids.rs +++ b/vortex-tensor/src/encodings/turboquant/centroids.rs @@ -15,7 +15,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_utils::aliases::dash_map::DashMap; -use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::MAX_BIT_WIDTH; +use crate::encodings::turboquant::MIN_DIMENSION; /// The maximum iterations for Max-Lloyd algorithm when computing centroids. const MAX_ITERATIONS: usize = 200; @@ -37,14 +38,14 @@ static CENTROID_CACHE: LazyLock>> = LazyLock::new(Da /// `dimension`-dimensional space. pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { vortex_ensure!( - (1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width), + (1..=MAX_BIT_WIDTH).contains(&bit_width), "TurboQuant bit_width must be 1-{}, got {bit_width}", - TurboQuant::MAX_BIT_WIDTH + MAX_BIT_WIDTH ); vortex_ensure!( - dimension >= TurboQuant::MIN_DIMENSION, + dimension >= MIN_DIMENSION, "TurboQuant dimension must be >= {}, got {dimension}", - TurboQuant::MIN_DIMENSION + MIN_DIMENSION ); if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { @@ -92,7 +93,7 @@ impl HalfIntExponent { /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` /// where `C_d` is the normalizing constant. fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { - debug_assert!((1..=TurboQuant::MAX_BIT_WIDTH).contains(&bit_width)); + debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); let num_centroids = 1usize << bit_width; // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. @@ -175,7 +176,6 @@ fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { /// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`. /// This is significantly faster than the general `powf` which goes through /// `exp(exponent * ln(base))`. -#[inline] fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { let base = (1.0 - x_val * x_val).max(0.0); @@ -199,7 +199,7 @@ pub fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { /// Find the index of the nearest centroid using precomputed decision boundaries. /// -/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding +/// `boundaries` must be the output of [`compute_centroid_boundaries`] for the corresponding /// centroids. Uses binary search on the midpoints, avoiding distance comparisons /// in the inner loop. #[inline] @@ -210,7 +210,7 @@ pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { ); debug_assert!( boundaries.len() <= 256, // 1 << 8 - "boundaries must be sorted" + "too many boundaries" ); #[expect( diff --git a/vortex-tensor/src/encodings/turboquant/scheme/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs similarity index 52% rename from vortex-tensor/src/encodings/turboquant/scheme/compress.rs rename to vortex-tensor/src/encodings/turboquant/compress.rs index f787a80cbf7..b0173bbe36c 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -6,35 +6,42 @@ //! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector) //! extension array whose rows are already L2-normalized (unit norm). Normalization is handled //! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm), -//! which the [`TurboQuantScheme`](super::TurboQuantScheme) calls before invoking this function. +//! which the [`TurboQuantScheme`] calls before invoking this function. +//! +//! [`TurboQuantScheme`]: crate::encodings::turboquant::TurboQuantScheme use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_fastlanes::bitpack_compress::bitpack_encode; - -use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::array::centroids::compute_centroid_boundaries; -use crate::encodings::turboquant::array::centroids::find_nearest_centroid; -use crate::encodings::turboquant::array::centroids::get_centroids; -use crate::encodings::turboquant::array::rotation::RotationMatrix; -use crate::encodings::turboquant::vtable::TurboQuantArray; + +use crate::encodings::turboquant::MAX_BIT_WIDTH; +use crate::encodings::turboquant::MIN_DIMENSION; +use crate::encodings::turboquant::centroids::compute_centroid_boundaries; +use crate::encodings::turboquant::centroids::find_nearest_centroid; +use crate::encodings::turboquant::centroids::get_centroids; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; +use crate::scalar_fns::sorf_transform::SorfMatrix; +use crate::scalar_fns::sorf_transform::SorfOptions; +use crate::scalar_fns::sorf_transform::SorfTransform; +use crate::utils::cast_to_f32; +use crate::vector::AnyVector; +use crate::vector::Vector; /// Configuration for TurboQuant encoding. #[derive(Clone, Debug)] @@ -50,50 +57,15 @@ pub struct TurboQuantConfig { impl Default for TurboQuantConfig { fn default() -> Self { Self { - bit_width: TurboQuant::MAX_BIT_WIDTH, + bit_width: MAX_BIT_WIDTH, seed: Some(42), num_rounds: 3, } } } -/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray for quantization. -/// -/// All quantization (rotation, centroid lookup) happens in f32. f16 is upcast; f64 is truncated. -fn extract_f32_elements( - fsl: &FixedSizeListArray, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let elements = fsl.elements(); - let primitive = elements.clone().execute::(ctx)?; - let ptype = primitive.ptype(); - - match ptype { - PType::F16 => Ok(primitive - .as_slice::() - .iter() - .map(|&v| f32::from(v)) - .collect()), - PType::F32 => Ok(primitive), - PType::F64 => Ok(primitive - .as_slice::() - .iter() - .map(|&v| { - #[expect( - clippy::cast_possible_truncation, - reason = "TurboQuant quantization operates in f32, so f64 inputs are intentionally downcast" - )] - let v = v as f32; - v - }) - .collect()), - _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"), - } -} - /// Shared intermediate results from the quantization loop. struct QuantizationResult { - rotation: RotationMatrix, centroids: Vec, all_indices: BufferMut, padded_dim: usize, @@ -115,12 +87,13 @@ fn turboquant_quantize_core( usize::try_from(fsl.list_size()).vortex_expect("u32 FixedSizeList dimension fits in usize"); let num_rows = fsl.len(); - let rotation = RotationMatrix::try_new(seed, dimension, num_rounds as usize)?; + let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; let padded_dim = rotation.padded_dim(); let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - let f32_elements = extract_f32_elements(fsl, ctx)?; + let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + let f32_elements = cast_to_f32(elements_prim)?; let centroids = get_centroids(padded_dim_u32, bit_width)?; let boundaries = compute_centroid_boundaries(¢roids); @@ -129,7 +102,7 @@ fn turboquant_quantize_core( let mut padded = vec![0.0f32; padded_dim]; let mut rotated = vec![0.0f32; padded_dim]; - let f32_slice = f32_elements.as_slice::(); + let f32_slice = f32_elements.as_slice(); for row in 0..num_rows { let x = &f32_slice[row * dimension..(row + 1) * dimension]; @@ -145,48 +118,44 @@ fn turboquant_quantize_core( } Ok(QuantizationResult { - rotation, centroids, all_indices, padded_dim, }) } -/// Build a `TurboQuantArray` from quantization results. +/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`. /// -/// The `ext_dtype` must be a non-nullable [`Vector`](crate::vector::Vector) extension dtype. -fn build_turboquant( +/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes +/// indexes into the centroid codebook. The Dict can be independently sliced, taken, +/// or executed (dequantized) without knowledge of the rotation. +fn build_quantized_fsl( num_rows: usize, - core: QuantizationResult, - ext_dtype: &DType, -) -> VortexResult { - let padded_dim = core.padded_dim; + all_indices: BufferMut, + centroids: &[f32], + padded_dim: usize, +) -> VortexResult { + let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); + + let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); + centroids_buf.extend_from_slice(centroids); + let centroids_array = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + + let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; + let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - let codes_elements = - PrimitiveArray::new::(core.all_indices.freeze(), Validity::NonNullable); - let codes = FixedSizeListArray::try_new( - codes_elements.into_array(), + Ok(FixedSizeListArray::try_new( + dict.into_array(), padded_dim_u32, Validity::NonNullable, num_rows, )? - .into_array(); - - // TODO(perf): `get_centroids` returns Vec; could avoid the copy by - // supporting Buffer::from(Vec) or caching as Buffer directly. - let mut centroids_buf = BufferMut::::with_capacity(core.centroids.len()); - centroids_buf.extend_from_slice(&core.centroids); - let centroids_array = - PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable).into_array(); - - let rotation_signs = bitpack_rotation_signs(&core.rotation)?; - - TurboQuant::try_new_array(ext_dtype.clone(), codes, centroids_array, rotation_signs) + .into_array()) } /// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a -/// [`TurboQuantArray`]. +/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`. /// /// The input must be a non-nullable Vector extension array whose rows are already unit-norm. /// **Null vectors are not supported.** The caller must normalize and strip nullability before @@ -196,9 +165,9 @@ fn build_turboquant( /// [`turboquant_encode_unchecked`] to skip this check when the caller has just performed /// normalization. /// -/// The returned array is a plain [`TurboQuantArray`] that decompresses to unit-norm vectors. -/// The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray if the original -/// magnitudes need to be restored. +/// The returned array is a `SorfTransform` ScalarFnArray wrapping `FSL(Dict)` that decompresses +/// to unit-norm vectors. The caller is responsible for wrapping it in an [`L2Denorm`] ScalarFnArray +/// if the original magnitudes need to be restored. /// /// [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm /// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm @@ -214,14 +183,15 @@ pub fn turboquant_encode( "TurboQuant input must be non-nullable (normalize first via L2Denorm), got {ext_dtype}", ); - validate_l2_normalized_rows(ext.as_ref().clone(), ctx)?; + validate_l2_normalized_rows(ext.as_ref(), ctx)?; // SAFETY: We just validated that the input is non-nullable and all rows are unit-norm. unsafe { turboquant_encode_unchecked(ext, config, ctx) } } /// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a -/// [`TurboQuantArray`], without validating the unit-norm precondition. +/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm +/// precondition. /// /// # Safety /// @@ -242,71 +212,64 @@ pub unsafe fn turboquant_encode_unchecked( let fsl = storage.clone().execute::(ctx)?; vortex_ensure!( - config.bit_width >= 1 && config.bit_width <= TurboQuant::MAX_BIT_WIDTH, - "bit_width must be 1-{}, got {}", - TurboQuant::MAX_BIT_WIDTH, + config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH, + "bit_width must be 1-{MAX_BIT_WIDTH}, got {}", config.bit_width ); let dimension = fsl.list_size(); vortex_ensure!( - dimension >= TurboQuant::MIN_DIMENSION, - "TurboQuant requires dimension >= {}, got {dimension}", - TurboQuant::MIN_DIMENSION + dimension >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}", ); + let vector_metadata = ext_dtype.as_extension().metadata::(); + let element_ptype = vector_metadata.element_ptype(); + + let seed = config.seed.unwrap_or(42); + let num_rows = fsl.len(); + if fsl.is_empty() { let padded_dim = dimension.next_power_of_two(); - let empty_codes = FixedSizeListArray::try_new( - PrimitiveArray::empty::(Nullability::NonNullable).into_array(), - padded_dim, - Validity::NonNullable, - 0, - )?; - + let empty_codes = PrimitiveArray::empty::(Nullability::NonNullable); let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_signs = FixedSizeListArray::try_new( - PrimitiveArray::empty::(Nullability::NonNullable).into_array(), + let empty_dict = + DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?; + let empty_fsl = FixedSizeListArray::try_new( + empty_dict.into_array(), padded_dim, Validity::NonNullable, 0, )?; - - return Ok(TurboQuant::try_new_array( - ext_dtype, - empty_codes.into_array(), - empty_centroids.into_array(), - empty_signs.into_array(), - )? - .into_array()); + let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?; + + let sorf_options = SorfOptions { + seed, + num_rounds: config.num_rounds, + dimension, + element_ptype, + }; + return Ok( + SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(), + ); } - let seed = config.seed.unwrap_or(42); - let num_rows = fsl.len(); let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; - - Ok(build_turboquant(num_rows, core, &ext_dtype)?.into_array()) + let quantized_fsl = + build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?; + let padded_vector = wrap_padded_as_vector(quantized_fsl)?; + + let sorf_options = SorfOptions { + seed, + num_rounds: config.num_rounds, + dimension, + element_ptype, + }; + Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) } -/// Export rotation signs as a `FixedSizeListArray` wrapping a 1-bit [`BitPackedArray`]. -/// -/// The rotation matrix's `num_rounds * padded_dim` sign values are exported as 0/1 u8 values in -/// inverse application order, bitpacked to 1 bit per sign, then wrapped in a -/// `FixedSizeListArray` with `list_size = padded_dim` and `len = num_rounds`. -fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult { - let signs_u8 = rotation.export_inverse_signs_u8(); - let num_rounds = rotation.num_rounds(); - let padded_dim = u32::try_from(rotation.padded_dim()).vortex_expect("padded_dim fits in u32"); - - let mut buf = BufferMut::::with_capacity(signs_u8.len()); - buf.extend_from_slice(&signs_u8); - let prim = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let bitpacked = bitpack_encode(&prim, 1, None)?; - - let fsl = FixedSizeListArray::try_new( - bitpacked.into_array(), - padded_dim, - Validity::NonNullable, - num_rounds, - )?; - Ok(fsl.into_array()) +/// Wrap an `FSL` in a [`Vector`](crate::vector::Vector) extension so it can be +/// passed as the child of [`SorfTransform`], which expects a `Vector` input. +fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl).into_array()) } diff --git a/vortex-tensor/src/encodings/turboquant/compute/mod.rs b/vortex-tensor/src/encodings/turboquant/compute/mod.rs deleted file mode 100644 index eab759ef0b9..00000000000 --- a/vortex-tensor/src/encodings/turboquant/compute/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Compute pushdown implementations for TurboQuant. - -mod ops; -mod slice; -mod take; - -pub(crate) mod rules; - -use num_traits::Float; -use num_traits::FromPrimitive; -use vortex_error::VortexExpect; - -/// Convert an f32 value to a float type `T`. -/// -/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the -/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening. -pub(crate) fn float_from_f32(v: f32) -> T { - FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible") -} diff --git a/vortex-tensor/src/encodings/turboquant/compute/ops.rs b/vortex-tensor/src/encodings/turboquant/compute/ops.rs deleted file mode 100644 index 4999816319b..00000000000 --- a/vortex-tensor/src/encodings/turboquant/compute/ops.rs +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::ArrayView; -use vortex_array::ExecutionCtx; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::slice::SliceReduce; -use vortex_array::scalar::Scalar; -use vortex_array::vtable::OperationsVTable; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; - -use crate::encodings::turboquant::TurboQuant; - -impl OperationsVTable for TurboQuant { - fn scalar_at( - array: ArrayView<'_, TurboQuant>, - index: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult { - // Slice to single row, decompress that one row. - let Some(sliced) = ::slice(array, index..index + 1)? else { - vortex_bail!("slice returned None for index {index}") - }; - let decoded = sliced.execute::(ctx)?; - decoded.scalar_at(0) - } -} diff --git a/vortex-tensor/src/encodings/turboquant/compute/rules.rs b/vortex-tensor/src/encodings/turboquant/compute/rules.rs deleted file mode 100644 index 39919a8c1ec..00000000000 --- a/vortex-tensor/src/encodings/turboquant/compute/rules.rs +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::arrays::dict::TakeExecuteAdaptor; -use vortex_array::arrays::slice::SliceReduceAdaptor; -use vortex_array::kernel::ParentKernelSet; -use vortex_array::optimizer::rules::ParentRuleSet; - -use crate::encodings::turboquant::TurboQuant; - -pub(crate) static RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&SliceReduceAdaptor(TurboQuant))]); - -pub(crate) static PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(TurboQuant))]); diff --git a/vortex-tensor/src/encodings/turboquant/compute/slice.rs b/vortex-tensor/src/encodings/turboquant/compute/slice.rs deleted file mode 100644 index c19f604e36a..00000000000 --- a/vortex-tensor/src/encodings/turboquant/compute/slice.rs +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ops::Range; - -use vortex_array::ArrayRef; -use vortex_array::ArrayView; -use vortex_array::IntoArray; -use vortex_array::arrays::slice::SliceReduce; -use vortex_error::VortexResult; - -use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::TurboQuantArrayExt; - -impl SliceReduce for TurboQuant { - fn slice( - array: ArrayView<'_, TurboQuant>, - range: Range, - ) -> VortexResult> { - let sliced_codes = array.codes().slice(range)?; - - Ok(Some( - TurboQuant::try_new_array( - array.dtype().clone(), - sliced_codes, - array.centroids().clone(), - array.rotation_signs().clone(), - )? - .into_array(), - )) - } -} diff --git a/vortex-tensor/src/encodings/turboquant/compute/take.rs b/vortex-tensor/src/encodings/turboquant/compute/take.rs deleted file mode 100644 index 19a2e65e393..00000000000 --- a/vortex-tensor/src/encodings/turboquant/compute/take.rs +++ /dev/null @@ -1,33 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::ArrayRef; -use vortex_array::ArrayView; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::dict::TakeExecute; -use vortex_error::VortexResult; - -use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::TurboQuantArrayExt; - -impl TakeExecute for TurboQuant { - fn take( - array: ArrayView<'_, TurboQuant>, - indices: &ArrayRef, - _ctx: &mut ExecutionCtx, - ) -> VortexResult> { - // FSL children handle per-row take natively. - let taken_codes = array.codes().take(indices.clone())?; - - Ok(Some( - TurboQuant::try_new_array( - array.dtype().clone(), - taken_codes, - array.centroids().clone(), - array.rotation_signs().clone(), - )? - .into_array(), - )) - } -} diff --git a/vortex-tensor/src/encodings/turboquant/metadata.rs b/vortex-tensor/src/encodings/turboquant/metadata.rs deleted file mode 100644 index b4a7a8aef08..00000000000 --- a/vortex-tensor/src/encodings/turboquant/metadata.rs +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Protobuf-backed metadata for TurboQuant encoding. - -use prost::Message; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; - -use crate::encodings::turboquant::TurboQuant; - -/// Serialized metadata for TurboQuant arrays. -#[derive(Clone, PartialEq, Message)] -pub(super) struct TurboQuantMetadata { - /// The number of bits per coordinate, which must be <= [`TurboQuant::MAX_BIT_WIDTH`]. - #[prost(uint32, required, tag = "1")] - bit_width: u32, - - /// The number of sign-diagonal + WHT rounds in the structured rotation. - #[prost(uint32, required, tag = "2")] - num_rounds: u32, -} - -impl TurboQuantMetadata { - /// Creates metadata for the given bit width and number of rotation rounds. - pub(super) fn new(bit_width: u8, num_rounds: u8) -> Self { - Self { - bit_width: u32::from(bit_width), - num_rounds: u32::from(num_rounds), - } - } - - /// Returns the validated TurboQuant bit width. - pub(super) fn bit_width(&self) -> VortexResult { - let bit_width = u8::try_from(self.bit_width).map_err(|_| { - vortex_err!( - "TurboQuant bit_width must fit into u8, got {}", - self.bit_width - ) - })?; - vortex_ensure!( - bit_width <= TurboQuant::MAX_BIT_WIDTH, - "bit_width is expected to be between 0 and {}, got {bit_width}", - TurboQuant::MAX_BIT_WIDTH - ); - - Ok(bit_width) - } - - /// Returns the validated number of rotation rounds. - /// - /// Returns 0 for degenerate (empty) arrays, which is validated at a higher level. - pub(super) fn num_rounds(&self) -> VortexResult { - u8::try_from(self.num_rounds).map_err(|_| { - vortex_err!( - "TurboQuant num_rounds must fit into u8, got {}", - self.num_rounds - ) - }) - } -} - -#[cfg(test)] -mod tests { - use prost::Message; - use rstest::rstest; - use vortex_error::VortexResult; - - use super::TurboQuantMetadata; - - #[rstest] - #[case(0, 0)] - #[case(0, 3)] - #[case(3, 1)] - #[case(8, 3)] - #[case(8, 5)] - fn protobuf_metadata_roundtrip( - #[case] bit_width: u8, - #[case] num_rounds: u8, - ) -> VortexResult<()> { - let bytes = TurboQuantMetadata::new(bit_width, num_rounds).encode_to_vec(); - let decoded = TurboQuantMetadata::decode(bytes.as_slice())?; - assert_eq!(decoded.bit_width()?, bit_width); - assert_eq!(decoded.num_rounds()?, num_rounds); - - Ok(()) - } -} diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index ff3259788e9..49e53effd0d 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -17,25 +17,34 @@ //! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) //! using MSE-optimal scalar quantization on coordinates of a rotated unit vector. //! -//! The `TurboQuantArray` stores only the quantized unit-norm vector data (codes, centroids, -//! rotation signs). Per-vector L2 norms are stored separately in an [`L2Denorm`] ScalarFnArray -//! wrapper. The [`turboquant_encode`] function returns this wrapper: +//! The encoding is decomposed into independently swappable layers: +//! +//! - **Normalization**: [`L2Denorm`] stores per-vector norms and wraps the compressed child. +//! - **Orthogonal transform**: [`SorfTransform`] records the SORF structured orthogonal +//! transform and applies the inverse at decode time. +//! - **Quantization**: `DictArray(codes, centroids)` wrapped in `FixedSizeListArray` stores +//! the per-coordinate codebook indices. +//! +//! The full encoded tree is: //! //! ```text -//! ScalarFnArray(L2Denorm, [TurboQuantArray, norms]) +//! ScalarFnArray(L2Denorm, [ +//! ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +//! norms +//! ]) //! ``` //! -//! When executed, the TQ array decompresses to unit-norm vectors, and the [`L2Denorm`] function -//! lazily re-applies the stored norms to reconstruct the original magnitudes. +//! When executed, the tree automatically decompresses: Dict dequantizes codes → SorfTransform +//! inverse-rotates → L2Denorm re-applies norms → original vectors (approximately). //! //! [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm -//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode +//! [`SorfTransform`]: crate::scalar_fns::sorf_transform::SorfTransform //! //! The TurboQuant paper analyzes a full random orthogonal rotation. The current Vortex //! implementation instead uses a fixed 3-round Walsh-Hadamard-based structured transform with -//! random sign diagonals. This is a practical approximation chosen for encode/decode efficiency, -//! and should be understood as an implementation choice rather than the exact construction used in -//! the paper's proofs. +//! random sign diagonals generated by Vortex's frozen local SplitMix64 stream. This is a practical +//! approximation chosen for encode/decode efficiency, and should be understood as an +//! implementation choice rather than the exact construction used in the paper's proofs. //! //! The current encoding is also intentionally MSE-only. It does not yet implement the paper's QJL //! residual correction for unbiased inner-product estimation, and it still uses internal @@ -95,7 +104,6 @@ //! use vortex_array::session::ArraySession; //! use vortex_session::VortexSession; //! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_unchecked}; -//! use vortex_tensor::scalar_fns::ApproxOptions; //! use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; //! use vortex_tensor::vector::Vector; //! @@ -117,9 +125,7 @@ //! // Normalize, then quantize the normalized child at 2 bits per coordinate. //! let session = VortexSession::empty().with::(); //! let mut ctx = session.create_execution_ctx(); -//! let l2_denorm = normalize_as_l2_denorm( -//! &ApproxOptions::Exact, ext.into_array(), &mut ctx, -//! ).unwrap(); +//! let l2_denorm = normalize_as_l2_denorm(ext.into_array(), &mut ctx).unwrap(); //! let normalized = l2_denorm.child_at(0).clone(); //! //! let normalized_ext = normalized.as_opt::().unwrap(); @@ -133,24 +139,55 @@ //! assert!(tq.nbytes() < 51200); //! ``` -mod array; -pub use array::data::TurboQuantArrayExt; -pub use array::data::TurboQuantData; +pub(crate) mod centroids; +pub(crate) mod compress; -pub(crate) mod compute; +mod scheme; +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode; +pub use compress::turboquant_encode_unchecked; +pub use scheme::TurboQuantScheme; -mod metadata; +/// Minimum vector dimension for TurboQuant encoding. +/// +/// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total +/// amount of distortion. +pub const MIN_DIMENSION: u32 = 128; -mod vtable; +/// Maximum supported number of bits per quantized coordinate. +pub const MAX_BIT_WIDTH: u8 = 8; -pub use vtable::TurboQuant; -pub use vtable::TurboQuantArray; +/// Maximum supported number of centroids in the scalar quantizer codebook. +pub const MAX_CENTROIDS: usize = 1usize << (MAX_BIT_WIDTH as usize); -mod scheme; -pub use scheme::TurboQuantScheme; -pub use scheme::compress::TurboQuantConfig; -pub use scheme::compress::turboquant_encode; -pub use scheme::compress::turboquant_encode_unchecked; +use vortex_array::dtype::DType; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use crate::vector::AnyVector; +use crate::vector::VectorMatcherMetadata; + +/// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with +/// dimension >= [`MIN_DIMENSION`]. +/// +/// Returns the validated vector metadata on success. +pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { + let vector_metadata = dtype + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") + })?; + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + + Ok(vector_metadata) +} #[cfg(test)] mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/scheme/mod.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs similarity index 58% rename from vortex-tensor/src/encodings/turboquant/scheme/mod.rs rename to vortex-tensor/src/encodings/turboquant/scheme.rs index c3beb2a6993..ba9b95dd1e0 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -1,16 +1,25 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! TurboQuant compression scheme and decompression. +//! TurboQuant compression scheme. //! //! The scheme first normalizes the input via [`normalize_as_l2_denorm`], then encodes the -//! normalized child via [`turboquant_encode`]. The result is: +//! normalized child via [`turboquant_encode_unchecked`]. The result is: //! //! ```text -//! ScalarFnArray(L2Denorm, [TurboQuantArray, norms]) +//! ScalarFnArray(L2Denorm, [ +//! ScalarFnArray( +//! SorfTransform, +//! FSL(Dict(codes, centroids)) +//! ), +//! norms +//! ]) //! ``` //! +//! Decompression is automatic: executing the outer array walks the ScalarFn tree. +//! //! [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm +//! [`turboquant_encode_unchecked`]: crate::encodings::turboquant::turboquant_encode_unchecked use vortex_array::ArrayRef; use vortex_array::Canonical; @@ -25,22 +34,20 @@ use vortex_compressor::stats::ArrayAndStats; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use crate::encodings::turboquant::TurboQuant; +use crate::encodings::turboquant::MAX_CENTROIDS; use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::tq_validate_vector_dtype; use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; -pub(super) mod compress; -pub(super) mod decompress; - /// TurboQuant compression scheme for [`Vector`] extension types. /// -/// Applies lossy vector quantization to [`Vector`] extension arrays using the TurboQuant -/// algorithm with MSE-optimal encoding. +/// Applies lossy vector quantization to [`Vector`] extension arrays using the TurboQuant algorithm +/// with MSE-optimal encoding. /// /// Register this scheme with the compressor builder via `with_scheme`: +/// /// ```ignore /// use vortex_btrblocks::BtrBlocksCompressorBuilder; /// use vortex_tensor::encodings::turboquant::TurboQuantScheme; @@ -64,7 +71,7 @@ impl Scheme for TurboQuantScheme { return false; }; - TurboQuant::validate_dtype(ext.dtype()).is_ok() + tq_validate_vector_dtype(ext.dtype()).is_ok() } fn expected_compression_ratio( @@ -76,15 +83,19 @@ impl Scheme for TurboQuantScheme { let dtype = data.array().dtype(); let vector_metadata = - TurboQuant::validate_dtype(dtype).vortex_expect("invalid dtype for TurboQuant"); + tq_validate_vector_dtype(dtype).vortex_expect("invalid dtype for TurboQuant"); let element_ptype = vector_metadata.element_ptype(); - let bit_width: u8 = element_ptype + let element_bit_width: u8 = element_ptype .bit_width() .try_into() .vortex_expect("invalid bit width for TurboQuant"); let dimension = vector_metadata.dimensions(); - CompressionEstimate::Ratio(estimate_compression_ratio(bit_width, dimension, len)) + CompressionEstimate::Ratio(estimate_compression_ratio( + element_bit_width, + dimension, + len, + )) } fn compress( @@ -100,50 +111,50 @@ impl Scheme for TurboQuantScheme { let mut ctx = compressor.execution_ctx(); - // Normalize first: produces L2Denorm(normalized_vectors, norms). - let l2_denorm = - normalize_as_l2_denorm(&ApproxOptions::Exact, ext_array.as_ref().clone(), &mut ctx)?; + // 1. Normalize: produces L2Denorm(normalized_vectors, norms). + let l2_denorm = normalize_as_l2_denorm(ext_array.as_ref().clone(), &mut ctx)?; let normalized = l2_denorm.child_at(0).clone(); let norms = l2_denorm.child_at(1).clone(); let num_rows = l2_denorm.len(); - // Quantize the normalized child. + // 2. Quantize the normalized child: SorfTransform(FSL(Dict)). let normalized_ext = normalized .as_opt::() .vortex_expect("normalized child should be an Extension array"); + let config = TurboQuantConfig::default(); // SAFETY: We just normalized the input via `normalize_as_l2_denorm`, so all rows are // guaranteed to be unit-norm (or zero for originally-null rows). - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? }; + let sorf_dict = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? }; + // 3. Wrap back in L2Denorm: the SorfTransform is the "normalized" child. // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally // bypass the strict normalized-row validation when reattaching the stored norms. - Ok( - unsafe { L2Denorm::new_array_unchecked(&ApproxOptions::Exact, tq, norms, num_rows) }? - .into_array(), - ) + Ok(unsafe { L2Denorm::new_array_unchecked(sorf_dict, norms, num_rows) }?.into_array()) } } +// TODO(connor): If we ever add scheme vtables with metadata, we would need to pass in the config as +// a parameter here. /// Estimate the compression ratio for TurboQuant MSE encoding with the default config. -fn estimate_compression_ratio(bits_per_element: u8, dimensions: u32, num_vectors: usize) -> f64 { +fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vectors: usize) -> f64 { let config = TurboQuantConfig::default(); let padded_dim = dimensions.next_power_of_two() as usize; - // Per-vector: MSE codes per padded coordinate, plus one f32 norm. - let compressed_bits_per_vector = 32 // norm is always f32 - + (config.bit_width as usize) * padded_dim; // MSE codes + // Per-vector: MSE codes per padded coordinate, plus one stored norm in the input element + // float width. + let compressed_bits_per_vector = + usize::from(element_bit_width) + usize::from(config.bit_width) * padded_dim; - // Shared overhead: codebook centroids (2^bit_width f32 values) and - // rotation signs (num_rounds * padded_dim bits). + // Shared overhead: codebook centroids (2^bit_width f32 values). + // Note: rotation signs are no longer stored — rotation is deterministic from seed. let num_centroids = 1usize << config.bit_width; - debug_assert!(num_centroids <= TurboQuant::MAX_CENTROIDS); - let overhead_bits = num_centroids * 32 // centroids are always f32 - + config.num_rounds as usize * padded_dim; // rotation signs, 1 bit each + debug_assert!(num_centroids <= MAX_CENTROIDS); + let overhead_bits = num_centroids * 32; // centroids are always f32 let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits; - let uncompressed_size_bits = bits_per_element as usize * dimensions as usize * num_vectors; + let uncompressed_size_bits = usize::from(element_bit_width) * dimensions as usize * num_vectors; uncompressed_size_bits as f64 / compressed_size_bits as f64 } @@ -155,27 +166,27 @@ mod tests { /// Verify compression ratio for typical embedding dimensions. /// - /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~4-6x. - /// f32 input at 1024-d (no padding) should give higher ratio since no waste. + /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~3x. + /// f32 input at 1024-d (no padding) should give ~4x since no padding waste. #[rstest] - #[case::f32_768d(32, 768, 1000, 2.5, 4.0)] + #[case::f32_768d(32, 768, 1000, 2.5, 4.5)] #[case::f32_1024d(32, 1024, 1000, 3.5, 5.0)] - #[case::f32_1536d(32, 1536, 1000, 2.5, 4.0)] + #[case::f32_1536d(32, 1536, 1000, 2.5, 4.5)] #[case::f32_128d(32, 128, 1000, 3.0, 5.0)] - #[case::f64_768d(64, 768, 1000, 5.0, 7.0)] - #[case::f16_768d(16, 768, 1000, 1.2, 2.0)] + #[case::f64_768d(64, 768, 1000, 5.0, 9.0)] + #[case::f16_768d(16, 768, 1000, 1.2, 2.5)] fn compression_ratio_in_expected_range( - #[case] bits_per_element: u8, + #[case] element_bit_width: u8, #[case] dim: u32, #[case] num_vectors: usize, #[case] min_ratio: f64, #[case] max_ratio: f64, ) { - let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + let ratio = estimate_compression_ratio(element_bit_width, dim, num_vectors); assert!( ratio > min_ratio && ratio < max_ratio, "ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \ - {bits_per_element}-bit elements, dim={dim}, n={num_vectors}" + {element_bit_width}-bit elements, dim={dim}, n={num_vectors}" ); } @@ -186,14 +197,38 @@ mod tests { #[case(32, 768, 10)] #[case(64, 256, 50)] fn ratio_always_greater_than_one( - #[case] bits_per_element: u8, + #[case] element_bit_width: u8, #[case] dim: u32, #[case] num_vectors: usize, ) { - let ratio = estimate_compression_ratio(bits_per_element, dim, num_vectors); + let ratio = estimate_compression_ratio(element_bit_width, dim, num_vectors); assert!( ratio > 1.0, - "ratio {ratio:.4} <= 1.0 for {bits_per_element}-bit, dim={dim}, n={num_vectors}" + "ratio {ratio:.4} <= 1.0 for {element_bit_width}-bit, dim={dim}, n={num_vectors}" + ); + } + + #[rstest] + #[case(16)] + #[case(32)] + #[case(64)] + fn ratio_accounts_for_norm_storage_width(#[case] element_bit_width: u8) { + let dim = 128u32; + let num_vectors = 1usize; + let padded_dim = dim.next_power_of_two() as usize; + let config = TurboQuantConfig::default(); + let num_centroids = 1usize << config.bit_width; + + let expected_compressed_bits = usize::from(element_bit_width) + + usize::from(config.bit_width) * padded_dim + + num_centroids * 32; + let expected_uncompressed_bits = + usize::from(element_bit_width) * dim as usize * num_vectors; + let expected = expected_uncompressed_bits as f64 / expected_compressed_bits as f64; + + assert_eq!( + estimate_compression_ratio(element_bit_width, dim, num_vectors), + expected ); } diff --git a/vortex-tensor/src/encodings/turboquant/scheme/decompress.rs b/vortex-tensor/src/encodings/turboquant/scheme/decompress.rs deleted file mode 100644 index 2e92ced88e1..00000000000 --- a/vortex-tensor/src/encodings/turboquant/scheme/decompress.rs +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant decoding (dequantization) logic. -//! -//! Decompression produces unit-norm vectors. The original magnitudes are restored externally -//! by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) ScalarFnArray wrapper. - -use num_traits::Float; -use num_traits::FromPrimitive; -use vortex_array::Array; -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::dtype::NativePType; -use vortex_array::dtype::Nullability; -use vortex_array::match_each_float_ptype; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexResult; - -use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::TurboQuantArrayExt; -use crate::encodings::turboquant::array::rotation::RotationMatrix; -use crate::encodings::turboquant::compute::float_from_f32; -use crate::vector::AnyVector; - -/// Decompress a `TurboQuantArray` into a unit-norm [`Vector`] extension array. -/// -/// The returned array is an [`ExtensionArray`] with the (non-nullable) Vector dtype wrapping a -/// `FixedSizeListArray` of the original vector element type. Each vector has unit L2 norm; the -/// original magnitudes are restored by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) -/// ScalarFnArray wrapper. -/// -/// [`Vector`]: crate::vector::Vector -pub fn execute_decompress( - array: Array, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let dim = array.dimension() as usize; - let padded_dim = array.padded_dim() as usize; - let num_rows = array.len(); - let ext_dtype = array.dtype().as_extension().clone(); - let element_ptype = ext_dtype.metadata::().element_ptype(); - - if num_rows == 0 { - match_each_float_ptype!(element_ptype, |T| { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - array.dimension(), - Validity::NonNullable, - 0, - )?; - - return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()); - }) - } - - // Read stored centroids (always f32). - let centroids_prim = array.centroids().clone().execute::(ctx)?; - let centroids = centroids_prim.as_slice::(); - - // The rotation signs are stored as a FixedSizeListArray wrapping bitpacked u8 values. - // We unwrap to the flat elements, then FastLanes SIMD-unpacks the 1-bit values into u8 0/1. - // These are expanded to u32 XOR masks once (amortized over all rows), enabling branchless - // XOR-based sign application in the per-row structured-rotation hot loop. - let num_rounds = array.num_rounds() as usize; - let signs_fsl = array - .rotation_signs() - .clone() - .execute::(ctx)?; - let signs_prim = signs_fsl - .elements() - .clone() - .execute::(ctx)?; - let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::(), dim, num_rounds)?; - - // Unpack codes from FixedSizeListArray -> flat u8 elements. - let codes_fsl = array.codes().clone().execute::(ctx)?; - let codes_prim = codes_fsl - .elements() - .clone() - .execute::(ctx)?; - let indices = codes_prim.as_slice::(); - - // MSE decode: dequantize (f32) -> inverse rotate (f32) -> cast to T. - // The rotation and centroid lookup always happen in f32. The final output is cast to the - // Vector's element type to match the original storage dtype. No norm scaling is applied here; - // that is handled by the external L2Denorm wrapper. - match_each_float_ptype!(element_ptype, |T| { - decompress_typed::(centroids, &rotation, indices, dim, padded_dim, num_rows).and_then( - |elements| { - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - array.dimension(), - Validity::NonNullable, - num_rows, - )?; - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - }, - ) - }) -} - -/// Typed decompress: dequantizes in f32 and produces unit-norm output as `T`. -fn decompress_typed( - centroids: &[f32], - rotation: &RotationMatrix, - indices: &[u8], - dim: usize, - padded_dim: usize, - num_rows: usize, -) -> VortexResult { - let mut output = BufferMut::::with_capacity(num_rows * dim); - let mut dequantized = vec![0.0f32; padded_dim]; - let mut unrotated = vec![0.0f32; padded_dim]; - - for row in 0..num_rows { - let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim]; - - for idx in 0..padded_dim { - dequantized[idx] = centroids[row_indices[idx] as usize]; - } - - rotation.inverse_rotate(&dequantized, &mut unrotated); - - for idx in 0..dim { - output.push(float_from_f32::(unrotated[idx])); - } - } - - Ok(PrimitiveArray::new::( - output.freeze(), - Validity::NonNullable, - )) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests.rs b/vortex-tensor/src/encodings/turboquant/tests.rs deleted file mode 100644 index 8e633469a3d..00000000000 --- a/vortex-tensor/src/encodings/turboquant/tests.rs +++ /dev/null @@ -1,1260 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::sync::LazyLock; - -use rand::SeedableRng; -use rand::rngs::StdRng; -use rand_distr::Distribution; -use rand_distr::Normal; -use rstest::rstest; -use vortex_array::ArrayRef; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFnVTable; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; -use vortex_array::session::ArraySession; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_session::VortexSession; - -use crate::encodings::turboquant::TurboQuant; -use crate::encodings::turboquant::TurboQuantArrayExt; -use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::array::rotation::RotationMatrix; -use crate::encodings::turboquant::turboquant_encode; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::ApproxOptions; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; -use crate::scalar_fns::l2_norm::L2Norm; -use crate::vector::Vector; - -static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - -/// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal) with the given -/// validity. -fn make_fsl_with_validity( - num_rows: usize, - dim: usize, - seed: u64, - validity: Validity, -) -> FixedSizeListArray { - let mut rng = StdRng::seed_from_u64(seed); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - validity, - num_rows, - ) - .unwrap() -} - -/// Create a non-nullable FixedSizeListArray of random f32 vectors (i.i.d. standard normal). -fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { - let mut rng = StdRng::seed_from_u64(seed); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - num_rows, - ) - .unwrap() -} - -/// Wrap a `FixedSizeListArray` in a `Vector` extension array. -fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .unwrap() - .erased(); - ExtensionArray::new(ext_dtype, fsl.clone().into_array()) -} - -/// Full encode pipeline: normalize, then TQ-encode, then wrap in L2Denorm. -/// -/// This mirrors what `TurboQuantScheme::compress()` does: normalize via `normalize_as_l2_denorm`, -/// then quantize the normalized child via `turboquant_encode_unchecked`, then reassemble. -fn normalize_and_encode( - ext: &ExtensionArray, - config: &TurboQuantConfig, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(&ApproxOptions::Exact, ext.as_ref().clone(), ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx)? }; - - Ok( - unsafe { L2Denorm::new_array_unchecked(&ApproxOptions::Exact, tq, norms, num_rows) }? - .into_array(), - ) -} - -/// Unwrap an L2Denorm ScalarFnArray into its TQ child and norms child. -fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { - let sfn = encoded - .as_opt::() - .expect("expected ScalarFnArray"); - let tq_child = sfn.child_at(0).clone(); - let norms_child = sfn.child_at(1).clone(); - (tq_child, norms_child) -} - -fn theoretical_mse_bound(bit_width: u8) -> f32 { - let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; - sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) -} - -fn per_vector_normalized_mse( - original: &[f32], - reconstructed: &[f32], - dim: usize, - num_rows: usize, -) -> f32 { - let mut total = 0.0f32; - for row in 0..num_rows { - let orig = &original[row * dim..(row + 1) * dim]; - let recon = &reconstructed[row * dim..(row + 1) * dim]; - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - if norm_sq < 1e-10 { - continue; - } - let err_sq: f32 = orig - .iter() - .zip(recon.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - total += err_sq / norm_sq; - } - total / num_rows as f32 -} - -/// Normalize, encode, and decode, returning (original, decoded) flat f32 slices. -fn encode_decode( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, -) -> VortexResult<(Vec, Vec)> { - let mut ctx = SESSION.create_execution_ctx(); - let original: Vec = { - let prim = fsl.elements().clone().execute::(&mut ctx)?; - prim.as_slice::().to_vec() - }; - let ext = make_vector_ext(fsl); - let encoded = normalize_and_encode(&ext, config, &mut ctx)?; - let decoded_ext = encoded.execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - let decoded_elements: Vec = { - let prim = decoded_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - prim.as_slice::().to_vec() - }; - Ok((original, decoded_elements)) -} - -fn empty_turboquant_parts( - dim: u32, -) -> VortexResult<(vortex_array::dtype::DType, ArrayRef, ArrayRef, ArrayRef)> { - let fsl = make_fsl(0, dim as usize, 42); - let ext = make_vector_ext(&fsl); - - let codes = FixedSizeListArray::try_new( - PrimitiveArray::empty::(Nullability::NonNullable).into_array(), - dim, - Validity::NonNullable, - 0, - )? - .into_array(); - let centroids = PrimitiveArray::empty::(Nullability::NonNullable).into_array(); - let rotation_signs = FixedSizeListArray::try_new( - PrimitiveArray::empty::(Nullability::NonNullable).into_array(), - dim, - Validity::NonNullable, - 0, - )? - .into_array(); - - // TQ dtype is non-nullable. - Ok(( - ext.dtype().as_nonnullable(), - codes, - centroids, - rotation_signs, - )) -} - -fn normalized_child( - ext: &ExtensionArray, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - Ok( - normalize_as_l2_denorm(&ApproxOptions::Exact, ext.as_ref().clone(), ctx)? - .child_at(0) - .clone(), - ) -} - -// ----------------------------------------------------------------------- -// Roundtrip tests -// ----------------------------------------------------------------------- - -#[rstest] -#[case(128, 1)] -#[case(128, 2)] -#[case(128, 3)] -#[case(128, 4)] -#[case(128, 6)] -#[case(128, 8)] -#[case(256, 2)] -fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let fsl = make_fsl(10, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - assert_eq!(decoded.len(), original.len()); - Ok(()) -} - -#[test] -fn empty_try_new_rejects_invalid_centroids_dtype() -> VortexResult<()> { - let (dtype, codes, _centroids, rotation_signs) = empty_turboquant_parts(128)?; - let wrong_centroids = PrimitiveArray::empty::(Nullability::NonNullable).into_array(); - - let err = TurboQuant::try_new_array(dtype, codes, wrong_centroids, rotation_signs).unwrap_err(); - - assert!( - err.to_string() - .contains("centroids dtype must be non-nullable f32") - ); - Ok(()) -} - -#[test] -fn empty_try_new_rejects_invalid_rotation_signs_dtype() -> VortexResult<()> { - let (dtype, codes, centroids, _rotation_signs) = empty_turboquant_parts(128)?; - let wrong_rotation_signs = PrimitiveArray::empty::(Nullability::NonNullable).into_array(); - - let err = TurboQuant::try_new_array(dtype, codes, centroids, wrong_rotation_signs).unwrap_err(); - - assert!( - err.to_string() - .contains("rotation_signs dtype does not match") - ); - Ok(()) -} - -// ----------------------------------------------------------------------- -// MSE quality tests -// ----------------------------------------------------------------------- - -#[rstest] -#[case(128, 1)] -#[case(128, 2)] -#[case(128, 3)] -#[case(128, 4)] -#[case(256, 2)] -#[case(256, 4)] -fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - - let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - let bound = theoretical_mse_bound(bit_width); - - assert!( - normalized_mse < bound, - "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ - for dim={dim}, bits={bit_width}", - ); - Ok(()) -} - -#[rstest] -#[case(128, 6)] -#[case(128, 8)] -#[case(256, 6)] -#[case(256, 8)] -fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - - let config_4bit = TurboQuantConfig { - bit_width: 4, - seed: Some(123), - num_rounds: 3, - }; - let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; - let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); - - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - - assert!( - mse < mse_4bit, - "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" - ); - assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); - Ok(()) -} - -#[test] -fn mse_decreases_with_bits() -> VortexResult<()> { - let dim = 128; - let num_rows = 50; - let fsl = make_fsl(num_rows, dim, 99); - - let mut prev_mse = f32::MAX; - for bit_width in 1..=8u8 { - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - assert!( - mse <= prev_mse * 1.01, - "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" - ); - prev_mse = mse; - } - Ok(()) -} - -// ----------------------------------------------------------------------- -// Edge cases -// ----------------------------------------------------------------------- - -#[rstest] -#[case(0)] -#[case(1)] -fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { - let fsl = make_fsl(num_rows, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let decoded = encoded.execute::(&mut ctx)?; - assert_eq!(decoded.len(), num_rows); - Ok(()) -} - -#[rstest] -#[case(1)] -#[case(64)] -#[case(127)] -fn rejects_dimension_below_128(#[case] dim: usize) { - let fsl = make_fsl_small(dim); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: Some(0), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - assert!(turboquant_encode(ext.as_view(), &config, &mut ctx).is_err()); -} - -#[test] -fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(half::f16::from_f32(normal.sample(&mut rng))); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - num_rounds: 3, - }; - - let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalized_child(&ext, &mut ctx)?; - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - - let encoded = turboquant_encode(normalized_ext, &config, &mut ctx)?; - assert_eq!(encoded.len(), num_rows); - Ok(()) -} - -fn make_fsl_small(dim: usize) -> FixedSizeListArray { - let mut buf = BufferMut::::with_capacity(dim); - for i in 0..dim { - buf.push(i as f32 + 1.0); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - 1, - ) - .unwrap() -} - -/// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). -#[test] -fn all_zero_vectors_roundtrip() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let buf = BufferMut::::full(0.0f32, num_rows * dim); - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - num_rows, - )?; - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - // All-zero vectors should decode to all-zero (norm=0 -> 0 * anything = 0). - for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { - assert_eq!(o, 0.0, "original[{i}] not zero"); - assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); - } - Ok(()) -} - -/// Verify that f64 input is accepted and encoded (converted to f32 internally). -#[test] -fn f64_input_encodes_successfully() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f64, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - num_rounds: 3, - }; - // Verify encoding succeeds with f64 input (f64->f32 conversion). - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, norms_child) = unwrap_l2denorm(&encoded); - let tq = tq_child.as_opt::().unwrap(); - assert_eq!(norms_child.len(), num_rows); - assert_eq!(tq.dimension() as usize, dim); - Ok(()) -} - -/// Verify that f16 input is accepted and encoded (upcast to f32 internally). -#[test] -fn f16_input_encodes_successfully() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(half::f16::from_f32(normal.sample(&mut rng))); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(42), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, norms_child) = unwrap_l2denorm(&encoded); - let tq = tq_child.as_opt::().unwrap(); - assert_eq!(norms_child.len(), num_rows); - assert_eq!(tq.dimension() as usize, dim); - - // Verify roundtrip: decode and check reconstruction is reasonable. - let decoded_ext = encoded.execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - assert_eq!(decoded_fsl.len(), num_rows); - Ok(()) -} - -// ----------------------------------------------------------------------- -// Verification tests for stored metadata -// ----------------------------------------------------------------------- - -/// Verify that the centroids stored in the array match what `get_centroids()` computes. -#[test] -fn stored_centroids_match_computed() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, _norms) = unwrap_l2denorm(&encoded); - let tq = tq_child.as_opt::().unwrap(); - - let mut ctx = SESSION.create_execution_ctx(); - let stored_centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; - let stored = stored_centroids_prim.as_slice::(); - - let padded_dim = tq.padded_dim(); - let computed = crate::encodings::turboquant::array::centroids::get_centroids(padded_dim, 3)?; - - assert_eq!(stored.len(), computed.len()); - for i in 0..stored.len() { - assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); - } - Ok(()) -} - -/// Verify that stored rotation signs produce identical decode to seed-based decode. -#[test] -fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 4, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, _norms) = unwrap_l2denorm(&encoded); - let tq = tq_child.as_opt::().unwrap(); - - // Decode via the full L2Denorm path (TQ decompress + norm scaling). - let mut ctx = SESSION.create_execution_ctx(); - let decoded_ext = encoded.execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - let decoded = decoded_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - let decoded_slice = decoded.as_slice::(); - - // Verify stored signs match seed-derived signs. - let rot_from_seed = RotationMatrix::try_new(123, 128, 4)?; - let expected_u8 = rot_from_seed.export_inverse_signs_u8(); - let stored_signs_fsl = tq - .rotation_signs() - .clone() - .execute::(&mut ctx)?; - let stored_signs = stored_signs_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - let stored_u8 = stored_signs.as_slice::(); - - assert_eq!(expected_u8.len(), stored_u8.len()); - for i in 0..expected_u8.len() { - assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); - } - - // Also verify decode output is non-empty and has expected size. - assert_eq!(decoded_slice.len(), 20 * 128); - Ok(()) -} - -// ----------------------------------------------------------------------- -// Compute pushdown tests -// ----------------------------------------------------------------------- - -#[test] -fn slice_preserves_data() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 4, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - - // Full decompress then slice. - let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; - let full_fsl = full_decoded - .storage_array() - .clone() - .execute::(&mut ctx)?; - let expected = full_fsl.slice(5..10)?; - let expected_fsl = expected.execute::(&mut ctx)?; - let expected_elements = expected_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - - // Slice then decompress. - let sliced = encoded.slice(5..10)?; - let sliced_decoded = sliced.execute::(&mut ctx)?; - let sliced_fsl = sliced_decoded - .storage_array() - .clone() - .execute::(&mut ctx)?; - let actual_elements = sliced_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - - assert_eq!( - expected_elements.as_slice::(), - actual_elements.as_slice::() - ); - Ok(()) -} - -#[test] -fn scalar_at_matches_decompress() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 2, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - - let full_decoded = encoded.clone().execute::(&mut ctx)?; - - for i in [0, 1, 5, 9] { - let expected = full_decoded.scalar_at(i)?; - let actual = encoded.scalar_at(i)?; - assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); - } - Ok(()) -} - -#[test] -fn l2_norm_readthrough() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 5, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (_tq_child, norms_child) = unwrap_l2denorm(&encoded); - - // Stored norms should match the actual L2 norms of the input. - let norms_prim = norms_child.execute::(&mut ctx)?; - let stored_norms = norms_prim.as_slice::(); - - let input_prim = fsl.elements().clone().execute::(&mut ctx)?; - let input_f32 = input_prim.as_slice::(); - for row in 0..10 { - let vec = &input_f32[row * 128..(row + 1) * 128]; - let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); - assert!( - (stored_norms[row] - actual_norm).abs() < 1e-5, - "norm mismatch at row {row}: stored={}, actual={}", - stored_norms[row], - actual_norm - ); - } - Ok(()) -} - -#[test] -fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 4, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, norms_child) = unwrap_l2denorm(&encoded); - let tq = tq_child.as_opt::().unwrap(); - - // Compute exact cosine similarity from original data. - let input_prim = fsl.elements().clone().execute::(&mut ctx)?; - let input_f32 = input_prim.as_slice::(); - - // Read quantized codes, norms, and centroids for approximate computation. - let mut ctx = SESSION.create_execution_ctx(); - let pd = tq.padded_dim() as usize; - let norms_prim = norms_child.execute::(&mut ctx)?; - let norms = norms_prim.as_slice::(); - let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; - let codes_prim = codes_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - let all_codes = codes_prim.as_slice::(); - let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; - let centroid_vals = centroids_prim.as_slice::(); - - for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { - let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; - let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; - - let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); - let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); - let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); - let exact_cos = dot / (norm_a * norm_b); - - // Approximate cosine similarity in quantized domain. - let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { - 0.0 - } else { - let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; - let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; - codes_a - .iter() - .zip(codes_b.iter()) - .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) - .sum::() - }; - - // At 4-bit, the theoretical MSE bound per coordinate is ~0.0106 (Theorem 1). For cosine - // similarity (bounded [-1, 1]), the error is bounded roughly by 2*sqrt(MSE) ~ 0.2. We use - // 0.15 as a tighter empirical bound. - let error = (exact_cos - approx_cos).abs(); - assert!( - error < 0.15, - "cosine similarity error too large for ({row_a}, {row_b}): \ - exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" - ); - } - Ok(()) -} - -/// Verify approximate dot product in the quantized domain. -/// -/// NOTE: The MSE quantizer (TurboQuant_mse) has inherent **multiplicative bias** for inner -/// products — the quantized dot product systematically over- or under-estimates the true value. -/// This is a fundamental property: the paper's `TurboQuant_prod` variant adds QJL specifically -/// to debias inner products, but we only implement the MSE-only variant. -/// -/// Even at 8-bit (near-lossless reconstruction, MSE ~4e-5), the quantized-domain dot product -/// can have ~10-15% relative error due to this bias. This tolerance is therefore intentionally -/// loose — we're testing that the approximation is in the right ballpark, not that it's precise. -/// -/// TODO(connor): Revisit these tolerances when we have TurboQuant_prod (QJL debiasing). -#[test] -fn dot_product_quantized_accuracy() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 8, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, norms_child) = unwrap_l2denorm(&encoded); - let tq = tq_child.as_opt::().unwrap(); - - let input_prim = fsl.elements().clone().execute::(&mut ctx)?; - let input_f32 = input_prim.as_slice::(); - - let mut ctx = SESSION.create_execution_ctx(); - let pd = tq.padded_dim() as usize; - let norms_prim = norms_child.execute::(&mut ctx)?; - let norms = norms_prim.as_slice::(); - let codes_fsl = tq.codes().clone().execute::(&mut ctx)?; - let codes_prim = codes_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - let all_codes = codes_prim.as_slice::(); - let centroids_prim = tq.centroids().clone().execute::(&mut ctx)?; - let centroid_vals = centroids_prim.as_slice::(); - - for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { - let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; - let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; - - let exact_dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); - - let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; - let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; - let unit_dot: f32 = codes_a - .iter() - .zip(codes_b.iter()) - .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) - .sum(); - let approx_dot = norms[row_a] * norms[row_b] * unit_dot; - - // See doc comment above: 15% relative error is expected due to MSE quantizer bias. - let scale = exact_dot.abs().max(1.0); - let rel_error = (exact_dot - approx_dot).abs() / scale; - assert!( - rel_error < 0.15, - "dot product error too large for ({row_a}, {row_b}): \ - exact={exact_dot:.4}, approx={approx_dot:.4}, rel_error={rel_error:.4}" - ); - } - Ok(()) -} - -/// Roundtrip at large embedding dimensions to validate padding and SRHT at common sizes. -/// -/// NOTE: The theoretical MSE bound (Theorem 1) is proved for Haar-distributed random orthogonal -/// matrices, not SRHT. The SRHT is a practical O(d log d) approximation that doesn't exactly -/// satisfy the Haar assumption, so empirical MSE can slightly exceed the theoretical bound. We -/// use a 2x multiplier to account for this gap. -/// -/// The 1024-d case uses 5-bit instead of 4-bit because at 4-bit the SRHT approximation error -/// at d=1024 pushes MSE ~20% above the 1x theoretical bound (0.0127 vs bound 0.0106). -/// -/// TODO(connor): Revisit after Stage 2 block decomposition — at d=768 with block_size=256, -/// the per-block SRHT will be lower-dimensional and may have different error characteristics. -#[rstest] -#[case(768, 4)] -#[case(1024, 5)] -fn large_dimension_roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 10; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: Some(123), - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - assert_eq!(decoded.len(), original.len()); - - let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - // 2x slack for the SRHT-vs-Haar gap (see doc comment above). - let bound = 2.0 * theoretical_mse_bound(bit_width); - assert!( - normalized_mse < bound, - "Normalized MSE {normalized_mse:.6} exceeds 2x bound {bound:.6} for dim={dim}, bits={bit_width}", - ); - Ok(()) -} - -/// Verify that the encoded array's dtype is a Vector extension type. -#[test] -fn encoded_dtype_is_vector_extension() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 2, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - - // The encoded TurboQuant array should claim a Vector extension dtype. - assert!( - encoded.dtype().is_extension(), - "TurboQuant dtype should be an extension type, got {}", - encoded.dtype() - ); - assert!( - encoded.dtype().as_extension().is::(), - "TurboQuant dtype should be a Vector extension type" - ); - Ok(()) -} - -// ----------------------------------------------------------------------- -// Nullable vector tests -// ----------------------------------------------------------------------- - -/// Encode a nullable Vector array and verify roundtrip preserves validity and non-null values. -#[test] -fn nullable_vectors_roundtrip() -> VortexResult<()> { - // Rows 2, 5, 7 are null. - let validity = Validity::from_iter([ - true, true, false, true, true, false, true, false, true, true, - ]); - let fsl = make_fsl_with_validity(10, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 4, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - - assert_eq!(encoded.len(), 10); - assert!(encoded.dtype().is_nullable()); - - // Check validity of the encoded array. - let encoded_validity = encoded.validity()?; - for i in 0..10 { - let expected = ![2, 5, 7].contains(&i); - assert_eq!( - encoded_validity.is_valid(i)?, - expected, - "validity mismatch at row {i}" - ); - } - - // Decode and verify non-null rows have correct data. - let decoded_ext = encoded.execute::(&mut ctx)?; - assert_eq!(decoded_ext.len(), 10); - - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - let decoded_prim = decoded_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - let decoded_f32 = decoded_prim.as_slice::(); - - // Original f32 elements for non-null row comparison. - let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; - let orig_f32 = orig_prim.as_slice::(); - - // Non-null rows should have reasonable reconstruction (within MSE bounds). - for row in [0, 1, 3, 4, 6, 8, 9] { - let orig_vec = &orig_f32[row * 128..(row + 1) * 128]; - let dec_vec = &decoded_f32[row * 128..(row + 1) * 128]; - let norm_sq: f32 = orig_vec.iter().map(|&v| v * v).sum(); - let err_sq: f32 = orig_vec - .iter() - .zip(dec_vec.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - // 3-bit normalized MSE should be well under the theoretical bound. - assert!( - err_sq / norm_sq < 0.1, - "non-null row {row} has excessive reconstruction error" - ); - } - Ok(()) -} - -/// Verify that norms carry the validity: null vectors have null norms. -#[test] -fn nullable_norms_match_validity() -> VortexResult<()> { - let validity = Validity::from_iter([true, false, true, false, true]); - let fsl = make_fsl_with_validity(5, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 2, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (_tq_child, norms_child) = unwrap_l2denorm(&encoded); - - let norms_validity = norms_child.validity()?; - for i in 0..5 { - let expected = i % 2 == 0; // rows 0, 2, 4 are valid - assert_eq!( - norms_validity.is_valid(i)?, - expected, - "norms validity mismatch at row {i}" - ); - } - Ok(()) -} - -/// Verify that L2Norm readthrough works correctly on nullable TurboQuant arrays. -#[test] -fn nullable_l2_norm_readthrough() -> VortexResult<()> { - let validity = Validity::from_iter([true, false, true, false, true]); - let fsl = make_fsl_with_validity(5, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - - // Compute L2Norm on the encoded array. - let norm_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, encoded, 5)?; - let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; - - // Null rows should have null norms, valid rows should have correct norms. - let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; - let orig_f32 = orig_prim.as_slice::(); - for row in 0..5 { - if row % 2 == 0 { - assert!(norms.is_valid(row)?, "row {row} should be valid"); - let expected: f32 = orig_f32[row * 128..(row + 1) * 128] - .iter() - .map(|&v| v * v) - .sum::() - .sqrt(); - let actual = norms.as_slice::()[row]; - assert!( - (actual - expected).abs() < 1e-5, - "norm mismatch at valid row {row}: actual={actual}, expected={expected}" - ); - } else { - assert!(!norms.is_valid(row)?, "row {row} should be null"); - } - } - Ok(()) -} - -/// Verify that slicing a nullable TurboQuant array preserves validity. -#[test] -fn nullable_slice_preserves_validity() -> VortexResult<()> { - // Rows 2, 5, 7 are null. - let validity = Validity::from_iter([ - true, true, false, true, true, false, true, false, true, true, - ]); - let fsl = make_fsl_with_validity(10, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 2, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - - // Slice rows 1..6 -> [true, false, true, true, false]. - let sliced = encoded.slice(1..6)?; - assert_eq!(sliced.len(), 5); - - let sliced_validity = sliced.validity()?; - let expected = [true, false, true, true, false]; - for (i, &exp) in expected.iter().enumerate() { - assert_eq!( - sliced_validity.is_valid(i)?, - exp, - "sliced validity mismatch at index {i}" - ); - } - Ok(()) -} - -// ----------------------------------------------------------------------- -// Serde roundtrip tests -// ----------------------------------------------------------------------- - -/// Verify that a TurboQuant array (extracted from the L2Denorm wrapper) survives -/// serialize/deserialize. -/// -/// TODO(connor): ScalarFnArray cannot be serialized yet, so we test the TQ child directly. -#[test] -fn serde_roundtrip() -> VortexResult<()> { - use vortex_array::ArrayContext; - use vortex_array::ArrayEq; - use vortex_array::Precision; - use vortex_array::serde::SerializeOptions; - use vortex_array::serde::SerializedArray; - use vortex_array::session::ArraySessionExt; - use vortex_buffer::ByteBufferMut; - use vortex_fastlanes::BitPacked; - use vortex_session::registry::ReadContext; - - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: Some(123), - num_rounds: 5, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - let (tq_child, _norms) = unwrap_l2denorm(&encoded); - - let dtype = tq_child.dtype().clone(); - let len = tq_child.len(); - - // Serialize the TQ child. - let array_ctx = ArrayContext::empty(); - let serde_session = VortexSession::empty().with::(); - serde_session.arrays().register(TurboQuant); - let serialized = - tq_child.serialize(&array_ctx, &serde_session, &SerializeOptions::default())?; - - let mut concat = ByteBufferMut::empty(); - for buf in serialized { - concat.extend_from_slice(buf.as_ref()); - } - - // Deserialize. The session needs TurboQuant and BitPacked (for rotation signs) registered. - serde_session.arrays().register(BitPacked); - - let parts = SerializedArray::try_from(concat.freeze())?; - let decoded = parts.decode( - &dtype, - len, - &ReadContext::new(array_ctx.to_ids()), - &serde_session, - )?; - - assert!( - decoded.array_eq(&tq_child, Precision::Value), - "serde roundtrip did not preserve array equality" - ); - Ok(()) -} - -/// Verify that a degenerate (empty) TurboQuant array survives serialize/deserialize. -#[test] -fn serde_roundtrip_empty() -> VortexResult<()> { - use vortex_array::ArrayContext; - use vortex_array::ArrayEq; - use vortex_array::Precision; - use vortex_array::serde::SerializeOptions; - use vortex_array::serde::SerializedArray; - use vortex_array::session::ArraySessionExt; - use vortex_buffer::ByteBufferMut; - use vortex_fastlanes::BitPacked; - use vortex_session::registry::ReadContext; - - let fsl = make_fsl(0, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: Some(123), - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; - assert_eq!(encoded.len(), 0); - let (tq_child, _norms) = unwrap_l2denorm(&encoded); - - let dtype = tq_child.dtype().clone(); - let len = tq_child.len(); - - let serde_session = VortexSession::empty().with::(); - serde_session.arrays().register(TurboQuant); - serde_session.arrays().register(BitPacked); - - let array_ctx = ArrayContext::empty(); - let serialized = - tq_child.serialize(&array_ctx, &serde_session, &SerializeOptions::default())?; - - let mut concat = ByteBufferMut::empty(); - for buf in serialized { - concat.extend_from_slice(buf.as_ref()); - } - - let parts = SerializedArray::try_from(concat.freeze())?; - let decoded = parts.decode( - &dtype, - len, - &ReadContext::new(array_ctx.to_ids()), - &serde_session, - )?; - - assert!( - decoded.array_eq(&tq_child, Precision::Value), - "serde roundtrip did not preserve array equality" - ); - Ok(()) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs new file mode 100644 index 00000000000..ac0389048f4 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests/compute.rs @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_error::VortexResult; + +use super::*; +use crate::scalar_fns::cosine_similarity::CosineSimilarity; +use crate::scalar_fns::l2_norm::L2Norm; + +fn execute_l2_norm( + input: ArrayRef, + len: usize, + ctx: &mut vortex_array::ExecutionCtx, +) -> VortexResult { + L2Norm::try_new_array(input, len)?.into_array().execute(ctx) +} + +fn execute_cosine_similarity( + lhs: ArrayRef, + rhs: ArrayRef, + len: usize, + ctx: &mut vortex_array::ExecutionCtx, +) -> VortexResult { + CosineSimilarity::try_new_array(lhs, rhs, len)? + .into_array() + .execute(ctx) +} + +#[test] +fn slice_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 4, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + // Full decompress then slice. + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let full_fsl = full_decoded + .storage_array() + .clone() + .execute::(&mut ctx)?; + let expected = full_fsl.slice(5..10)?; + let expected_fsl = expected.execute::(&mut ctx)?; + let expected_elements = expected_fsl + .elements() + .clone() + .execute::(&mut ctx)?; + + // Slice then decompress. + let sliced = encoded.slice(5..10)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let sliced_fsl = sliced_decoded + .storage_array() + .clone() + .execute::(&mut ctx)?; + let actual_elements = sliced_fsl + .elements() + .clone() + .execute::(&mut ctx)?; + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) +} + +#[test] +fn scalar_at_matches_decompress() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 2, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let full_decoded = encoded.clone().execute::(&mut ctx)?; + + for i in [0, 1, 5, 9] { + let expected = full_decoded.scalar_at(i)?; + let actual = encoded.scalar_at(i)?; + assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); + } + Ok(()) +} + +#[test] +fn l2_norm_readthrough() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 5, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); + + // Stored norms should match the actual L2 norms of the input. + let norms_prim = norms_child.execute::(&mut ctx)?; + let stored_norms = norms_prim.as_slice::(); + + let input_prim = fsl.elements().clone().execute::(&mut ctx)?; + let input_f32 = input_prim.as_slice::(); + for row in 0..10 { + let vec = &input_f32[row * 128..(row + 1) * 128]; + let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); + assert!( + (stored_norms[row] - actual_norm).abs() < 1e-5, + "norm mismatch at row {row}: stored={}, actual={}", + stored_norms[row], + actual_norm + ); + } + + // Also verify L2Norm readthrough shortcut works. + let norms = execute_l2_norm(encoded, 10, &mut ctx)?; + assert_eq!(norms.as_slice::(), stored_norms); + assert_eq!(norms.len(), 10); + Ok(()) +} + +#[test] +fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> { + let num_rows = 12; + let fsl = make_fsl(num_rows, 128, 7); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 1, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); + + let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; + let encoded_norms = execute_l2_norm(encoded.clone(), num_rows, &mut ctx)?; + assert_eq!( + encoded_norms.as_slice::(), + stored_norms.as_slice::() + ); + + let decoded = encoded.execute::(&mut ctx)?.into_array(); + let decoded_norms = execute_l2_norm(decoded, num_rows, &mut ctx)?; + let max_gap = stored_norms + .as_slice::() + .iter() + .zip(decoded_norms.as_slice::().iter()) + .map(|(&stored, &decoded)| (stored - decoded).abs()) + .fold(0.0f32, f32::max); + + assert!( + max_gap > 1e-3, + "expected at least one decoded norm to drift from the authoritative stored norms, got max gap {max_gap:.6}", + ); + Ok(()) +} + +#[test] +fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> { + let num_rows = 12; + let fsl = make_fsl(num_rows, 128, 11); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 1, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let encoded_cos = + execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; + let decoded = encoded.execute::(&mut ctx)?.into_array(); + let decoded_cos = execute_cosine_similarity(decoded.clone(), decoded, num_rows, &mut ctx)?; + + let decoded_values = decoded_cos.as_slice::(); + assert!( + decoded_values + .iter() + .all(|&value| (value - 1.0).abs() < 1e-5), + "decoded cosine(x, x) should stay at 1.0", + ); + + let max_gap = encoded_cos + .as_slice::() + .iter() + .zip(decoded_values.iter()) + .map(|(&encoded, &decoded)| (encoded - decoded).abs()) + .fold(0.0f32, f32::max); + assert!( + max_gap > 1e-3, + "expected encoded cosine readthrough to differ from decoded recomputation, got max gap {max_gap:.6}", + ); + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs new file mode 100644 index 00000000000..4c01affa27d --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tests for TurboQuant encoding with decomposed SorfTransform + DictArray tree. + +mod compute; +mod nullable; +mod roundtrip; +mod structural; + +use std::sync::LazyLock; + +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand_distr::Distribution; +use rand_distr::Normal; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::Dict; +use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnVTable; +use vortex_array::arrays::dict::DictArraySlotsExt; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_session::VortexSession; + +use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::vector::Vector; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +/// Create a FixedSizeListArray of random f32 vectors with the given validity. +fn make_fsl_with_validity( + num_rows: usize, + dim: usize, + seed: u64, + validity: Validity, +) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + validity, + num_rows, + ) + .unwrap() +} + +/// Create a non-nullable FixedSizeListArray of random f32 vectors. +fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + make_fsl_with_validity(num_rows, dim, seed, Validity::NonNullable) +} + +/// Wrap a `FixedSizeListArray` in a `Vector` extension array. +fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .unwrap() + .erased(); + ExtensionArray::new(ext_dtype, fsl.clone().into_array()) +} + +/// Full encode pipeline: normalize → TQ-encode → wrap in L2Denorm. +fn normalize_and_encode( + ext: &ExtensionArray, + config: &TurboQuantConfig, + ctx: &mut vortex_array::ExecutionCtx, +) -> VortexResult { + let l2_denorm = normalize_as_l2_denorm(ext.as_ref().clone(), ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + let num_rows = l2_denorm.len(); + + let normalized_ext = normalized + .as_opt::() + .vortex_expect("normalized child should be an Extension array"); + // SAFETY: We just normalized the input via `normalize_as_l2_denorm`. + let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx)? }; + + Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) +} + +/// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). +fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { + let sfn = encoded + .as_opt::() + .expect("expected ScalarFnArray (L2Denorm)"); + let sorf_child = sfn.child_at(0).clone(); + let norms_child = sfn.child_at(1).clone(); + (sorf_child, norms_child) +} + +/// Unwrap a SorfTransform ScalarFnArray to get the FSL(Dict) child. +fn unwrap_sorf(sorf: &ArrayRef) -> ArrayRef { + let sfn = sorf + .as_opt::() + .expect("expected ScalarFnArray (SorfTransform)"); + sfn.child_at(0).clone() +} + +/// Navigate the full tree to get (codes, centroids, norms) as flat arrays. +fn unwrap_codes_centroids_norms( + encoded: &ArrayRef, + ctx: &mut vortex_array::ExecutionCtx, +) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { + let (sorf_child, norms_child) = unwrap_l2denorm(encoded); + let padded_vector_child = unwrap_sorf(&sorf_child); + + // Vector wrapping FSL(Dict(codes, centroids)) + let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; + let fsl: FixedSizeListArray = padded_vector.storage_array().clone().execute(ctx)?; + let dict = fsl + .elements() + .as_opt::() + .vortex_expect("FSL elements should be a DictArray"); + let codes: PrimitiveArray = dict.codes().clone().execute(ctx)?; + let centroids: PrimitiveArray = dict.values().clone().execute(ctx)?; + let norms: PrimitiveArray = norms_child.execute(ctx)?; + + Ok((codes, centroids, norms)) +} + +fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) +} + +fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, +) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 +} + +/// Normalize, encode, and decode, returning (original, decoded) flat f32 slices. +fn encode_decode( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult<(Vec, Vec)> { + let mut ctx = SESSION.create_execution_ctx(); + let original: Vec = { + let prim = fsl.elements().clone().execute::(&mut ctx)?; + prim.as_slice::().to_vec() + }; + let ext = make_vector_ext(fsl); + let encoded = normalize_and_encode(&ext, config, &mut ctx)?; + let decoded_ext = encoded.execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .clone() + .execute::(&mut ctx)?; + let decoded_elements: Vec = { + let prim = decoded_fsl + .elements() + .clone() + .execute::(&mut ctx)?; + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) +} + +fn make_fsl_small(dim: usize) -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(dim); + for i in 0..dim { + buf.push(i as f32 + 1.0); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + 1, + ) + .unwrap() +} diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs new file mode 100644 index 00000000000..8d406239019 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; + +use super::*; + +/// Encode a nullable Vector array and verify roundtrip preserves validity and non-null values. +#[test] +fn nullable_vectors_roundtrip() -> VortexResult<()> { + let validity = Validity::from_iter([ + true, true, false, true, true, false, true, false, true, true, + ]); + let fsl = make_fsl_with_validity(10, 128, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 4, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + assert_eq!(encoded.len(), 10); + assert!(encoded.dtype().is_nullable()); + + let encoded_validity = encoded.validity()?; + for i in 0..10 { + let expected = ![2, 5, 7].contains(&i); + assert_eq!( + encoded_validity.is_valid(i)?, + expected, + "validity mismatch at row {i}" + ); + } + + let decoded_ext = encoded.execute::(&mut ctx)?; + assert_eq!(decoded_ext.len(), 10); + + let decoded_fsl = decoded_ext + .storage_array() + .clone() + .execute::(&mut ctx)?; + let decoded_prim = decoded_fsl + .elements() + .clone() + .execute::(&mut ctx)?; + let decoded_f32 = decoded_prim.as_slice::(); + + let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; + let orig_f32 = orig_prim.as_slice::(); + + for row in [0, 1, 3, 4, 6, 8, 9] { + let orig_vec = &orig_f32[row * 128..(row + 1) * 128]; + let dec_vec = &decoded_f32[row * 128..(row + 1) * 128]; + let norm_sq: f32 = orig_vec.iter().map(|&v| v * v).sum(); + let err_sq: f32 = orig_vec + .iter() + .zip(dec_vec.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + assert!( + err_sq / norm_sq < 0.1, + "non-null row {row} has excessive reconstruction error" + ); + } + Ok(()) +} + +/// Verify that norms carry the validity: null vectors have null norms. +#[test] +fn nullable_norms_match_validity() -> VortexResult<()> { + let validity = Validity::from_iter([true, false, true, false, true]); + let fsl = make_fsl_with_validity(5, 128, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); + + let norms_validity = norms_child.validity()?; + for i in 0..5 { + let expected = i % 2 == 0; + assert_eq!( + norms_validity.is_valid(i)?, + expected, + "norms validity mismatch at row {i}" + ); + } + Ok(()) +} + +/// Verify that L2Norm readthrough works correctly on nullable TurboQuant arrays. +#[test] +fn nullable_l2_norm_readthrough() -> VortexResult<()> { + use crate::scalar_fns::l2_norm::L2Norm; + + let validity = Validity::from_iter([true, false, true, false, true]); + let fsl = make_fsl_with_validity(5, 128, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let norm_sfn = L2Norm::try_new_array(encoded, 5)?; + let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; + + let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; + let orig_f32 = orig_prim.as_slice::(); + for row in 0..5 { + if row % 2 == 0 { + assert!(norms.is_valid(row)?, "row {row} should be valid"); + let expected: f32 = orig_f32[row * 128..(row + 1) * 128] + .iter() + .map(|&v| v * v) + .sum::() + .sqrt(); + let actual = norms.as_slice::()[row]; + assert!( + (actual - expected).abs() < 1e-5, + "norm mismatch at valid row {row}: actual={actual}, expected={expected}" + ); + } else { + assert!(!norms.is_valid(row)?, "row {row} should be null"); + } + } + Ok(()) +} + +/// Verify that slicing a nullable TurboQuant array preserves validity. +#[test] +fn nullable_slice_preserves_validity() -> VortexResult<()> { + let validity = Validity::from_iter([ + true, true, false, true, true, false, true, false, true, true, + ]); + let fsl = make_fsl_with_validity(10, 128, 42, validity); + let ext = make_vector_ext(&fsl); + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 2, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let sliced = encoded.slice(1..6)?; + assert_eq!(sliced.len(), 5); + + let sliced_validity = sliced.validity()?; + let expected = [true, false, true, true, false]; + for (i, &exp) in expected.iter().enumerate() { + assert_eq!( + sliced_validity.is_valid(i)?, + exp, + "sliced validity mismatch at index {i}" + ); + } + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs new file mode 100644 index 00000000000..cd61d9193da --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use super::*; + +#[rstest] +#[case(128, 1)] +#[case(128, 2)] +#[case(128, 3)] +#[case(128, 4)] +#[case(128, 6)] +#[case(128, 8)] +#[case(256, 2)] +fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + num_rounds: 3, + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) +} + +#[rstest] +#[case(128, 1)] +#[case(128, 2)] +#[case(128, 3)] +#[case(128, 4)] +#[case(256, 2)] +#[case(256, 4)] +fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + num_rounds: 3, + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) +} + +#[rstest] +#[case(128, 6)] +#[case(128, 8)] +#[case(256, 6)] +#[case(256, 8)] +fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + num_rounds: 3, + }; + let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + num_rounds: 3, + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) +} + +#[test] +fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + num_rounds: 3, + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) +} + +#[rstest] +#[case(0)] +#[case(1)] +fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) +} + +#[rstest] +#[case(1)] +#[case(64)] +#[case(127)] +fn rejects_dimension_below_128(#[case] dim: usize) { + let fsl = make_fsl_small(dim); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + assert!( + crate::encodings::turboquant::turboquant_encode(ext.as_view(), &config, &mut ctx).is_err() + ); +} + +#[rstest] +#[case(0)] +#[case(9)] +fn rejects_invalid_bit_width(#[case] bit_width: u8) { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width, + seed: Some(0), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx) + .unwrap() + .child_at(0) + .clone(); + let normalized_ext = normalized + .as_opt::() + .expect("normalized child should be Extension"); + assert!(unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx) }.is_err()); +} + +#[test] +fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into() + .expect("somehow got dimension greater than u32::MAX"), + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + num_rounds: 3, + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) +} + +/// Roundtrip at large embedding dimensions. +#[rstest] +#[case(768, 4)] +#[case(1024, 5)] +fn large_dimension_roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 10; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + num_rounds: 3, + }; + let (original, decoded) = encode_decode(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + // 2x slack for the SRHT-vs-Haar gap. + let bound = 2.0 * theoretical_mse_bound(bit_width); + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds 2x bound {bound:.6} for dim={dim}, bits={bit_width}", + ); + Ok(()) +} + +/// Verify that f64 input is accepted and encoded. +#[test] +fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into().unwrap(), + Validity::NonNullable, + num_rows, + )?; + + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); + assert_eq!(norms_child.len(), num_rows); + Ok(()) +} + +/// Verify that f16 input is accepted and encoded. +#[test] +fn f16_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(half::f16::from_f32(normal.sample(&mut rng))); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into().unwrap(), + Validity::NonNullable, + num_rows, + )?; + + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); + assert_eq!(norms_child.len(), num_rows); + + let decoded_ext = encoded.execute::(&mut ctx)?; + let decoded_fsl = decoded_ext + .storage_array() + .clone() + .execute::(&mut ctx)?; + assert_eq!(decoded_fsl.len(), num_rows); + Ok(()) +} + +/// Verify that the checked encode accepts normalized f16 input. +#[test] +fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(half::f16::from_f32(normal.sample(&mut rng))); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into().unwrap(), + Validity::NonNullable, + num_rows, + )?; + + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + num_rounds: 3, + }; + + let mut ctx = SESSION.create_execution_ctx(); + let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx)? + .child_at(0) + .clone(); + let normalized_ext = normalized + .as_opt::() + .vortex_expect("normalized child should be an Extension array"); + + let encoded = + crate::encodings::turboquant::turboquant_encode(normalized_ext, &config, &mut ctx)?; + assert_eq!(encoded.len(), num_rows); + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs new file mode 100644 index 00000000000..87b59836b38 --- /dev/null +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -0,0 +1,333 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tests that verify the internal structure of the encoded tree. + +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_error::VortexResult; + +use super::*; + +/// Verify that the centroids stored in the DictArray match what `get_centroids()` computes. +#[test] +fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let (_codes, centroids, _norms) = unwrap_codes_centroids_norms(&encoded, &mut ctx)?; + let stored = centroids.as_slice::(); + + // padded_dim for dim=128 is 128. + let computed = crate::encodings::turboquant::centroids::get_centroids(128, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) +} + +/// Verify that the rotation is deterministic from seed by checking decode output. +#[test] +fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 4, + }; + + // Encode twice with the same seed → should produce identical results. + let mut ctx = SESSION.create_execution_ctx(); + let encoded1 = normalize_and_encode(&ext, &config, &mut ctx)?; + let decoded1 = encoded1.execute::(&mut ctx)?; + let fsl1 = decoded1 + .storage_array() + .clone() + .execute::(&mut ctx)?; + let elems1 = fsl1 + .elements() + .clone() + .execute::(&mut ctx)?; + + let mut ctx = SESSION.create_execution_ctx(); + let encoded2 = normalize_and_encode(&ext, &config, &mut ctx)?; + let decoded2 = encoded2.execute::(&mut ctx)?; + let fsl2 = decoded2 + .storage_array() + .clone() + .execute::(&mut ctx)?; + let elems2 = fsl2 + .elements() + .clone() + .execute::(&mut ctx)?; + + assert_eq!( + elems1.as_slice::(), + elems2.as_slice::(), + "Two encodes with same seed should produce identical decode output" + ); + Ok(()) +} + +/// Verify that the encoded array's dtype is a Vector extension type. +#[test] +fn encoded_dtype_is_vector_extension() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + num_rounds: 2, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + assert!( + encoded.dtype().is_extension(), + "TurboQuant dtype should be an extension type, got {}", + encoded.dtype() + ); + assert!( + encoded.dtype().as_extension().is::(), + "TurboQuant dtype should be a Vector extension type" + ); + Ok(()) +} + +/// Verify approximate cosine similarity in the quantized domain. +#[test] +fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let input_prim = fsl.elements().clone().execute::(&mut ctx)?; + let input_f32 = input_prim.as_slice::(); + + // Navigate tree to get codes, centroids, norms. + let (codes_prim, centroids_prim, norms_prim) = + unwrap_codes_centroids_norms(&encoded, &mut ctx)?; + let all_codes = codes_prim.as_slice::(); + let centroid_vals = centroids_prim.as_slice::(); + let norms = norms_prim.as_slice::(); + + // padded_dim for dim=128. + let pd = 128usize; + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); + let exact_cos = dot / (norm_a * norm_b); + + let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { + 0.0 + } else { + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum::() + }; + + let error = (exact_cos - approx_cos).abs(); + assert!( + error < 0.15, + "cosine similarity error too large for ({row_a}, {row_b}): \ + exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" + ); + } + Ok(()) +} + +/// Verify approximate dot product in the quantized domain. +#[test] +fn dot_product_quantized_accuracy() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let ext = make_vector_ext(&fsl); + let config = TurboQuantConfig { + bit_width: 8, + seed: Some(123), + num_rounds: 3, + }; + let mut ctx = SESSION.create_execution_ctx(); + let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + + let input_prim = fsl.elements().clone().execute::(&mut ctx)?; + let input_f32 = input_prim.as_slice::(); + + let (codes_prim, centroids_prim, norms_prim) = + unwrap_codes_centroids_norms(&encoded, &mut ctx)?; + let all_codes = codes_prim.as_slice::(); + let centroid_vals = centroids_prim.as_slice::(); + let norms = norms_prim.as_slice::(); + + let pd = 128usize; + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let exact_dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); + + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + let unit_dot: f32 = codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) + .sum(); + let approx_dot = norms[row_a] * norms[row_b] * unit_dot; + + let scale = exact_dot.abs().max(1.0); + let rel_error = (exact_dot - approx_dot).abs() / scale; + assert!( + rel_error < 0.15, + "dot product error too large for ({row_a}, {row_b}): \ + exact={exact_dot:.4}, approx={approx_dot:.4}, rel_error={rel_error:.4}" + ); + } + Ok(()) +} + +/// Verify SorfTransform in isolation: manually forward-rotate known data, wrap in +/// FSL(Dict), execute SorfTransform, and check inverse rotation recovers the original. +#[test] +#[expect( + clippy::cast_possible_truncation, + reason = "test uses known small dimensions" +)] +fn sorf_transform_roundtrip_isolation() -> VortexResult<()> { + use vortex_array::IntoArray; + use vortex_array::arrays::dict::DictArray; + use vortex_array::dtype::extension::ExtDType; + use vortex_array::extension::EmptyMetadata; + use vortex_array::validity::Validity; + use vortex_buffer::BufferMut; + + use crate::encodings::turboquant::centroids::compute_centroid_boundaries; + use crate::encodings::turboquant::centroids::find_nearest_centroid; + use crate::encodings::turboquant::centroids::get_centroids; + use crate::scalar_fns::sorf_transform::SorfMatrix; + use crate::scalar_fns::sorf_transform::SorfOptions; + use crate::scalar_fns::sorf_transform::SorfTransform; + use crate::vector::Vector; + + let dim = 128usize; + let seed = 99u64; + let num_rounds = 3u8; + let num_rows = 5; + + // Build a known input: simple increasing values, then normalize each row to unit norm. + let mut input_f32 = vec![0.0f32; num_rows * dim]; + for row in 0..num_rows { + let mut norm_sq = 0.0f32; + for i in 0..dim { + let val = ((row * dim + i) as f32 + 1.0) * 0.01; + input_f32[row * dim + i] = val; + norm_sq += val * val; + } + let norm = norm_sq.sqrt(); + for i in 0..dim { + input_f32[row * dim + i] /= norm; + } + } + + // Forward transform + quantize (mimicking what turboquant_quantize_core does). + let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?; + let padded_dim = rotation.padded_dim(); + let centroids = get_centroids(padded_dim as u32, 8)?; + let boundaries = compute_centroid_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + padded[..dim].copy_from_slice(&input_f32[row * dim..(row + 1) * dim]); + padded[dim..].fill(0.0); + rotation.rotate(&padded, &mut rotated); + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + // Build FSL(Dict(codes, centroids)). + let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); + let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); + centroids_buf.extend_from_slice(¢roids); + let centroids_arr = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + let dict = DictArray::try_new(codes.into_array(), centroids_arr.into_array())?; + let fsl = FixedSizeListArray::try_new( + dict.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )?; + + // Wrap the padded FSL in a Vector extension so it can be the SorfTransform child. + let padded_vector_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let padded_vector = ExtensionArray::new(padded_vector_dtype, fsl.into_array()); + + // Wrap in SorfTransform and execute. + let sorf_options = SorfOptions { + seed, + num_rounds, + dimension: dim as u32, + element_ptype: vortex_array::dtype::PType::F32, + }; + let sorf_array = + SorfTransform::try_new_array(&sorf_options, padded_vector.into_array(), num_rows)?; + + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf_array.into_array().execute(&mut ctx)?; + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; + let result_f32 = result_prim.as_slice::(); + + assert_eq!(result_f32.len(), num_rows * dim); + + // At 8-bit quantization, reconstruction should be very close to input. + for row in 0..num_rows { + let orig = &input_f32[row * dim..(row + 1) * dim]; + let recon = &result_f32[row * dim..(row + 1) * dim]; + let err_sq: f32 = orig + .iter() + .zip(recon) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + assert!( + err_sq / norm_sq < 1e-3, + "SorfTransform isolation: row {row} MSE too high: {:.6}", + err_sq / norm_sq + ); + } + Ok(()) +} diff --git a/vortex-tensor/src/encodings/turboquant/vtable.rs b/vortex-tensor/src/encodings/turboquant/vtable.rs deleted file mode 100644 index b673dea6ba5..00000000000 --- a/vortex-tensor/src/encodings/turboquant/vtable.rs +++ /dev/null @@ -1,382 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! VTable implementation for TurboQuant encoding. - -use std::hash::Hash; -use std::hash::Hasher; -use std::sync::Arc; - -use prost::Message; -use vortex_array::Array; -use vortex_array::ArrayEq; -use vortex_array::ArrayHash; -use vortex_array::ArrayId; -use vortex_array::ArrayParts; -use vortex_array::ArrayRef; -use vortex_array::ArrayView; -use vortex_array::ExecutionCtx; -use vortex_array::ExecutionResult; -use vortex_array::Precision; -use vortex_array::buffer::BufferHandle; -use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::serde::ArrayChildren; -use vortex_array::validity::Validity; -use vortex_array::vtable::VTable; -use vortex_array::vtable::ValidityVTable; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_ensure_eq; -use vortex_error::vortex_err; -use vortex_error::vortex_panic; -use vortex_session::VortexSession; - -use crate::encodings::turboquant::TurboQuantData; -use crate::encodings::turboquant::array::slots::Slot; -use crate::encodings::turboquant::compute::rules::PARENT_KERNELS; -use crate::encodings::turboquant::compute::rules::RULES; -use crate::encodings::turboquant::metadata::TurboQuantMetadata; -use crate::encodings::turboquant::scheme::decompress::execute_decompress; -use crate::vector::AnyVector; -use crate::vector::VectorMatcherMetadata; - -/// Encoding marker type for TurboQuant. -#[derive(Clone, Debug)] -pub struct TurboQuant; - -impl TurboQuant { - pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); - - /// Minimum vector dimension for TurboQuant encoding. - /// - /// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total - /// amount of distortion. - pub const MIN_DIMENSION: u32 = 128; - - /// Maximum supported number of bits per quantized coordinate. - pub const MAX_BIT_WIDTH: u8 = 8; - - /// Maximum supported number of centroids in the scalar quantizer codebook. - pub const MAX_CENTROIDS: usize = 1usize << (Self::MAX_BIT_WIDTH as usize); - - /// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with - /// dimension >= [`MIN_DIMENSION`](Self::MIN_DIMENSION). - /// - /// Returns the validated vector metadata on success. - pub fn validate_dtype(dtype: &DType) -> VortexResult { - let vector_metadata = dtype - .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") - })?; - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= Self::MIN_DIMENSION, - "TurboQuant requires dimension >= {}, got {dimensions}", - Self::MIN_DIMENSION - ); - - Ok(vector_metadata) - } - - /// Creates a new [`TurboQuantArray`]. - /// - /// The `dtype` must be a non-nullable [`Vector`](crate::vector::Vector) extension type. - /// Nullability is handled externally by the [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm) - /// ScalarFnArray wrapper. - /// - /// Internally calls [`TurboQuantData::validate`] and [`TurboQuantData::try_new`], then - /// delegates to [`new_array_unchecked`](Self::new_array_unchecked). - pub fn try_new_array( - dtype: DType, - codes: ArrayRef, - centroids: ArrayRef, - rotation_signs: ArrayRef, - ) -> VortexResult { - TurboQuantData::validate(&dtype, &codes, ¢roids, &rotation_signs)?; - - Ok(unsafe { Self::new_array_unchecked(dtype, codes, centroids, rotation_signs) }) - } - - /// Creates a new [`TurboQuantArray`] without validation. - /// - /// # Safety - /// - /// The caller must ensure all invariants required by [`TurboQuantData::validate`] hold: - /// - /// - `dtype` is a non-nullable [`Vector`](crate::vector::Vector) extension type with - /// dimension >= [`MIN_DIMENSION`](Self::MIN_DIMENSION). - /// - `codes` is a non-nullable `FixedSizeList` with `list_size == padded_dim`. - /// - `centroids` is a non-nullable `Primitive` with a power-of-2 length in - /// `[2, MAX_CENTROIDS]` (or empty for degenerate arrays). - /// - `rotation_signs` is a non-nullable `FixedSizeList` with `list_size == padded_dim`. - /// - /// Violating these invariants may produce incorrect results during decompression or panics - /// during array access. - pub unsafe fn new_array_unchecked( - dtype: DType, - codes: ArrayRef, - centroids: ArrayRef, - rotation_signs: ArrayRef, - ) -> TurboQuantArray { - #[cfg(debug_assertions)] - TurboQuantData::validate(&dtype, &codes, ¢roids, &rotation_signs) - .vortex_expect("[DEBUG ASSERTION]: TurboQuantData arrays are invalid"); - - let len = codes.len(); - - let dimension = dtype - .as_extension_opt() - .vortex_expect("we validated the dtype") - .metadata_opt::() - .vortex_expect("we validated that this is a vector") - .dimensions(); - - let bit_width = if centroids.is_empty() { - 0 - } else { - #[expect( - clippy::cast_possible_truncation, - reason = "bit_width is guaranteed <= 8" - )] - (centroids.len().trailing_zeros() as u8) - }; - - #[expect( - clippy::cast_possible_truncation, - reason = "num_rounds fits in u8 by the caller's invariants" - )] - let num_rounds = rotation_signs.len() as u8; - - // SAFETY: The caller guarantees that dimension, bit_width, and num_rounds satisfy the - // invariants documented on `TurboQuantData::new_unchecked`. - let data = unsafe { TurboQuantData::new_unchecked(dimension, bit_width, num_rounds) }; - let parts = ArrayParts::new(TurboQuant, dtype, len, data) - .with_slots(TurboQuantData::make_slots(codes, centroids, rotation_signs)); - - // SAFETY: The caller guarantees the parts are logically consistent. - unsafe { Array::from_parts_unchecked(parts) } - } -} - -/// A [`TurboQuant`]-encoded Vortex array. -pub type TurboQuantArray = Array; - -impl VTable for TurboQuant { - type ArrayData = TurboQuantData; - type OperationsVTable = TurboQuant; - type ValidityVTable = TurboQuant; - - fn id(&self) -> ArrayId { - Self::ID - } - - fn validate( - &self, - data: &Self::ArrayData, - dtype: &DType, - len: usize, - slots: &[Option], - ) -> VortexResult<()> { - vortex_ensure_eq!( - slots.len(), - Slot::COUNT, - "TurboQuantArray got incorrect amount of slots", - ); - - // Even if the array is degenerate (empty), the arrays still have to exist - // (they will be empty). - let codes = slots[Slot::Codes as usize] - .as_ref() - .ok_or_else(|| vortex_err!("TurboQuantArray missing codes slot"))?; - let centroids = slots[Slot::Centroids as usize] - .as_ref() - .ok_or_else(|| vortex_err!("TurboQuantArray missing centroids slot"))?; - let rotation_signs = slots[Slot::RotationSigns as usize] - .as_ref() - .ok_or_else(|| vortex_err!("TurboQuantArray missing rotation_signs slot"))?; - - vortex_ensure_eq!( - codes.len(), - len, - "TurboQuant codes length does not match outer length", - ); - - TurboQuantData::validate(dtype, codes, centroids, rotation_signs)?; - - vortex_ensure_eq!(data.dimension, Self::validate_dtype(dtype)?.dimensions()); - - let expected_bit_width = if centroids.is_empty() { - 0 - } else { - u8::try_from(centroids.len().trailing_zeros()) - .map_err(|_| vortex_err!("centroids bit_width does not fit in u8"))? - }; - vortex_ensure_eq!( - data.bit_width, - expected_bit_width, - "TurboQuant bit_width does not match centroids slot", - ); - - // Verify num_rounds matches the rotation_signs FSL length. - let expected_num_rounds = u8::try_from(rotation_signs.len()) - .map_err(|_| vortex_err!("rotation_signs num_rounds does not fit in u8"))?; - vortex_ensure_eq!( - data.num_rounds, - expected_num_rounds, - "TurboQuant num_rounds does not match rotation_signs slot", - ); - - Ok(()) - } - - fn nbuffers(_array: ArrayView) -> usize { - 0 - } - - fn buffer(_array: ArrayView, idx: usize) -> BufferHandle { - vortex_panic!("TurboQuantArray buffer index {idx} out of bounds") - } - - fn buffer_name(_array: ArrayView, _idx: usize) -> Option { - None - } - - fn serialize( - array: ArrayView<'_, Self>, - _session: &VortexSession, - ) -> VortexResult>> { - Ok(Some( - TurboQuantMetadata::new(array.bit_width, array.num_rounds).encode_to_vec(), - )) - } - - fn deserialize( - &self, - dtype: &DType, - len: usize, - metadata: &[u8], - _buffers: &[BufferHandle], - children: &dyn ArrayChildren, - _session: &VortexSession, - ) -> VortexResult> { - let metadata = TurboQuantMetadata::decode(metadata)?; - let bit_width = metadata.bit_width()?; - let num_rounds = metadata.num_rounds()?; - - // bit_width == 0 and num_rounds == 0 are only valid for degenerate (empty) arrays. - vortex_ensure!( - bit_width > 0 || len == 0, - "bit_width == 0 is only valid for empty arrays, got len={len}" - ); - vortex_ensure!( - num_rounds > 0 || len == 0, - "num_rounds == 0 is only valid for empty arrays, got len={len}" - ); - - // Validate and derive dimension from the Vector extension dtype. - let vector_metadata = TurboQuant::validate_dtype(dtype)?; - let dimensions = vector_metadata.dimensions(); - - // TurboQuant arrays are always non-nullable. - vortex_ensure!( - !dtype.is_nullable(), - "TurboQuant dtype must be non-nullable during deserialization" - ); - - let padded_dim = dimensions.next_power_of_two(); - - // Get the codes array (indices into the codebook). Codes are always non-nullable. - let codes_ptype = DType::Primitive(PType::U8, Nullability::NonNullable); - let codes_dtype = - DType::FixedSizeList(Arc::new(codes_ptype), padded_dim, Nullability::NonNullable); - let codes_array = children.get(0, &codes_dtype, len)?; - - // Get the centroids array (codebook). - let num_centroids = if bit_width == 0 { - 0 // A degenerate TQ array. - } else { - 1usize << bit_width - }; - let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let centroids = children.get(1, ¢roids_dtype, num_centroids)?; - - // Get the rotation signs array (FixedSizeList with list_size = padded_dim). - let signs_len = if len == 0 { 0 } else { num_rounds as usize }; - let signs_dtype = DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, - Nullability::NonNullable, - ); - let rotation_signs = children.get(2, &signs_dtype, signs_len)?; - - Ok(ArrayParts::new( - TurboQuant, - dtype.clone(), - len, - TurboQuantData { - dimension: dimensions, - bit_width, - num_rounds, - }, - ) - .with_slots(TurboQuantData::make_slots( - codes_array, - centroids, - rotation_signs, - ))) - } - - fn slot_name(_array: ArrayView, idx: usize) -> String { - Slot::from_index(idx).name().to_string() - } - fn execute(array: Array, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionResult::done(execute_decompress(array, ctx)?)) - } - - fn execute_parent( - array: ArrayView, - parent: &ArrayRef, - child_idx: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - PARENT_KERNELS.execute(array, parent, child_idx, ctx) - } - - fn reduce_parent( - array: ArrayView, - parent: &ArrayRef, - child_idx: usize, - ) -> VortexResult> { - RULES.evaluate(array, parent, child_idx) - } -} - -impl ValidityVTable for TurboQuant { - fn validity(_array: ArrayView<'_, TurboQuant>) -> VortexResult { - // TurboQuant arrays are always non-nullable. This method is only called when the dtype is - // nullable, which should never happen for TQ arrays. - Ok(Validity::NonNullable) - } -} - -impl ArrayHash for TurboQuantData { - fn array_hash(&self, state: &mut H, _precision: Precision) { - self.dimension.hash(state); - self.bit_width.hash(state); - self.num_rounds.hash(state); - } -} - -impl ArrayEq for TurboQuantData { - fn array_eq(&self, other: &Self, _precision: Precision) -> bool { - self.dimension == other.dimension - && self.bit_width == other.bit_width - && self.num_rounds == other.num_rounds - } -} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index a4577f6d262..3d3563aa8e4 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -7,15 +7,14 @@ use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; -use vortex_array::session::ArraySessionExt; use vortex_session::VortexSession; -use crate::encodings::turboquant::TurboQuant; use crate::fixed_shape::FixedShapeTensor; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_norm::L2Norm; +use crate::scalar_fns::sorf_transform::SorfTransform; use crate::vector::Vector; pub mod matcher; @@ -33,10 +32,9 @@ pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); session.dtypes().register(FixedShapeTensor); - session.arrays().register(TurboQuant); - session.scalar_fns().register(CosineSimilarity); session.scalar_fns().register(InnerProduct); session.scalar_fns().register(L2Denorm); session.scalar_fns().register(L2Norm); + session.scalar_fns().register(SorfTransform); } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 39c8b135e43..f4f129c2b27 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -20,6 +20,7 @@ use vortex_array::expr::and; use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::EmptyOptions; use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; @@ -29,7 +30,6 @@ use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; -use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_norm::L2Norm; @@ -45,16 +45,20 @@ use crate::utils::validate_tensor_float_input; /// Both inputs must be tensor-like extension arrays ([`FixedShapeTensor`] or [`Vector`]) with the /// same dtype and a float element type. The output is a float column of the same float type. /// +/// When either input is wrapped in [`L2Denorm`], this operator treats the stored norms and +/// normalized children as authoritative. For lossy encodings such as TurboQuant, that means the +/// optimized readthrough path may intentionally differ slightly from decoding both sides to dense +/// coordinates and recomputing cosine from scratch. +/// /// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor /// [`Vector`]: crate::vector::Vector #[derive(Clone)] pub struct CosineSimilarity; impl CosineSimilarity { - /// Creates a new [`ScalarFn`] wrapping the cosine similarity operation with the given - /// [`ApproxOptions`] controlling approximation behavior. - pub fn new(options: &ApproxOptions) -> ScalarFn { - ScalarFn::new(CosineSimilarity, options.clone()) + /// Creates a new [`ScalarFn`] wrapping the cosine similarity operation. + pub fn new() -> ScalarFn { + ScalarFn::new(CosineSimilarity, EmptyOptions) } /// Constructs a [`ScalarFnArray`] that lazily computes the cosine similarity between `lhs` and @@ -64,18 +68,13 @@ impl CosineSimilarity { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array( - options: &ApproxOptions, - lhs: ArrayRef, - rhs: ArrayRef, - len: usize, - ) -> VortexResult { - ScalarFnArray::try_new(CosineSimilarity::new(options).erased(), vec![lhs, rhs], len) + pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult { + ScalarFnArray::try_new(CosineSimilarity::new().erased(), vec![lhs, rhs], len) } } impl ScalarFnVTable for CosineSimilarity { - type Options = ApproxOptions; + type Options = EmptyOptions; fn id(&self) -> ScalarFnId { ScalarFnId::from("vortex.tensor.cosine_similarity") @@ -126,7 +125,7 @@ impl ScalarFnVTable for CosineSimilarity { fn execute( &self, - options: &Self::Options, + _options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { @@ -140,12 +139,12 @@ impl ScalarFnVTable for CosineSimilarity { let rhs_is_denorm = rhs_ref.is::>(); if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); } else if lhs_is_denorm || rhs_is_denorm { if rhs_is_denorm { (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); } - return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); } } @@ -153,9 +152,9 @@ impl ScalarFnVTable for CosineSimilarity { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; // Compute inner product and norms as columnar operations, and propagate the options. - let norm_lhs_arr = L2Norm::try_new_array(options, lhs_ref.clone(), len)?; - let norm_rhs_arr = L2Norm::try_new_array(options, rhs_ref.clone(), len)?; - let dot_arr = InnerProduct::try_new_array(options, lhs_ref, rhs_ref, len)?; + let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone(), len)?; + let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(lhs_ref, rhs_ref, len)?; // Execute to get the inner product and norms of the arrays. We only fully decompress // because we need to perform special logic (guard against 0) during division. @@ -208,10 +207,10 @@ impl ScalarFnVTable for CosineSimilarity { } impl CosineSimilarity { - /// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`. + /// Both sides are `L2Denorm`: treat the normalized children as authoritative, so + /// `cosine_similarity = dot(n_l, n_r)`. fn execute_both_denorm( &self, - options: &ApproxOptions, lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, @@ -222,9 +221,9 @@ impl CosineSimilarity { let (normalized_l, _) = extract_l2_denorm_children(lhs_ref); let (normalized_r, _) = extract_l2_denorm_children(rhs_ref); - // Dot product of already-normalized children IS the cosine similarity. - let dot = - InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?.into_array(); + // `L2Denorm` makes the normalized children authoritative, so their dot product is the + // cosine similarity even for lossy storage wrappers. + let dot = InnerProduct::try_new_array(normalized_l, normalized_r, len)?.into_array(); if !matches!(validity, Validity::NonNullable) { // Masking always changes the nullability to nullable. @@ -234,12 +233,12 @@ impl CosineSimilarity { } } - /// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`. + /// One side is `L2Denorm`: treat the normalized child as authoritative, so + /// `cosine_similarity = dot(n, b) / ||b||`. /// /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. fn execute_one_denorm( &self, - options: &ApproxOptions, denorm_ref: &ArrayRef, plain_ref: &ArrayRef, len: usize, @@ -249,8 +248,8 @@ impl CosineSimilarity { let (normalized, _) = extract_l2_denorm_children(denorm_ref); - let dot_arr = InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)?; - let norm_arr = L2Norm::try_new_array(options, plain_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?; + let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?; let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?; @@ -286,13 +285,11 @@ mod tests { use vortex_array::arrays::MaskedArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; - use vortex_array::scalar_fn::ScalarFn; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_session::VortexSession; - use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::test_helpers::assert_close; @@ -306,7 +303,7 @@ mod tests { /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { - let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased(); + let scalar_fn = CosineSimilarity::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -477,7 +474,7 @@ mod tests { let rhs = tensor_array(&[2], &[3.0, 4.0, 0.0, 1.0])?; let rhs = MaskedArray::try_new(rhs, Validity::from_iter([true, false]))?.into_array(); - let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased(); + let scalar_fn = CosineSimilarity::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -499,10 +496,7 @@ mod tests { let normalized = tensor_array(shape, normalized_elements)?; let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); let mut ctx = SESSION.create_execution_ctx(); - Ok( - L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)? - .into_array(), - ) + Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) } #[test] @@ -570,11 +564,9 @@ mod tests { let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let rhs = - L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_r, norms_r, 2, &mut ctx)? - .into_array(); + let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); - let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased(); + let scalar_fn = CosineSimilarity::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index da1b62e6ca1..71288c241ae 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -22,6 +22,7 @@ use vortex_array::expr::and; use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::EmptyOptions; use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; @@ -33,7 +34,6 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_err; use crate::matcher::AnyTensor; -use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::extract_flat_elements; use crate::utils::extract_l2_denorm_children; @@ -53,10 +53,9 @@ use crate::utils::extract_l2_denorm_children; pub struct InnerProduct; impl InnerProduct { - /// Creates a new [`ScalarFn`] wrapping the inner product operation with the given - /// [`ApproxOptions`] controlling approximation behavior. - pub fn new(options: &ApproxOptions) -> ScalarFn { - ScalarFn::new(InnerProduct, options.clone()) + /// Creates a new [`ScalarFn`] wrapping the inner product operation. + pub fn new() -> ScalarFn { + ScalarFn::new(InnerProduct, EmptyOptions) } /// Constructs a [`ScalarFnArray`] that lazily computes the inner product between `lhs` and @@ -66,18 +65,13 @@ impl InnerProduct { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array( - options: &ApproxOptions, - lhs: ArrayRef, - rhs: ArrayRef, - len: usize, - ) -> VortexResult { - ScalarFnArray::try_new(InnerProduct::new(options).erased(), vec![lhs, rhs], len) + pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult { + ScalarFnArray::try_new(InnerProduct::new().erased(), vec![lhs, rhs], len) } } impl ScalarFnVTable for InnerProduct { - type Options = ApproxOptions; + type Options = EmptyOptions; fn id(&self) -> ScalarFnId { ScalarFnId::from("vortex.tensor.inner_product") @@ -144,7 +138,7 @@ impl ScalarFnVTable for InnerProduct { fn execute( &self, - options: &Self::Options, + _options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { @@ -158,12 +152,12 @@ impl ScalarFnVTable for InnerProduct { let rhs_is_denorm = rhs_ref.is::>(); if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); } else if lhs_is_denorm || rhs_is_denorm { if rhs_is_denorm { (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); } - return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); } } @@ -225,7 +219,6 @@ impl InnerProduct { /// Both sides are `L2Denorm`: `inner_product = s_l * s_r * dot(n_l, n_r)`. fn execute_both_denorm( &self, - options: &ApproxOptions, lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, @@ -239,10 +232,9 @@ impl InnerProduct { let norms_l: PrimitiveArray = norms_l.execute(ctx)?; let norms_r: PrimitiveArray = norms_r.execute(ctx)?; - let dot: PrimitiveArray = - InnerProduct::try_new_array(options, normalized_l, normalized_r, len)? - .into_array() - .execute(ctx)?; + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? + .into_array() + .execute(ctx)?; match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); @@ -260,7 +252,6 @@ impl InnerProduct { /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. fn execute_one_denorm( &self, - options: &ApproxOptions, denorm_ref: &ArrayRef, plain_ref: &ArrayRef, len: usize, @@ -271,10 +262,9 @@ impl InnerProduct { let (normalized, norms) = extract_l2_denorm_children(denorm_ref); let denorm_norms: PrimitiveArray = norms.execute(ctx)?; - let dot: PrimitiveArray = - InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)? - .into_array() - .execute(ctx)?; + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)? + .into_array() + .execute(ctx)?; match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); @@ -308,13 +298,11 @@ mod tests { use vortex_array::arrays::MaskedArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; - use vortex_array::scalar_fn::ScalarFn; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_session::VortexSession; - use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::test_helpers::assert_close; @@ -326,7 +314,7 @@ mod tests { /// Evaluates inner product between two tensor arrays and returns the result as `Vec`. fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { - let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased(); + let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -404,7 +392,7 @@ mod tests { let rhs = tensor_array(&[2], &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0])?; let lhs = MaskedArray::try_new(lhs, Validity::from_iter([true, false, true]))?.into_array(); - let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased(); + let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -422,7 +410,7 @@ mod tests { fn rejects_non_extension_dtype() { let lhs = PrimitiveArray::from_iter([1.0_f64, 2.0]).into_array(); let rhs = PrimitiveArray::from_iter([3.0_f64, 4.0]).into_array(); - let result = InnerProduct::try_new_array(&ApproxOptions::Exact, lhs, rhs, 2); + let result = InnerProduct::try_new_array(lhs, rhs, 2); assert!(result.is_err()); } @@ -430,7 +418,7 @@ mod tests { fn rejects_mismatched_dtypes() -> VortexResult<()> { let lhs = tensor_array(&[2], &[1.0_f64, 2.0])?; let rhs = vector_array(2, &[3.0_f64, 4.0])?; - let result = InnerProduct::try_new_array(&ApproxOptions::Exact, lhs, rhs, 1); + let result = InnerProduct::try_new_array(lhs, rhs, 1); assert!(result.is_err()); Ok(()) } @@ -447,10 +435,7 @@ mod tests { let normalized = tensor_array(shape, normalized_elements)?; let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); let mut ctx = SESSION.create_execution_ctx(); - Ok( - L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)? - .into_array(), - ) + Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) } #[test] @@ -508,12 +493,10 @@ mod tests { let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let lhs = - L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_l, norms_l, 2, &mut ctx)? - .into_array(); + let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased(); + let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 04a348ff7d8..d72940c8da8 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -23,6 +23,7 @@ use vortex_array::expr::and; use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::EmptyOptions; use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; @@ -37,12 +38,11 @@ use vortex_error::vortex_ensure; use vortex_error::vortex_ensure_eq; use crate::matcher::AnyTensor; -use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extract_flat_elements; use crate::utils::validate_tensor_float_input; -/// Re-applies L2 norms to a normalized tensor column. +/// Re-applies authoritative L2 norms to a normalized tensor column. /// /// Computes `normalized * norm` on each row over the flat backing buffer of each tensor-like type. /// @@ -51,17 +51,28 @@ use crate::utils::validate_tensor_float_input; /// /// The norms input must be a primitive float column with the same element type as the normalized /// tensor elements. +/// +/// [`L2Denorm`] is the norm-splitting wrapper used throughout the tensor crate. Callers that build +/// it through [`try_new_array`](Self::try_new_array) get an exact unit-norm invariant on the +/// `normalized` child. +/// +/// Advanced callers can also use [`new_array_unchecked`](Self::new_array_unchecked) to attach +/// authoritative stored norms to a lossy approximation of that child, such as quantized normalized +/// vectors. +/// +/// Downstream readthrough rules intentionally treat the stored norms and normalized child as the +/// encoding contract, even when that differs slightly from recomputing over fully decoded +/// coordinates. #[derive(Clone)] pub struct L2Denorm; impl L2Denorm { - /// Creates a new [`ScalarFn`] wrapping the L2 denormalization operation with the given - /// [`ApproxOptions`] controlling approximation behavior. + /// Creates a new [`ScalarFn`] wrapping the L2 denormalization operation. /// /// This is a low-level scalar-function descriptor constructor. To build a semantically valid /// [`L2Denorm`] array, prefer [`try_new_array`](Self::try_new_array). - pub fn new(options: &ApproxOptions) -> ScalarFn { - ScalarFn::new(L2Denorm, options.clone()) + pub fn new() -> ScalarFn { + ScalarFn::new(L2Denorm, EmptyOptions) } /// Constructs a validated [`ScalarFnArray`] that lazily re-applies `norms` to `normalized`. @@ -77,21 +88,15 @@ impl L2Denorm { /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches) or if the `normalized` child is not row-wise L2-normalized. pub fn try_new_array( - options: &ApproxOptions, normalized: ArrayRef, norms: ArrayRef, len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - let result = ScalarFnArray::try_new( - L2Denorm::new(options).erased(), - vec![normalized.clone(), norms.clone()], - len, - )?; - - validate_l2_denorm_children(normalized, norms, ctx)?; + validate_l2_denorm_children(&normalized, &norms, ctx)?; - Ok(result) + // SAFETY: We just validated that it is normalized. + unsafe { Self::new_array_unchecked(normalized, norms, len) } } /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually @@ -104,24 +109,24 @@ impl L2Denorm { /// # Safety /// /// The caller must ensure the `normalized` child is semantically suitable for L2 - /// denormalization, which typically means every valid row is unit-norm or zero. Violating this - /// invariant will not cause memory unsafety, but may produce incorrect results. + /// denormalization. For exact wrappers, that means every valid row is unit-norm or zero. + /// + /// Lossy encodings may deliberately relax that invariant while still treating the stored norms + /// as authoritative. + /// + /// Violating the intended contract will not cause memory unsafety, but may produce incorrect + /// results. pub unsafe fn new_array_unchecked( - options: &ApproxOptions, normalized: ArrayRef, norms: ArrayRef, len: usize, ) -> VortexResult { - ScalarFnArray::try_new( - L2Denorm::new(options).erased(), - vec![normalized, norms], - len, - ) + ScalarFnArray::try_new(L2Denorm::new().erased(), vec![normalized, norms], len) } } impl ScalarFnVTable for L2Denorm { - type Options = ApproxOptions; + type Options = EmptyOptions; fn id(&self) -> ScalarFnId { ScalarFnId::new_ref("vortex.tensor.l2_denorm") @@ -258,8 +263,10 @@ impl ScalarFnVTable for L2Denorm { /// [`L2Norm`]'s validity propagation. When the [`L2Denorm`] wrapper is executed, its validity is /// `and(normalized_validity, norms_validity)`, which correctly identifies originally-null rows /// since the normalized child is all-valid and the norms child carries the original nulls. +/// +/// Because this helper computes exact norms first and then divides by those norms, the returned +/// `normalized` child satisfies the strict unit-norm invariant required by [`L2Denorm`]. pub fn normalize_as_l2_denorm( - options: &ApproxOptions, input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { @@ -267,7 +274,7 @@ pub fn normalize_as_l2_denorm( let tensor_match = validate_tensor_float_input(input.dtype())?; let tensor_flat_size = tensor_match.list_size(); - let norms_sfn = L2Norm::try_new_array(options, input.clone(), row_count)?; + let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; let norms: PrimitiveArray = norms_array.clone().execute(ctx)?; let norms_validity = norms.validity()?; @@ -307,8 +314,15 @@ pub fn normalize_as_l2_denorm( ) })?; - // TODO(connor): Need to figure out a way to not run validation. - L2Denorm::try_new_array(options, normalized, norms_array, row_count, ctx) + // SAFETY: + // - `norms_array` was produced by `L2Norm(input)`, so every stored norm is non-negative and + // null rows already carry null validity through that child. + // - For every valid row, we either emit all zeros when the norm is zero or divide every + // element by the exact stored norm, so the normalized child is unit-norm (or zero) by + // construction. + // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values into + // downstream lossy encodings. + unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } } /// Rebuilds a tensor-like extension array from flat primitive elements. @@ -347,7 +361,7 @@ fn unit_norm_tolerance(element_ptype: PType) -> f64 { } /// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0). -pub fn validate_l2_normalized_rows(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { +pub fn validate_l2_normalized_rows(input: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { validate_l2_normalized_rows_impl(input, None, ctx) } @@ -357,16 +371,16 @@ pub fn validate_l2_normalized_rows(input: ArrayRef, ctx: &mut ExecutionCtx) -> V /// - All vectors in the normalized array have length 1 or 0. /// - If the vector has a norm of 0, then the vector must be all 0s. fn validate_l2_denorm_children( - normalized: ArrayRef, - norms: ArrayRef, + normalized: &ArrayRef, + norms: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { validate_l2_normalized_rows_impl(normalized, Some(norms), ctx) } fn validate_l2_normalized_rows_impl( - normalized: ArrayRef, - norms: Option, + normalized: &ArrayRef, + norms: Option<&ArrayRef>, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let row_count = normalized.len(); @@ -379,11 +393,19 @@ fn validate_l2_normalized_rows_impl( let tolerance = unit_norm_tolerance(element_ptype); let tensor_flat_size = tensor_match.list_size(); - let normalized: ExtensionArray = normalized.execute(ctx)?; + if let Some(norms) = norms { + vortex_ensure_eq!( + norms.dtype().as_ptype(), + element_ptype, + "L2Denorm norms ptype must match normalized element ptype" + ); + } + + let normalized: ExtensionArray = normalized.clone().execute(ctx)?; let normalized_validity = normalized.as_ref().validity()?; let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; let norms = norms - .map(|norms| norms.execute::(ctx)) + .map(|norms| norms.clone().execute::(ctx)) .transpose()?; let combined_validity = match &norms { @@ -461,7 +483,6 @@ mod tests { use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; - use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; @@ -478,8 +499,7 @@ mod tests { /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { let mut ctx = SESSION.create_execution_ctx(); - let result = - L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)?; + let result = L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?; result.into_array().execute(&mut ctx) } @@ -585,7 +605,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); } @@ -595,7 +615,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -606,7 +626,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -617,7 +637,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f32, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -626,8 +646,8 @@ mod tests { fn validate_l2_normalized_rows_accepts_normalized_f16_input() -> VortexResult<()> { let input = f16_vector_array(2, &[3.0, 4.0, 0.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(&ApproxOptions::Exact, input, &mut ctx)?; - validate_l2_normalized_rows(roundtrip.child_at(0).clone(), &mut ctx)?; + let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; + validate_l2_normalized_rows(&roundtrip.child_at(0).clone(), &mut ctx)?; Ok(()) } @@ -635,7 +655,7 @@ mod tests { fn validate_l2_normalized_rows_rejects_unnormalized_input() -> VortexResult<()> { let input = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); - let result = validate_l2_normalized_rows(input, &mut ctx); + let result = validate_l2_normalized_rows(&input, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -646,7 +666,7 @@ mod tests { let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -657,7 +677,7 @@ mod tests { let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -668,7 +688,7 @@ mod tests { let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -678,8 +698,7 @@ mod tests { let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); - let result = - unsafe { L2Denorm::new_array_unchecked(&ApproxOptions::Exact, normalized, norms, 2) }; + let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 2) }; assert!(result.is_ok()); Ok(()) } @@ -688,7 +707,7 @@ mod tests { fn normalize_as_l2_denorm_roundtrips_vectors() -> VortexResult<()> { let input = vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(&ApproxOptions::Exact, input.clone(), &mut ctx)?; + let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; assert_tensor_arrays_eq(actual, input)?; @@ -699,7 +718,7 @@ mod tests { fn normalize_as_l2_denorm_roundtrips_fixed_shape_tensors() -> VortexResult<()> { let input = tensor_array(&[2, 2], &[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(&ApproxOptions::Exact, input.clone(), &mut ctx)?; + let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; assert_tensor_arrays_eq(actual, input)?; @@ -710,7 +729,7 @@ mod tests { fn normalize_as_l2_denorm_supports_constant_tensors() -> VortexResult<()> { let input = constant_tensor_array(&[2], &[3.0, 4.0], 3)?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(&ApproxOptions::Exact, input.clone(), &mut ctx)?; + let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; assert_tensor_arrays_eq(actual, input)?; @@ -721,7 +740,7 @@ mod tests { fn normalize_as_l2_denorm_supports_constant_vectors() -> VortexResult<()> { let input = constant_vector_array(&[3.0, 4.0], 2)?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(&ApproxOptions::Exact, input.clone(), &mut ctx)?; + let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; assert_tensor_arrays_eq(actual, input)?; @@ -732,7 +751,7 @@ mod tests { fn normalize_as_l2_denorm_uses_zero_rows_for_zero_norms() -> VortexResult<()> { let input = vector_array(2, &[0.0, 0.0, 3.0, 4.0])?; let mut ctx = SESSION.create_execution_ctx(); - let roundtrip = normalize_as_l2_denorm(&ApproxOptions::Exact, input.clone(), &mut ctx)?; + let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let normalized: ExtensionArray = roundtrip.child_at(0).clone().execute(&mut ctx)?; let storage: FixedSizeListArray = normalized.storage_array().clone().execute(&mut ctx)?; let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 2632c07cffd..170ac800d10 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -22,6 +22,7 @@ use vortex_array::expr::Expression; use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::EmptyOptions; use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; @@ -32,7 +33,6 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; use crate::matcher::AnyTensor; -use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::extract_flat_elements; use crate::utils::validate_tensor_float_input; @@ -43,14 +43,18 @@ use crate::utils::validate_tensor_float_input; /// /// The input must be a tensor-like extension array with a float element type. The output is a float /// column of the same float type. +/// +/// When the input is wrapped in [`L2Denorm`], this operator treats the stored norms as +/// authoritative. For lossy encodings such as TurboQuant, that means `L2Norm` may intentionally +/// read the stored norms instead of re-deriving them from fully decoded coordinates. That behavior +/// is part of the lossy storage contract, not a separate lossy-compute mode. #[derive(Clone)] pub struct L2Norm; impl L2Norm { - /// Creates a new [`ScalarFn`] wrapping the L2 norm operation with the given [`ApproxOptions`] - /// controlling approximation behavior. - pub fn new(options: &ApproxOptions) -> ScalarFn { - ScalarFn::new(L2Norm, options.clone()) + /// Creates a new [`ScalarFn`] wrapping the L2 norm operation. + pub fn new() -> ScalarFn { + ScalarFn::new(L2Norm, EmptyOptions) } /// Constructs a [`ScalarFnArray`] that lazily computes the L2 norm over `child`. @@ -59,17 +63,13 @@ impl L2Norm { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array( - options: &ApproxOptions, - child: ArrayRef, - len: usize, - ) -> VortexResult { - ScalarFnArray::try_new(L2Norm::new(options).erased(), vec![child], len) + pub fn try_new_array(child: ArrayRef, len: usize) -> VortexResult { + ScalarFnArray::try_new(L2Norm::new().erased(), vec![child], len) } } impl ScalarFnVTable for L2Norm { - type Options = ApproxOptions; + type Options = EmptyOptions; fn id(&self) -> ScalarFnId { ScalarFnId::from("vortex.tensor.l2_norm") @@ -122,9 +122,9 @@ impl ScalarFnVTable for L2Norm { let tensor_flat_size = tensor_match.list_size(); let element_ptype = tensor_match.element_ptype(); - // L2Norm(L2Denorm(normalized, norms)) == norms, since normalized vectors have unit norm - // and L2 norms are non-negative. This avoids decompressing the TQ child just to recompute - // norms that are already stored. + // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored + // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics + // instead of forcing a decode-and-recompute path here. if let Some(sfn) = input_ref.as_opt::() && sfn.scalar_fn().as_opt::().is_some() { @@ -195,13 +195,11 @@ mod tests { use vortex_array::arrays::MaskedArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; - use vortex_array::scalar_fn::ScalarFn; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_session::VortexSession; - use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::tensor_array; @@ -212,7 +210,7 @@ mod tests { /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { - let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased(); + let scalar_fn = L2Norm::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -267,7 +265,7 @@ mod tests { let arr = tensor_array(&[2], &[3.0, 4.0, 0.0, 0.0])?; let arr = MaskedArray::try_new(arr, Validity::from_iter([true, false]))?.into_array(); - let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased(); + let scalar_fn = L2Norm::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![arr], 2)?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 32db845a6e2..1d1a362c8af 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -3,40 +3,8 @@ //! Scalar function expressions defined on tensor and tensor-like extension types. -use std::fmt; - pub mod cosine_similarity; pub mod inner_product; pub mod l2_denorm; pub mod l2_norm; - -/// Options for tensor-related expressions that might have error. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] -pub enum ApproxOptions { - /// Computes the exact result. - #[default] - Exact, - /// Allows approximate results. - Approximate, -} - -impl ApproxOptions { - /// Returns `true` if the option is [`Exact`](Self::Exact). - pub fn is_exact(&self) -> bool { - matches!(self, Self::Exact) - } - - /// Returns `true` if the option is [`Approximate`](Self::Approximate). - pub fn is_approx(&self) -> bool { - matches!(self, Self::Approximate) - } -} - -impl fmt::Display for ApproxOptions { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Exact => write!(f, "Exact"), - Self::Approximate => write!(f, "Approximate"), - } - } -} +pub mod sorf_transform; diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs new file mode 100644 index 00000000000..96e6af414ca --- /dev/null +++ b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! SORF inverse transform scalar function. +//! +//! SORF (Structured Orthogonal Random Features, [Yu et al. 2016][sorf-paper]) is a fast structured +//! approximation to a random orthogonal matrix. It composes random sign diagonals with the +//! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost +//! of a dense orthogonal matrix. +//! +//! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension +//! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform +//! at execution time, producing a [`Vector`] extension array with the original (pre-padding) +//! dimensionality. +//! +//! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the +//! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's +//! frozen local SplitMix64 stream contract rather than by an external RNG crate. +//! +//! # Input element type: `f32` only (TODO(connor): for now...) +//! +//! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is +//! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need +//! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform. +//! +//! The reason for this constraint is that TurboQuant (the only production caller today) stores its +//! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`. +//! +//! Supporting other float storage types would require an implicit up-/down-cast that we do not yet +//! want to bake into SorfTransform. This restriction is intentional and may be relaxed in the +//! future, but today it is load-bearing. +//! +//! # Output element type +//! +//! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It +//! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast +//! while materializing the output. This lets SorfTransform hand its result directly to a +//! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose +//! element-type expectation may differ from the `f32` the transform operated on internally. +//! +//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf +//! [`Vector`]: crate::vector::Vector + +use std::fmt; +use std::fmt::Formatter; + +use vortex_array::ArrayRef; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::dtype::PType; +use vortex_array::scalar_fn::ScalarFn; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +mod rotation; +mod splitmix64; +pub use rotation::SorfMatrix; + +mod vtable; + +/// Inverse SORF orthogonal transform scalar function. +/// +/// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32` +/// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the +/// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and +/// wraps the result in a new [`Vector`](crate::vector::Vector) extension array. +/// +/// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the +/// `f32`-only input constraint. +#[derive(Clone)] +pub struct SorfTransform; + +/// Options for the SORF inverse transform scalar function. +/// +/// Stored in the [`ScalarFnArray`] and used to deterministically reconstruct the +/// [`SorfMatrix`] at decode time. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct SorfOptions { + /// Seed used to generate the structured sign diagonals via Vortex's frozen SplitMix64 stream. + pub seed: u64, + /// Number of sign-diagonal + WHT rounds in the structured orthogonal transform. + pub num_rounds: u8, + /// Original vector dimension (before power-of-2 padding). The output + /// [`Vector`](crate::vector::Vector) has this dimension. + pub dimension: u32, + /// Element type of the output [`Vector`](crate::vector::Vector). The child input must always + /// be `f32`, but the output can be any float type (`F16`, `F32`, `F64`); the final + /// `f32 -> element_ptype` cast happens while building the output. + pub element_ptype: PType, +} + +impl SorfTransform { + /// Creates a new [`ScalarFn`] wrapping the SORF inverse transform with the given options. + pub fn new(options: &SorfOptions) -> ScalarFn { + ScalarFn::new(SorfTransform, options.clone()) + } + + /// Constructs a validated [`ScalarFnArray`] that lazily applies the inverse SORF transform. + /// + /// The `child` must be a [`Vector`] extension array (or an array that executes to one) with: + /// + /// - dimension equal to `padded_dim` (i.e. `options.dimension.next_power_of_two()`), and + /// - `f32` storage elements. This is a hard requirement today; see the + /// [module-level docs](crate::scalar_fns::sorf_transform) for the rationale. + /// + /// The output [`Vector`] has dimension `options.dimension` and element type + /// `options.element_ptype`. + /// + /// [`Vector`]: crate::vector::Vector + pub fn try_new_array( + options: &SorfOptions, + child: ArrayRef, + len: usize, + ) -> VortexResult { + validate_sorf_options(options)?; + + ScalarFnArray::try_new(SorfTransform::new(options).erased(), vec![child], len) + } +} + +/// Checks that the SORF configuration is valid. +pub(crate) fn validate_sorf_options(options: &SorfOptions) -> VortexResult<()> { + vortex_ensure!( + options.num_rounds >= 1, + "SorfTransform num_rounds must be >= 1, got {}", + options.num_rounds + ); + vortex_ensure!( + options.element_ptype.is_float(), + "SorfTransform element_ptype must be a float type, got {}", + options.element_ptype + ); + Ok(()) +} + +impl fmt::Display for SorfOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "SorfOptions(seed={}, rounds={}, dim={}, ptype={})", + self.seed, self.num_rounds, self.dimension, self.element_ptype + ) + } +} + +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/array/rotation.rs b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs similarity index 65% rename from vortex-tensor/src/encodings/turboquant/array/rotation.rs rename to vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs index 57c236c821a..9279a35259d 100644 --- a/vortex-tensor/src/encodings/turboquant/array/rotation.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs @@ -1,15 +1,31 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Deterministic random rotation for TurboQuant. +//! SORF (Structured Orthogonal Random Features) orthogonal transform. //! -//! The TurboQuant paper analyzes a full random orthogonal rotation. The current implementation -//! uses a cheaper structured Walsh-Hadamard-based surrogate instead of a dense d x d matrix. +//! Implements the SORF construction from [Yu et al. 2016][sorf-paper]: a fast structured +//! approximation to a random orthogonal matrix using random sign diagonals interleaved with the +//! Fast Walsh-Hadamard Transform (FWHT). //! -//! Concretely, this applies three rounds of random sign diagonals interleaved with the Fast -//! Walsh-Hadamard Transform (FWHT): `D3 * H * D2 * H * D1 * H`, followed by normalization. This is -//! a SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d) -//! encode/decode cost and compact serialized parameters. +//! For `k` rounds, the transform is `norm * H * D_k * ... * H * D_1 * x`, where `D_1` is the +//! first sign diagonal applied. The number of rounds is configurable (typically 3). Each round +//! applies a random sign diagonal `D_i` and then the Hadamard matrix `H`, giving O(d log d) cost +//! per matrix-vector product instead of the O(d^2) cost of a dense orthogonal matrix. +//! +//! Vortex defines those sign diagonals using a frozen local SplitMix64 stream rather than an +//! external RNG crate. The contract is: +//! +//! - state is a single `u64` seed, +//! - each `next_u64()` call uses the SplitMix64 reference algorithm with wrapping `u64` +//! arithmetic, +//! - signs are generated in round-major, block-major order, +//! - each generated `u64` contributes 64 signs in least-significant-bit-first order, +//! - bit `1` means `+1` and bit `0` means `-1`. +//! +//! This makes SORF sign generation stable as a Vortex format contract even if external RNG +//! implementations change. +//! +//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf //! //! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2 //! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n) @@ -26,18 +42,22 @@ //! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into //! `vpxor`/`veor`. -use rand::RngExt; -use rand::SeedableRng; -use rand::rngs::StdRng; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use super::splitmix64::SplitMix64; + /// IEEE 754 sign bit mask for f32. const F32_SIGN_BIT: u32 = 0x8000_0000; -/// A Walsh-Hadamard-based structured surrogate for a random orthogonal rotation. -pub struct RotationMatrix { - /// Flat XOR masks for all `num_rounds` diagonal matrices, total length `num_rounds * padded_dim`. +/// A Walsh-Hadamard-based structured orthogonal transform matrix. +/// +/// All computation is done in f32. The sign diagonals are stored as IEEE 754 XOR masks on +/// f32 bit patterns, and the Walsh-Hadamard butterfly operates on `&mut [f32]` slices. +pub struct SorfMatrix { + /// Flat XOR masks for all `num_rounds` diagonal matrices, total length + /// `num_rounds * padded_dim`. + /// /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = /// multiply by -1 (flip sign bit). sign_masks: Vec, @@ -49,22 +69,25 @@ pub struct RotationMatrix { norm_factor: f32, } -impl RotationMatrix { - /// Create a new structured Walsh-Hadamard-based rotation from a deterministic seed. +impl SorfMatrix { + /// Create a new structured Walsh-Hadamard-based orthogonal transform from a deterministic + /// seed. + /// + /// The seed is expanded using Vortex's frozen local SplitMix64 stream. Signs are generated in + /// round-major, block-major order, with each `u64` contributing 64 sign bits in + /// least-significant-bit-first order. pub fn try_new(seed: u64, dimension: usize, num_rounds: usize) -> VortexResult { vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); let padded_dim = dimension.next_power_of_two(); - let mut rng = StdRng::seed_from_u64(seed); - - let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim); - for _ in 0..num_rounds { - sign_masks.extend(gen_random_sign_masks(&mut rng, padded_dim)); - } + let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. + // The result is always in (0, 1] for any valid padded_dim >= 2 and num_rounds >= 1, so + // the f64 -> f32 cast is a precision loss only -- it cannot overflow to infinity. #[expect( clippy::cast_possible_truncation, - reason = "Intentional f64 -> f32 truncation for normalization factor." + reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" )] let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; @@ -76,7 +99,14 @@ impl RotationMatrix { }) } - /// Apply forward rotation: `output = R(input)`. + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All `rotate`/`inverse_rotate` buffers must be this length. + pub fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the forward orthogonal transform: `output = R(input)`. /// /// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is /// responsible for zero-padding input beyond `dim` positions. @@ -88,7 +118,7 @@ impl RotationMatrix { self.apply_srht(output); } - /// Apply inverse rotation: `output = R⁻¹(input)`. + /// Apply the inverse orthogonal transform: `output = R⁻¹(input)`. /// /// Both `input` and `output` must have length `padded_dim()`. pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { @@ -99,19 +129,7 @@ impl RotationMatrix { self.apply_inverse_srht(output); } - /// Returns the number of sign-diagonal + WHT rounds. - pub fn num_rounds(&self) -> usize { - self.num_rounds - } - - /// Returns the padded dimension (next power of 2 >= dim). - /// - /// All rotate/inverse_rotate buffers must be this length. - pub fn padded_dim(&self) -> usize { - self.padded_dim - } - - /// Apply the structured rotation: `D_k · H · ... · D₁ · H · x`, with normalization. + /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. fn apply_srht(&self, buf: &mut [f32]) { for round in 0..self.num_rounds { let offset = round * self.padded_dim; @@ -123,9 +141,9 @@ impl RotationMatrix { buf.iter_mut().for_each(|val| *val *= norm); } - /// Apply the inverse structured rotation. + /// Apply the inverse structured transform. /// - /// Forward is: `norm · H · D_k · H · ... · D₁ · H`. + /// Forward is: `norm · H · D_k · ... · H · D₁`. /// Inverse is: `norm · D₁ · H · ... · D_k · H`. fn apply_inverse_srht(&self, buf: &mut [f32]) { for round in (0..self.num_rounds).rev() { @@ -144,6 +162,7 @@ impl RotationMatrix { /// Convention: `1` = positive (+1), `0` = negative (-1). The output has length /// `num_rounds * padded_dim` and is suitable for bitpacking via FastLanes /// `bitpack_encode(..., 1, None)`. + #[cfg(test)] pub fn export_inverse_signs_u8(&self) -> Vec { let total = self.num_rounds * self.padded_dim; let mut out = Vec::with_capacity(total); @@ -158,7 +177,7 @@ impl RotationMatrix { out } - /// Reconstruct a [`RotationMatrix`] from unpacked `u8` 0/1 values. + /// Reconstruct a [`SorfMatrix`] from unpacked `u8` 0/1 values. /// /// The input must have length `num_rounds * padded_dim` with signs in inverse application /// order `[D_k | ... | D₁]` (as produced by [`export_inverse_signs_u8`]). Convention: @@ -166,6 +185,7 @@ impl RotationMatrix { /// /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the stored /// [`BitPackedArray`] into `&[u8]`, which is passed here. + #[cfg(test)] pub fn from_u8_slice( signs_u8: &[u8], dimension: usize, @@ -196,9 +216,11 @@ impl RotationMatrix { } } + // Same norm factor computation as `try_new`. See the comment there for why this cast + // cannot overflow. #[expect( clippy::cast_possible_truncation, - reason = "Intentional f64 -> f32 truncation for normalization factor." + reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" )] let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; @@ -211,24 +233,40 @@ impl RotationMatrix { } } -/// Generate a vector of random XOR sign masks. -fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec { - (0..len) - .map(|_| { - if rng.random_bool(0.5) { - 0u32 // +1: no-op - } else { - F32_SIGN_BIT // -1: flip sign bit - } - }) - .collect() +/// Generate XOR sign masks from the frozen local SplitMix64 stream. +/// +/// Signs are produced in round-major, block-major order. For each block we call +/// [`SplitMix64::next_u64`] exactly once and unpack its bits from least significant to most +/// significant. Bit `1` means positive sign / `0x00000000`; bit `0` means negative sign / +/// [`F32_SIGN_BIT`]. +fn gen_sign_masks_from_seed(seed: u64, padded_dim: usize, num_rounds: usize) -> Vec { + let mut rng = SplitMix64::new(seed); + let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim); + + for _round in 0..num_rounds { + for base_idx in (0..padded_dim).step_by(64) { + let word = rng.next_u64(); + let bits_in_block = (padded_dim - base_idx).min(64); + sign_masks.extend((0..bits_in_block).map(|bit_idx| sign_mask_from_word(word, bit_idx))); + } + } + + sign_masks +} + +/// Convert one bit from a SplitMix64 output word into an XOR sign mask. +fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { + if ((word >> bit_idx) & 1) != 0 { + 0u32 + } else { + F32_SIGN_BIT + } } /// Apply sign masks via XOR on the IEEE 754 sign bit. /// /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to /// multiplying each element by +/-1.0, but avoids FP dependency chains. -#[inline] fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { for (val, &mask) in buf.iter_mut().zip(masks.iter()) { *val = f32::from_bits(val.to_bits() ^ mask); @@ -265,7 +303,6 @@ fn walsh_hadamard_transform(buf: &mut [f32]) { /// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element /// pair. Factored into a separate function so LLVM can see the slice lengths match and /// auto-vectorize. -#[inline(always)] fn butterfly(lo: &mut [f32], hi: &mut [f32]) { debug_assert_eq!(lo.len(), hi.len()); for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { @@ -282,11 +319,18 @@ mod tests { use vortex_error::VortexResult; use super::*; + use crate::scalar_fns::sorf_transform::splitmix64::SplitMix64; + + fn unpack_sign_bits(word: u64, count: usize) -> Vec { + (0..count) + .map(|bit_idx| u8::from(((word >> bit_idx) & 1) != 0)) + .collect() + } #[test] fn deterministic_from_seed() -> VortexResult<()> { - let r1 = RotationMatrix::try_new(42, 64, 3)?; - let r2 = RotationMatrix::try_new(42, 64, 3)?; + let r1 = SorfMatrix::try_new(42, 64, 3)?; + let r2 = SorfMatrix::try_new(42, 64, 3)?; let pd = r1.padded_dim(); let mut input = vec![0.0f32; pd]; @@ -303,6 +347,48 @@ mod tests { Ok(()) } + #[test] + fn export_inverse_signs_matches_golden_words() -> VortexResult<()> { + let rot = SorfMatrix::try_new(42, 64, 2)?; + let actual = rot.export_inverse_signs_u8(); + let mut rng = SplitMix64::new(42); + let round0_word = rng.next_u64(); + let round1_word = rng.next_u64(); + + let mut expected = Vec::with_capacity(128); + expected.extend(unpack_sign_bits(round1_word, 64)); + expected.extend(unpack_sign_bits(round0_word, 64)); + + assert_eq!(actual, expected); + Ok(()) + } + + #[test] + fn one_word_generates_64_signs_lsb_first() { + let masks = gen_sign_masks_from_seed(42, 64, 1); + assert_eq!(masks.len(), 64); + + let mut rng = SplitMix64::new(42); + let word = rng.next_u64(); + let expected: Vec<_> = (0..64) + .map(|bit_idx| sign_mask_from_word(word, bit_idx)) + .collect(); + assert_eq!(masks, expected); + } + + #[test] + fn tail_block_uses_only_required_bits() { + let masks = gen_sign_masks_from_seed(42, 32, 1); + assert_eq!(masks.len(), 32); + + let mut rng = SplitMix64::new(42); + let word = rng.next_u64(); + let expected: Vec<_> = (0..32) + .map(|bit_idx| sign_mask_from_word(word, bit_idx)) + .collect(); + assert_eq!(masks, expected); + } + /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, /// including non-power-of-two dimensions that require padding. #[rstest] @@ -318,7 +404,7 @@ mod tests { #[case(768, 3)] #[case(1024, 3)] fn roundtrip_exact(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> { - let rot = RotationMatrix::try_new(42, dim, num_rounds)?; + let rot = SorfMatrix::try_new(42, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let mut input = vec![0.0f32; padded_dim]; @@ -354,7 +440,7 @@ mod tests { #[case(128, 5)] #[case(768, 3)] fn preserves_norm(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> { - let rot = RotationMatrix::try_new(7, dim, num_rounds)?; + let rot = SorfMatrix::try_new(7, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let mut input = vec![0.0f32; padded_dim]; @@ -377,7 +463,7 @@ mod tests { Ok(()) } - /// Verify that export -> [`from_u8_slice`] produces identical rotation output. + /// Verify that export -> [`from_u8_slice`] produces identical transform output. #[rstest] #[case(64, 3)] #[case(128, 1)] @@ -388,11 +474,11 @@ mod tests { #[case] dim: usize, #[case] num_rounds: usize, ) -> VortexResult<()> { - let rot = RotationMatrix::try_new(42, dim, num_rounds)?; + let rot = SorfMatrix::try_new(42, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let signs_u8 = rot.export_inverse_signs_u8(); - let rot2 = RotationMatrix::from_u8_slice(&signs_u8, dim, num_rounds)?; + let rot2 = SorfMatrix::from_u8_slice(&signs_u8, dim, num_rounds)?; let mut input = vec![0.0f32; padded_dim]; for i in 0..dim { @@ -403,12 +489,12 @@ mod tests { let mut out2 = vec![0.0f32; padded_dim]; rot.rotate(&input, &mut out1); rot2.rotate(&input, &mut out2); - assert_eq!(out1, out2, "Forward rotation mismatch after export/import"); + assert_eq!(out1, out2, "Forward transform mismatch after export/import"); rot.inverse_rotate(&out1, &mut out2); let mut out3 = vec![0.0f32; padded_dim]; rot2.inverse_rotate(&out1, &mut out3); - assert_eq!(out2, out3, "Inverse rotation mismatch after export/import"); + assert_eq!(out2, out3, "Inverse transform mismatch after export/import"); Ok(()) } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs b/vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs new file mode 100644 index 00000000000..23345cbcb7d --- /dev/null +++ b/vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Frozen local SplitMix64 stream used to define SORF sign diagonals. +//! +//! This is a direct translation of the `splitmix64.c` reference implementation. The state is a +//! single `u64`, and `next_u64()` first adds [`SPLITMIX64_INCREMENT`] with wrapping arithmetic, +//! then applies the two reference mixing steps and final xor-shift. + +/// SplitMix64 additive constant from the reference implementation. +const SPLITMIX64_INCREMENT: u64 = 0x9E37_79B9_7F4A_7C15; + +/// First SplitMix64 mixing multiplier from the reference implementation. +const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9; + +/// Second SplitMix64 mixing multiplier from the reference implementation. +const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB; + +/// Frozen local SplitMix64 stream used to define SORF sign diagonals. +pub(crate) struct SplitMix64 { + state: u64, +} + +impl SplitMix64 { + pub(crate) fn new(seed: u64) -> Self { + Self { state: seed } + } + + pub(crate) fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(SPLITMIX64_INCREMENT); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(SPLITMIX64_MUL1); + z = (z ^ (z >> 27)).wrapping_mul(SPLITMIX64_MUL2); + z ^ (z >> 31) + } +} + +#[cfg(test)] +mod tests { + use super::SplitMix64; + + const SPLITMIX64_SEED0_GOLDEN: [u64; 4] = [ + 0xE220_A839_7B1D_CDAF, + 0x6E78_9E6A_A1B9_65F4, + 0x06C4_5D18_8009_454F, + 0xF88B_B8A8_724C_81EC, + ]; + + const SPLITMIX64_SEED42_GOLDEN: [u64; 4] = [ + 0xBDD7_3226_2FEB_6E95, + 0x28EF_E333_B266_F103, + 0x4752_6757_130F_9F52, + 0x581C_E1FF_0E4A_E394, + ]; + + #[test] + fn splitmix64_seed0_matches_golden_outputs() { + let mut rng = SplitMix64::new(0); + let actual: Vec<_> = (0..SPLITMIX64_SEED0_GOLDEN.len()) + .map(|_| rng.next_u64()) + .collect(); + assert_eq!(actual, SPLITMIX64_SEED0_GOLDEN); + } + + #[test] + fn splitmix64_seed42_matches_golden_outputs() { + let mut rng = SplitMix64::new(42); + let actual: Vec<_> = (0..SPLITMIX64_SEED42_GOLDEN.len()) + .map(|_| rng.next_u64()) + .collect(); + assert_eq!(actual, SPLITMIX64_SEED42_GOLDEN); + } +} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs new file mode 100644 index 00000000000..99d97fbe87f --- /dev/null +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -0,0 +1,438 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Unit tests for the [`SorfTransform`] scalar function. + +#![allow(clippy::cast_possible_truncation)] + +use std::sync::Arc; +use std::sync::LazyLock; + +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::dict::DictArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_session::VortexSession; + +use super::SorfOptions; +use super::SorfTransform; +use super::rotation::SorfMatrix; +use crate::encodings::turboquant::centroids::compute_centroid_boundaries; +use crate::encodings::turboquant::centroids::find_nearest_centroid; +use crate::encodings::turboquant::centroids::get_centroids; +use crate::vector::Vector; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +/// Build a unit-normalized input vector array and forward-transform + quantize it, returning +/// `(input_f32, Vector(FSL(Dict(codes, centroids))), padded_dim)`. +/// +/// This mimics what the TurboQuant compression pipeline does, but directly, so we can test +/// `SorfTransform` in isolation. +fn forward_rotate_and_quantize( + dim: usize, + num_rows: usize, + seed: u64, + num_rounds: usize, + bit_width: u8, +) -> VortexResult<(Vec, ExtensionArray, usize)> { + // Build simple unit-normalized input vectors. + let mut input_f32 = vec![0.0f32; num_rows * dim]; + for row in 0..num_rows { + let mut norm_sq = 0.0f32; + for i in 0..dim { + let val = ((row * dim + i) as f32 + 1.0) * 0.01; + input_f32[row * dim + i] = val; + norm_sq += val * val; + } + let norm = norm_sq.sqrt(); + for i in 0..dim { + input_f32[row * dim + i] /= norm; + } + } + + let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?; + let padded_dim = rotation.padded_dim(); + let centroids = get_centroids(padded_dim as u32, bit_width)?; + let boundaries = compute_centroid_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + padded[..dim].copy_from_slice(&input_f32[row * dim..(row + 1) * dim]); + padded[dim..].fill(0.0); + rotation.rotate(&padded, &mut rotated); + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); + let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); + centroids_buf.extend_from_slice(¢roids); + let centroids_arr = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + let dict = DictArray::try_new(codes.into_array(), centroids_arr.into_array())?; + let fsl = FixedSizeListArray::try_new( + dict.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )?; + let padded_vector = wrap_as_vector(fsl, Validity::NonNullable)?; + + Ok((input_f32, padded_vector, padded_dim)) +} + +/// Wrap an FSL in a Vector extension, optionally re-tagging its validity. This is used by tests +/// that need to adjust top-level nullability of a padded vector child. +fn wrap_as_vector(fsl: FixedSizeListArray, validity: Validity) -> VortexResult { + let list_size = fsl.list_size(); + let num_rows = fsl.len(); + let elements = fsl.elements().clone(); + let fsl = FixedSizeListArray::try_new(elements, list_size, validity, num_rows)?; + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array())) +} + +/// Helper to build `SorfOptions` with common defaults. +fn default_options(dim: u32, seed: u64) -> SorfOptions { + SorfOptions { + seed, + num_rounds: 3, + dimension: dim, + element_ptype: PType::F32, + } +} + +/// Execute a `SorfTransform` array and return the decoded flat f32 elements. +fn execute_sorf( + options: &SorfOptions, + child: ExtensionArray, + num_rows: usize, +) -> VortexResult> { + let sorf = SorfTransform::try_new_array(options, child.into_array(), num_rows)?; + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; + Ok(result_prim.as_slice::().to_vec()) +} + +/// Build an empty `Vector` extension array wrapping an empty FSL. +fn empty_padded_vector(padded_dim: u32, validity: Validity) -> VortexResult { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let fsl = FixedSizeListArray::try_new(elements.into_array(), padded_dim, validity, 0)?; + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array())) +} + +#[test] +fn roundtrip_recovery() -> VortexResult<()> { + let dim = 128; + let num_rows = 10; + let seed = 42u64; + let (input_f32, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; + let options = default_options(dim as u32, seed); + let result = execute_sorf(&options, padded_vector, num_rows)?; + + assert_eq!(result.len(), num_rows * dim); + + // At 8-bit quantization, the reconstruction should be very close to the input. + for row in 0..num_rows { + let orig = &input_f32[row * dim..(row + 1) * dim]; + let recon = &result[row * dim..(row + 1) * dim]; + let err_sq: f32 = orig + .iter() + .zip(recon) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + assert!( + err_sq / norm_sq < 1e-3, + "row {row} MSE too high: {:.6}", + err_sq / norm_sq + ); + } + Ok(()) +} + +#[test] +fn empty_array_non_nullable() -> VortexResult<()> { + let dim = 128u32; + let padded_dim = dim.next_power_of_two(); + let options = default_options(dim, 42); + + // Build an empty Vector child. + let child = empty_padded_vector(padded_dim, Validity::NonNullable)?; + + let sorf = SorfTransform::try_new_array(&options, child.into_array(), 0)?; + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; + + assert_eq!(result.len(), 0); + + // Output should be non-nullable. + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + assert!(!result_fsl.dtype().is_nullable()); + + Ok(()) +} + +#[test] +fn empty_array_nullable() -> VortexResult<()> { + let dim = 128u32; + let padded_dim = dim.next_power_of_two(); + let options = default_options(dim, 42); + + // Build an empty but nullable Vector child. + let child = empty_padded_vector(padded_dim, Validity::from(Nullability::Nullable))?; + + let sorf = SorfTransform::try_new_array(&options, child.into_array(), 0)?; + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; + + assert_eq!(result.len(), 0); + + // Output should be nullable (matching the child). + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + assert!(result_fsl.dtype().is_nullable()); + + Ok(()) +} + +#[test] +fn nullable_validity_propagation() -> VortexResult<()> { + let dim = 128; + let num_rows = 4; + let seed = 42u64; + let (_, non_nullable_vector, padded_dim) = + forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; + + // Re-wrap the underlying FSL with a validity mask: rows 0 and 2 are valid, rows 1 and 3 + // are null. + let validity = Validity::from_iter([true, false, true, false]); + let fsl_non_nullable: FixedSizeListArray = non_nullable_vector + .storage_array() + .clone() + .execute(&mut SESSION.create_execution_ctx())?; + let fsl_nullable = FixedSizeListArray::try_new( + fsl_non_nullable.elements().clone(), + padded_dim as u32, + validity.clone(), + num_rows, + )?; + let nullable_vector = wrap_as_vector(fsl_nullable, validity.clone())?; + + let options = default_options(dim as u32, seed); + let sorf = SorfTransform::try_new_array(&options, nullable_vector.into_array(), num_rows)?; + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + + // The output FSL validity should match the input. + let output_validity = result_fsl.validity()?; + for row in 0..num_rows { + assert_eq!( + output_validity.is_valid(row)?, + validity.is_valid(row)?, + "validity mismatch at row {row}" + ); + } + + Ok(()) +} + +#[test] +fn dimension_truncation() -> VortexResult<()> { + // Use a non-power-of-2 dimension (padded 200 -> 256). + let dim = 200; + let num_rows = 3; + let seed = 42u64; + let (_, padded_vector, padded_dim) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; + + assert_eq!(padded_dim, 256, "200 should pad to 256"); + + let options = default_options(dim as u32, seed); + let result = execute_sorf(&options, padded_vector, num_rows)?; + + // Output should have original dimension, not padded. + assert_eq!(result.len(), num_rows * dim); + + Ok(()) +} + +#[test] +fn return_dtype_is_vector_extension() -> VortexResult<()> { + let dim = 128u32; + let padded_dim = dim.next_power_of_two(); + let options = default_options(dim, 42); + + // Input must be a Vector extension dtype. + let child_elem_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let child_storage_dtype = DType::FixedSizeList( + Arc::new(child_elem_dtype), + padded_dim, + Nullability::NonNullable, + ); + let child_ext_dtype = ExtDType::::try_new(EmptyMetadata, child_storage_dtype)?.erased(); + let child_dtype = DType::Extension(child_ext_dtype); + + use vortex_array::scalar_fn::ScalarFnVTable; + let return_dtype = SorfTransform.return_dtype(&options, &[child_dtype])?; + + // Should be a Vector extension type. + let ext = return_dtype + .as_extension_opt() + .expect("return dtype should be an extension type"); + assert!(ext.metadata_opt::().is_some()); + + // Inner FSL should have the original (unpadded) dimension. + let DType::FixedSizeList(_, inner_dim, _) = ext.storage_dtype() else { + panic!("expected storage dtype to be FSL"); + }; + assert_eq!(*inner_dim, dim); + + Ok(()) +} + +#[test] +fn rejects_zero_rounds_at_construction() { + let options = SorfOptions { + seed: 42, + num_rounds: 0, + dimension: 128, + element_ptype: PType::F32, + }; + let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); + let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) + .expect("test child should be valid"); + + let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + .expect_err("zero rounds should be rejected at construction time"); + assert!(err.to_string().contains("num_rounds")); +} + +#[test] +fn rejects_non_float_output_ptype_at_construction() { + let options = SorfOptions { + seed: 42, + num_rounds: 3, + dimension: 128, + element_ptype: PType::U8, + }; + let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); + let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) + .expect("test child should be valid"); + + let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + .expect_err("non-float output ptypes should be rejected at construction time"); + assert!(err.to_string().contains("element_ptype")); +} + +#[test] +fn rejects_non_vector_extension_child_at_construction() { + let options = default_options(128, 42); + // A bare FSL child (not wrapped in a Vector extension) should be rejected. + let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); + let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) + .expect("test child should be valid"); + + let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + .expect_err("non-Vector-extension children should be rejected at construction time"); + assert!(err.to_string().contains("Vector extension")); +} + +#[test] +fn rejects_wrong_padded_dimension_at_construction() { + // Options say dimension=128 so padded_dim should be 128. Pass a Vector<256> instead. + let options = default_options(128, 42); + let elements = PrimitiveArray::from_iter([0.0f32; 256]).into_array(); + let fsl = FixedSizeListArray::try_new(elements, 256, Validity::NonNullable, 1) + .expect("test child should be valid"); + let child = wrap_as_vector(fsl, Validity::NonNullable).expect("wrap should succeed"); + + let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + .expect_err("mismatched padded dimension should be rejected at construction time"); + assert!(err.to_string().contains("dimension")); +} + +#[test] +fn rejects_non_f32_child_storage_at_construction() { + // Options are valid and target f32 output. Pass a Vector<128> whose storage is f16 instead + // of f32 -- SorfTransform's f32-only input constraint should reject this. + let options = default_options(128, 42); + let elements = PrimitiveArray::from_iter([half::f16::from_f32(0.0); 128]).into_array(); + let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) + .expect("test child should be valid"); + let child = wrap_as_vector(fsl, Validity::NonNullable).expect("wrap should succeed"); + + let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + .expect_err("non-f32 Vector storage should be rejected at construction time"); + assert!(err.to_string().contains("f32")); +} + +#[test] +fn f16_output_type() -> VortexResult<()> { + let dim = 128; + let num_rows = 3; + let seed = 42u64; + let (_, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; + + let options = SorfOptions { + seed, + num_rounds: 3, + dimension: dim as u32, + element_ptype: PType::F16, + }; + let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array(), num_rows)?; + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; + + assert_eq!(result_prim.ptype(), PType::F16); + assert_eq!(result_prim.as_slice::().len(), num_rows * dim); + + Ok(()) +} + +#[test] +fn f64_output_type() -> VortexResult<()> { + let dim = 128; + let num_rows = 3; + let seed = 42u64; + let (_, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; + + let options = SorfOptions { + seed, + num_rounds: 3, + dimension: dim as u32, + element_ptype: PType::F64, + }; + let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array(), num_rows)?; + let mut ctx = SESSION.create_execution_ctx(); + let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; + let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; + let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; + + assert_eq!(result_prim.ptype(), PType::F64); + assert_eq!(result_prim.as_slice::().len(), num_rows * dim); + + Ok(()) +} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs new file mode 100644 index 00000000000..69ae671bc85 --- /dev/null +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! [`ScalarFnVTable`] implementation for [`SorfTransform`]. + +use std::fmt; +use std::fmt::Formatter; +use std::sync::Arc; + +use num_traits::Float; +use num_traits::FromPrimitive; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::expr::Expression; +use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; + +use super::SorfOptions; +use super::SorfTransform; +use super::rotation::SorfMatrix; +use super::validate_sorf_options; +use crate::vector::AnyVector; +use crate::vector::Vector; + +impl ScalarFnVTable for SorfTransform { + type Options = SorfOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.tensor.sorf_transform") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("rotated"), + _ => unreachable!("SorfTransform must have exactly one child"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "sorf_transform(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + validate_sorf_options(options)?; + + let child_dtype = &arg_dtypes[0]; + let vector_metadata = child_dtype + .as_extension_opt() + .and_then(|ext| ext.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("SorfTransform child must be a Vector extension, got {child_dtype}") + })?; + + let expected_padded = options.dimension.next_power_of_two(); + vortex_ensure_eq!( + vector_metadata.dimensions(), + expected_padded, + "SorfTransform child Vector must have dimension {expected_padded} (next power of two \ + for dimension {})", + options.dimension, + ); + + // For now, the child Vector storage must be f32. TurboQuant stores its centroids as f32, + // and the SORF transform itself operates in f32, so any other input type would require an + // implicit cast that we do not yet support. The output element type is independently + // specified via `options.element_ptype` and is built below. + vortex_ensure_eq!( + vector_metadata.element_ptype(), + PType::F32, + "SorfTransform child Vector storage must be f32 (for now), got {}", + vector_metadata.element_ptype(), + ); + + let output_elem_dtype = DType::Primitive(options.element_ptype, Nullability::NonNullable); + let storage_dtype = DType::FixedSizeList( + Arc::new(output_elem_dtype), + options.dimension, + child_dtype.nullability(), + ); + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); + Ok(DType::Extension(ext_dtype)) + } + + fn execute( + &self, + options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + validate_sorf_options(options)?; + let dim = options.dimension as usize; + let num_rows = args.row_count(); + + if num_rows == 0 { + let child_nullability = args.get(0)?.dtype().nullability(); + let validity = Validity::from(child_nullability); + + return match_each_float_ptype!(options.element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + options.dimension, + validity, + 0, + )?; + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }); + } + + // Execute the child to get the Vector extension wrapping an FSL of f32 coordinates. The + // `return_dtype` check guarantees the child is a `Vector`, so the + // materialized FSL elements are always f32. + let child_ext: ExtensionArray = args.get(0)?.execute(ctx)?; + let child_validity = child_ext.as_ref().validity()?; + let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; + let padded_dim = + usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); + + let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; + let f32_elements = elements_prim.into_buffer::(); + + // Reconstruct the orthogonal transform matrix from the seed. + let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; + + // Inverse transform each row, truncate to original dimension, cast to target type. + match_each_float_ptype!(options.element_ptype, |T| { + inverse_rotate_typed::( + &f32_elements, + &rotation, + dim, + padded_dim, + num_rows, + child_validity, + ) + }) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Convert an f32 value to a float type `T`. +/// +/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the +/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening. +fn float_from_f32(v: f32) -> T { + FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible") +} + +/// Apply the inverse SORF transform on f32 data, truncate to the original dimension, cast each +/// element to `T`, and build the output [`Vector`] extension array. +fn inverse_rotate_typed( + f32_elements: &[f32], + rotation: &SorfMatrix, + dim: usize, + padded_dim: usize, + num_rows: usize, + validity: Validity, +) -> VortexResult { + let dim_u32 = u32::try_from(dim).vortex_expect("dimension fits u32"); + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut unrotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let row_data = &f32_elements[row * padded_dim..(row + 1) * padded_dim]; + + rotation.inverse_rotate(row_data, &mut unrotated); + + for idx in 0..dim { + // SAFETY: We allocated enough memory above. + unsafe { output.push_unchecked(float_from_f32::(unrotated[idx])) }; + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) +} diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 4d78597c962..325a530ee46 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -13,8 +13,10 @@ use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; +use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; @@ -41,6 +43,40 @@ pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult`. +/// +/// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively +/// in f32. This function handles the cast from any float ptype: +/// +/// - f16: losslessly widened to f32. +/// - f32: zero-copy buffer extraction. +/// - f64: truncated to f32 precision. Values outside f32 range become +/- infinity. This is +/// acceptable because callers of this function operate in f32 and document this constraint. +pub fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { + match prim.ptype() { + PType::F16 => Ok(prim + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(prim.into_buffer()), + PType::F64 => Ok(prim + .as_slice::() + .iter() + .map(|&v| { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 values outside f32 range become infinity, which is acceptable \ + because callers operate in f32 and document this constraint" + )] + let v = v as f32; + v + }) + .collect()), + other => vortex_bail!("expected float elements, got {other:?}"), + } +} + /// The flat primitive elements of a tensor storage array, with typed row access. /// /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 091b9a6747f..564d3dd35cd 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -451,7 +451,6 @@ mod turboquant_benches { use vortex_buffer::BufferMut; use vortex_tensor::encodings::turboquant::TurboQuantConfig; use vortex_tensor::encodings::turboquant::turboquant_encode_unchecked; - use vortex_tensor::scalar_fns::ApproxOptions; use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; use vortex_tensor::vector::Vector; @@ -499,7 +498,7 @@ mod turboquant_benches { fn setup_normalized_vector_ext(dim: usize) -> ExtensionArray { let ext = setup_vector_ext(dim); let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(&ApproxOptions::Exact, ext.into_array(), &mut ctx) + let normalized = normalize_as_l2_denorm(ext.into_array(), &mut ctx) .unwrap() .child_at(0) .clone();