From 61ae20199132581e3e60a0ead5e2d9245ead4ccd Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 14 Apr 2026 17:45:35 -0400 Subject: [PATCH 1/2] support serializin tensor scalar fns Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 30 ++++ vortex-tensor/src/lib.rs | 21 ++- .../src/scalar_fns/cosine_similarity.rs | 85 +++++++++- vortex-tensor/src/scalar_fns/inner_product.rs | 148 +++++++++++++++++- vortex-tensor/src/scalar_fns/l2_denorm.rs | 112 ++++++++++++- vortex-tensor/src/scalar_fns/l2_norm.rs | 89 ++++++++++- .../src/scalar_fns/sorf_transform/tests.rs | 66 +++++++- .../src/scalar_fns/sorf_transform/vtable.rs | 93 +++++++++++ 8 files changed, 629 insertions(+), 15 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 4f6c2dddbfb..01ff0a6f477 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -248,6 +248,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::Cosine pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> + impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions @@ -284,6 +290,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::inner_product::InnerProdu pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::clone(&self) -> vortex_tensor::scalar_fns::inner_product::InnerProduct +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> + impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_array::scalar_fn::vtable::EmptyOptions @@ -322,6 +334,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::l2_denorm::L2Denorm pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::clone(&self) -> vortex_tensor::scalar_fns::l2_denorm::L2Denorm +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm + +pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> + impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_array::scalar_fn::vtable::EmptyOptions @@ -362,6 +380,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> + impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions @@ -444,6 +468,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::sorf_transform::SorfTrans pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) -> vortex_tensor::scalar_fns::sorf_transform::SorfTransform +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult>> + impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform pub type vortex_tensor::scalar_fns::sorf_transform::SorfTransform::Options = vortex_tensor::scalar_fns::sorf_transform::SorfOptions diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index b3cf6c21695..171a0d817f0 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,8 +5,10 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; +use vortex_array::session::ArraySessionExt; use vortex_session::VortexSession; use crate::fixed_shape::FixedShapeTensor; @@ -34,9 +36,18 @@ pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); session.dtypes().register(FixedShapeTensor); - session.scalar_fns().register(CosineSimilarity); - session.scalar_fns().register(InnerProduct); - session.scalar_fns().register(L2Denorm); - session.scalar_fns().register(L2Norm); - session.scalar_fns().register(SorfTransform); + let session_fns = session.scalar_fns(); + let session_arrays = session.arrays(); + + session_fns.register(CosineSimilarity); + session_fns.register(InnerProduct); + session_fns.register(L2Denorm); + session_fns.register(L2Norm); + session_fns.register(SorfTransform); + + session_arrays.register(ScalarFnArrayPlugin::new(CosineSimilarity)); + session_arrays.register(ScalarFnArrayPlugin::new(InnerProduct)); + session_arrays.register(ScalarFnArrayPlugin::new(L2Denorm)); + session_arrays.register(ScalarFnArrayPlugin::new(L2Norm)); + session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform)); } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 2a717b90b7b..12e06701418 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -12,6 +12,9 @@ use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; @@ -25,11 +28,14 @@ use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; +use vortex_session::VortexSession; +use crate::scalar_fns::inner_product::BinaryTensorOpMetadata; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; @@ -221,6 +227,37 @@ impl ScalarFnVTable for CosineSimilarity { } } +impl ScalarFnArrayVTable for CosineSimilarity { + fn serialize( + &self, + view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + Ok(Some(BinaryTensorOpMetadata::encode_from_view(view)?)) + } + + fn deserialize( + &self, + _dtype: &DType, + len: usize, + metadata: &[u8], + children: &dyn ArrayChildren, + session: &VortexSession, + ) -> VortexResult> { + let reconstructed = BinaryTensorOpMetadata::decode_children( + metadata, + len, + children, + session, + "CosineSimilarity", + )?; + Ok(ScalarFnArrayParts { + options: EmptyOptions, + children: reconstructed, + }) + } +} + impl CosineSimilarity { /// Both sides are `L2Denorm`: treat the normalized children as authoritative, so /// `cosine_similarity = dot(n_l, n_r)`. @@ -295,12 +332,14 @@ mod tests { use std::sync::LazyLock; use rstest::rstest; + use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::MaskedArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; + use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; @@ -314,8 +353,11 @@ mod tests { use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); + static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session + }); /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { @@ -693,4 +735,43 @@ mod tests { ); Ok(()) } + + #[rstest] + #[case::vector( + vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).unwrap(), + vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).unwrap(), + 2, + )] + #[case::fixed_shape_tensor( + tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).unwrap(), + tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).unwrap(), + 2, + )] + fn serde_round_trip( + #[case] lhs: ArrayRef, + #[case] rhs: ArrayRef, + #[case] len: usize, + ) -> VortexResult<()> { + let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(CosineSimilarity); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("CosineSimilarity serialize must produce metadata"); + + let children = vec![lhs, rhs]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 0d6b2bf76aa..3d856921f91 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -7,6 +7,7 @@ use std::fmt::Formatter; use std::sync::Arc; use num_traits::Float; +use prost::Message; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -18,15 +19,21 @@ use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeList; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding; use vortex_array::arrays::dict::DictArraySlotsExt; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; use vortex_array::extension::EmptyMetadata; @@ -39,12 +46,14 @@ use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::serde::ArrayChildren; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::L2Denorm; @@ -245,6 +254,97 @@ impl ScalarFnVTable for InnerProduct { } } +/// Metadata for a serialized binary tensor-op array (shared by [`InnerProduct`] and +/// [`CosineSimilarity`]). Both operands share the same extension dtype up to nullability +/// (enforced by their `return_dtype` checks), but their individual nullabilities are lost in the +/// parent's unioned output, so both are persisted. +/// +/// [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity +#[derive(Clone, prost::Message)] +pub(crate) struct BinaryTensorOpMetadata { + #[prost(message, optional, tag = "1")] + pub(crate) lhs_dtype: Option, + #[prost(message, optional, tag = "2")] + pub(crate) rhs_dtype: Option, +} + +impl BinaryTensorOpMetadata { + /// Encodes the two children of `view` into a [`BinaryTensorOpMetadata`] byte blob. + pub(crate) fn encode_from_view( + view: &ScalarFnArrayView, + ) -> VortexResult> { + let scalar_fn_array = view.as_::(); + let lhs_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?); + let rhs_dtype = Some(scalar_fn_array.child_at(1).dtype().try_into()?); + Ok(Self { + lhs_dtype, + rhs_dtype, + } + .encode_to_vec()) + } + + /// Decodes `metadata` and fetches both children from `children` using the decoded dtypes, + /// validating that `lhs` and `rhs` agree modulo nullability. + pub(crate) fn decode_children( + metadata: &[u8], + len: usize, + children: &dyn ArrayChildren, + session: &VortexSession, + scalar_fn_name: &str, + ) -> VortexResult> { + let metadata = Self::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode BinaryTensorOpMetadata: {e}"))?; + let lhs_pb = metadata + .lhs_dtype + .as_ref() + .ok_or_else(|| vortex_err!("{scalar_fn_name} metadata missing lhs_dtype"))?; + let rhs_pb = metadata + .rhs_dtype + .as_ref() + .ok_or_else(|| vortex_err!("{scalar_fn_name} metadata missing rhs_dtype"))?; + let lhs_dtype = DType::from_proto(lhs_pb, session)?; + let rhs_dtype = DType::from_proto(rhs_pb, session)?; + vortex_ensure!( + lhs_dtype.eq_ignore_nullability(&rhs_dtype), + "{scalar_fn_name} operand dtype mismatch: {lhs_dtype} vs {rhs_dtype}" + ); + let lhs = children.get(0, &lhs_dtype, len)?; + let rhs = children.get(1, &rhs_dtype, len)?; + Ok(vec![lhs, rhs]) + } +} + +impl ScalarFnArrayVTable for InnerProduct { + fn serialize( + &self, + view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + Ok(Some(BinaryTensorOpMetadata::encode_from_view(view)?)) + } + + fn deserialize( + &self, + _dtype: &DType, + len: usize, + metadata: &[u8], + children: &dyn ArrayChildren, + session: &VortexSession, + ) -> VortexResult> { + let reconstructed = BinaryTensorOpMetadata::decode_children( + metadata, + len, + children, + session, + "InnerProduct", + )?; + Ok(ScalarFnArrayParts { + options: EmptyOptions, + children: reconstructed, + }) + } +} + impl InnerProduct { /// Both sides are `L2Denorm`: `inner_product = s_l * s_r * dot(n_l, n_r)`. fn execute_both_denorm( @@ -595,12 +695,14 @@ mod tests { use std::sync::LazyLock; use rstest::rstest; + use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::MaskedArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; + use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; @@ -612,8 +714,11 @@ mod tests { use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); + static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session + }); /// Evaluates inner product between two tensor arrays and returns the result as `Vec`. fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { @@ -810,6 +915,45 @@ mod tests { Ok(()) } + #[rstest] + #[case::vector( + vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), + vector_array(3, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap(), + 2, + )] + #[case::fixed_shape_tensor( + tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0]).unwrap(), + tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).unwrap(), + 2, + )] + fn serde_round_trip( + #[case] lhs: ArrayRef, + #[case] rhs: ArrayRef, + #[case] len: usize, + ) -> VortexResult<()> { + let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(InnerProduct); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("InnerProduct serialize must produce metadata"); + + let children = vec![lhs, rhs]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } + // ---- Tests for the `SorfTransform + constant` and `Dict + constant` fast paths ---- #[allow( diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 181322855cf..f907a2b7c06 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -8,6 +8,7 @@ use std::fmt::Formatter; use num_traits::Float; use num_traits::ToPrimitive; use num_traits::Zero; +use prost::Message; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -18,13 +19,19 @@ use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; +use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; use vortex_array::match_each_float_ptype; @@ -38,6 +45,7 @@ use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::scalar_fn::fns::operators::Operator; +use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; @@ -46,6 +54,8 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; +use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_norm::L2Norm; @@ -272,6 +282,64 @@ impl ScalarFnVTable for L2Denorm { } } +/// Metadata for a serialized [`L2Denorm`] array: both children's full [`DType`]s. The parent's +/// dtype is `normalized.union_nullability(norms.nullability())`, which loses both children's +/// individual nullabilities, so we persist them directly. +#[derive(Clone, prost::Message)] +pub(super) struct L2DenormMetadata { + #[prost(message, optional, tag = "1")] + normalized_dtype: Option, + #[prost(message, optional, tag = "2")] + norms_dtype: Option, +} + +impl ScalarFnArrayVTable for L2Denorm { + fn serialize( + &self, + view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + let scalar_fn_array = view.as_::(); + let normalized_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?); + let norms_dtype = Some(scalar_fn_array.child_at(1).dtype().try_into()?); + Ok(Some( + L2DenormMetadata { + normalized_dtype, + norms_dtype, + } + .encode_to_vec(), + )) + } + + fn deserialize( + &self, + _dtype: &DType, + len: usize, + metadata: &[u8], + children: &dyn ArrayChildren, + session: &VortexSession, + ) -> VortexResult> { + let metadata = L2DenormMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode L2DenormMetadata: {e}"))?; + let normalized_pb = metadata + .normalized_dtype + .as_ref() + .ok_or_else(|| vortex_err!("L2DenormMetadata missing normalized_dtype"))?; + let norms_pb = metadata + .norms_dtype + .as_ref() + .ok_or_else(|| vortex_err!("L2DenormMetadata missing norms_dtype"))?; + let normalized_dtype = DType::from_proto(normalized_pb, session)?; + let norms_dtype = DType::from_proto(norms_pb, session)?; + let normalized = children.get(0, &normalized_dtype, len)?; + let norms = children.get(1, &norms_dtype, len)?; + Ok(ScalarFnArrayParts { + options: EmptyOptions, + children: vec![normalized, norms], + }) + } +} + /// Optimized execution when the norms array is constant. fn execute_l2_denorm_constant_norms( normalized_ref: ArrayRef, @@ -633,6 +701,8 @@ fn validate_l2_normalized_rows_impl( mod tests { use std::sync::LazyLock; + use rstest::rstest; + use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; @@ -646,6 +716,7 @@ mod tests { use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; + use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; @@ -671,8 +742,11 @@ mod tests { use crate::utils::test_helpers::vector_array; use crate::vector::Vector; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); + static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session + }); /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { @@ -1032,4 +1106,38 @@ mod tests { assert_tensor_arrays_eq(actual, expected)?; Ok(()) } + + /// Build an `L2Denorm` array from a raw input (which may have nullable storage) by running + /// `normalize_as_l2_denorm`. The normalized child ends up non-nullable, and the norms child + /// inherits the input's nullability, giving us two different per-child nullabilities to + /// round-trip. + #[rstest] + #[case::vector(vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0]).unwrap())] + #[case::fixed_shape_tensor(tensor_array(&[2, 2], &[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]).unwrap())] + fn serde_round_trip(#[case] input: ArrayRef) -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let original = normalize_as_l2_denorm(input, &mut ctx)?.into_array(); + + let scalar_fn_array = original.as_::(); + let children = scalar_fn_array.children(); + + let plugin = ScalarFnArrayPlugin::new(L2Denorm); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("L2Denorm serialize must produce metadata"); + + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 59bd9d77009..c318a78eb06 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -6,17 +6,24 @@ use std::fmt::Formatter; use num_traits::Float; +use prost::Message; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; +use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; @@ -26,10 +33,13 @@ use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::serde::ArrayChildren; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; +use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::L2Denorm; @@ -172,6 +182,49 @@ impl ScalarFnVTable for L2Norm { } } +/// Metadata for a serialized [`L2Norm`] array: the single `input` child's [`DType`], which carries +/// the extension type (`FixedShapeTensor` vs `Vector`), dimension, and nullability that are not +/// recoverable from the parent's primitive-float output. +#[derive(Clone, prost::Message)] +pub(super) struct L2NormMetadata { + #[prost(message, optional, tag = "1")] + input_dtype: Option, +} + +impl ScalarFnArrayVTable for L2Norm { + fn serialize( + &self, + view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + let scalar_fn_array = view.as_::(); + let input_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?); + Ok(Some(L2NormMetadata { input_dtype }.encode_to_vec())) + } + + fn deserialize( + &self, + _dtype: &DType, + len: usize, + metadata: &[u8], + children: &dyn ArrayChildren, + session: &VortexSession, + ) -> VortexResult> { + let metadata = L2NormMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode L2NormMetadata: {e}"))?; + let input_pb = metadata + .input_dtype + .as_ref() + .ok_or_else(|| vortex_err!("L2NormMetadata missing input_dtype"))?; + let input_dtype = DType::from_proto(input_pb, session)?; + let child = children.get(0, &input_dtype, len)?; + Ok(ScalarFnArrayParts { + options: EmptyOptions, + children: vec![child], + }) + } +} + /// Computes the L2 norm (Euclidean norm) of a float slice. /// /// Returns `sqrt(sum(v_i^2))`. A zero-length or all-zero input produces `0.0`. @@ -188,12 +241,14 @@ mod tests { use std::sync::LazyLock; use rstest::rstest; + use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::MaskedArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; + use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; @@ -204,8 +259,11 @@ mod tests { use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); + static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session + }); /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { @@ -275,4 +333,31 @@ mod tests { assert_close(&[prim.as_slice::()[0]], &[5.0]); Ok(()) } + + #[rstest] + #[case::fixed_shape_tensor(tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)] + #[case::vector(vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)] + fn serde_round_trip(#[case] child: ArrayRef, #[case] len: usize) -> VortexResult<()> { + let original = L2Norm::try_new_array(child.clone(), len)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(L2Norm); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("L2Norm serialize must produce metadata"); + + let children = vec![child]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 99d97fbe87f..782c9f90cd5 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -8,6 +8,8 @@ use std::sync::Arc; use std::sync::LazyLock; +use vortex_array::ArrayPlugin; +use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::ExtensionArray; @@ -16,6 +18,7 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -23,7 +26,9 @@ use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::EmptyMetadata; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; +use vortex_buffer::Buffer; use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_session::VortexSession; @@ -35,8 +40,11 @@ use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::encodings::turboquant::centroids::get_centroids; use crate::vector::Vector; -static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); +static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session +}); /// Build a unit-normalized input vector array and forward-transform + quantize it, returning /// `(input_f32, Vector(FSL(Dict(codes, centroids))), padded_dim)`. @@ -436,3 +444,57 @@ fn f64_output_type() -> VortexResult<()> { Ok(()) } + +/// Build a trivial `Vector>` child populated with zeroes. The values +/// are irrelevant for the serde round-trip test; only the dtype shape matters. +fn trivial_padded_vector(padded_dim: u32, num_rows: usize, validity: Validity) -> ArrayRef { + let elements = PrimitiveArray::new( + Buffer::::zeroed(num_rows * padded_dim as usize), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new(elements.into_array(), padded_dim, validity, num_rows) + .vortex_expect("fsl must build"); + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) + .vortex_expect("ext dtype must build") + .erased(); + ExtensionArray::new(ext_dtype, fsl.into_array()).into_array() +} + +#[rstest::rstest] +// Non-power-of-two dimension to exercise `padded_dim = dim.next_power_of_two()`. +#[case::power_of_two_dim(128, Validity::NonNullable)] +#[case::non_power_of_two_dim(100, Validity::NonNullable)] +// Nullable top-level Vector to verify child nullability is reconstructed from the parent output. +#[case::nullable_child(100, Validity::AllValid)] +fn serde_round_trip(#[case] dimension: u32, #[case] validity: Validity) -> VortexResult<()> { + let padded_dim = dimension.next_power_of_two(); + let num_rows = 4; + let options = SorfOptions { + seed: 42, + num_rounds: 3, + dimension, + element_ptype: PType::F32, + }; + let child = trivial_padded_vector(padded_dim, num_rows, validity); + let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); + + let plugin = ScalarFnArrayPlugin::new(SorfTransform); + let metadata = plugin + .serialize(&original, &SESSION)? + .expect("SorfTransform serialize must produce metadata"); + + let children = vec![child]; + let recovered = plugin.deserialize( + original.dtype(), + original.len(), + &metadata, + &[], + &children, + &SESSION, + )?; + + assert_eq!(recovered.dtype(), original.dtype()); + assert_eq!(recovered.len(), original.len()); + assert_eq!(recovered.encoding_id(), original.encoding_id()); + Ok(()) +} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 5c77b48eb87..64f92da384f 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use num_traits::Float; use num_traits::FromPrimitive; +use prost::Message; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -17,6 +18,9 @@ use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; @@ -30,12 +34,14 @@ use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; +use vortex_session::VortexSession; use super::SorfOptions; use super::SorfTransform; @@ -188,6 +194,93 @@ impl ScalarFnVTable for SorfTransform { } } +/// Metadata for a serialized [`SorfTransform`] array. +/// +/// Stores the full [`SorfOptions`] inline. The child [`DType`] is not serialized because it is +/// fully determined by the options: the child is always a [`Vector`] extension wrapping +/// `FSL`. The child's nullability is recovered from the +/// parent output dtype at deserialize time, since `SorfTransform::return_dtype` propagates child +/// nullability into the output FSL (see `return_dtype` above). +#[derive(Clone, prost::Message)] +pub(super) struct SorfTransformMetadata { + #[prost(uint64, tag = "1")] + seed: u64, + /// Rust `u8` widened to `u32` for protobuf (no `u8` on the wire). + #[prost(uint32, tag = "2")] + num_rounds: u32, + #[prost(uint32, tag = "3")] + dimension: u32, + #[prost(enumeration = "PType", tag = "4")] + element_ptype: i32, +} + +impl ScalarFnArrayVTable for SorfTransform { + fn serialize( + &self, + view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + let options = view.options; + Ok(Some( + SorfTransformMetadata { + seed: options.seed, + num_rounds: u32::from(options.num_rounds), + dimension: options.dimension, + element_ptype: options.element_ptype as i32, + } + .encode_to_vec(), + )) + } + + fn deserialize( + &self, + dtype: &DType, + len: usize, + metadata: &[u8], + children: &dyn ArrayChildren, + _session: &VortexSession, + ) -> VortexResult> { + let metadata = SorfTransformMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; + let options = SorfOptions { + seed: metadata.seed, + num_rounds: u8::try_from(metadata.num_rounds).map_err(|_| { + vortex_err!( + "SorfTransform num_rounds {} does not fit in u8", + metadata.num_rounds + ) + })?, + dimension: metadata.dimension, + element_ptype: metadata.element_ptype(), + }; + validate_sorf_options(&options)?; + + // `return_dtype` sets the output FSL's nullability to the child's nullability (see + // `return_dtype` above), so we read the child nullability back from the parent dtype. + let child_nullability = dtype + .as_extension_opt() + .ok_or_else(|| { + vortex_err!("SorfTransform parent dtype must be a Vector extension, got {dtype}") + })? + .storage_dtype() + .nullability(); + let padded_dim = options.dimension.next_power_of_two(); + let child_storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + padded_dim, + child_nullability, + ); + let child_ext = ExtDType::::try_new(EmptyMetadata, child_storage)?.erased(); + let child_dtype = DType::Extension(child_ext); + let child = children.get(0, &child_dtype, len)?; + + Ok(ScalarFnArrayParts { + options, + children: vec![child], + }) + } +} + /// Convert an f32 value to a float type `T`. /// /// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the From 38e71419b50dc017769fb79246850da69a741b60 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 14 Apr 2026 21:37:23 -0400 Subject: [PATCH 2/2] single SESSION for tests Signed-off-by: Connor Tsui --- .../src/encodings/turboquant/tests/mod.rs | 8 +------- vortex-tensor/src/lib.rs | 14 ++++++++++++++ vortex-tensor/src/scalar_fns/cosine_similarity.rs | 10 +--------- vortex-tensor/src/scalar_fns/inner_product.rs | 10 +--------- vortex-tensor/src/scalar_fns/l2_denorm.rs | 10 +--------- vortex-tensor/src/scalar_fns/l2_norm.rs | 10 +--------- .../src/scalar_fns/sorf_transform/tests.rs | 10 +--------- 7 files changed, 20 insertions(+), 52 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index 4c01affa27d..b111c6e28ba 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -8,8 +8,6 @@ mod nullable; mod roundtrip; mod structural; -use std::sync::LazyLock; - use rand::SeedableRng; use rand::rngs::StdRng; use rand_distr::Distribution; @@ -29,22 +27,18 @@ use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::EmptyMetadata; -use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_session::VortexSession; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::turboquant_encode_unchecked; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::tests::SESSION; use crate::vector::Vector; -static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - /// Create a FixedSizeListArray of random f32 vectors with the given validity. fn make_fsl_with_validity( num_rows: usize, diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 171a0d817f0..d2fb02a6c7f 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -51,3 +51,17 @@ pub fn initialize(session: &VortexSession) { session_arrays.register(ScalarFnArrayPlugin::new(L2Norm)); session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform)); } + +#[cfg(test)] +mod tests { + use std::sync::LazyLock; + + use vortex_array::session::ArraySession; + use vortex_session::VortexSession; + + pub static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session + }); +} diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 12e06701418..6c06d165a0e 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -329,7 +329,6 @@ impl CosineSimilarity { #[cfg(test)] mod tests { - use std::sync::LazyLock; use rstest::rstest; use vortex_array::ArrayPlugin; @@ -340,25 +339,18 @@ mod tests { use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::scalar_fns::l2_denorm::L2Denorm; + use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; use crate::utils::test_helpers::constant_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - static SESSION: LazyLock = LazyLock::new(|| { - let session = VortexSession::empty().with::(); - crate::initialize(&session); - session - }); - /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { let scalar_fn = CosineSimilarity::new().erased(); diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 3d856921f91..3f4678626bd 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -692,7 +692,6 @@ fn execute_dict_constant_inner_product( #[cfg(test)] mod tests { - use std::sync::LazyLock; use rstest::rstest; use vortex_array::ArrayPlugin; @@ -703,23 +702,16 @@ mod tests { use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::L2Denorm; + use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - static SESSION: LazyLock = LazyLock::new(|| { - let session = VortexSession::empty().with::(); - crate::initialize(&session); - session - }); - /// Evaluates inner product between two tensor arrays and returns the result as `Vec`. fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { let scalar_fn = InnerProduct::new().erased(); diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index f907a2b7c06..01164a71dfe 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -699,7 +699,6 @@ fn validate_l2_normalized_rows_impl( #[cfg(test)] mod tests { - use std::sync::LazyLock; use rstest::rstest; use vortex_array::ArrayPlugin; @@ -724,17 +723,16 @@ mod tests { use vortex_array::extension::datetime::Date; use vortex_array::extension::datetime::TimeUnit; use vortex_array::scalar::Scalar; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; + use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; use crate::utils::test_helpers::constant_vector_array; @@ -742,12 +740,6 @@ mod tests { use crate::utils::test_helpers::vector_array; use crate::vector::Vector; - static SESSION: LazyLock = LazyLock::new(|| { - let session = VortexSession::empty().with::(); - crate::initialize(&session); - session - }); - /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { let mut ctx = SESSION.create_execution_ctx(); diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index c318a78eb06..13245dd880b 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -238,7 +238,6 @@ fn l2_norm_row(v: &[T]) -> T { #[cfg(test)] mod tests { - use std::sync::LazyLock; use rstest::rstest; use vortex_array::ArrayPlugin; @@ -249,22 +248,15 @@ mod tests { use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::scalar_fns::l2_norm::L2Norm; + use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - static SESSION: LazyLock = LazyLock::new(|| { - let session = VortexSession::empty().with::(); - crate::initialize(&session); - session - }); - /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { let scalar_fn = L2Norm::new().erased(); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 782c9f90cd5..64308c39562 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -6,7 +6,6 @@ #![allow(clippy::cast_possible_truncation)] use std::sync::Arc; -use std::sync::LazyLock; use vortex_array::ArrayPlugin; use vortex_array::ArrayRef; @@ -24,13 +23,11 @@ use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::EmptyMetadata; -use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_session::VortexSession; use super::SorfOptions; use super::SorfTransform; @@ -38,14 +35,9 @@ use super::rotation::SorfMatrix; use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::encodings::turboquant::centroids::get_centroids; +use crate::tests::SESSION; use crate::vector::Vector; -static SESSION: LazyLock = LazyLock::new(|| { - let session = VortexSession::empty().with::(); - crate::initialize(&session); - session -}); - /// Build a unit-normalized input vector array and forward-transform + quantize it, returning /// `(input_f32, Vector(FSL(Dict(codes, centroids))), padded_dim)`. ///