Skip to content

Commit 61ae201

Browse files
committed
support serializin tensor scalar fns
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent c690c2c commit 61ae201

8 files changed

Lines changed: 629 additions & 15 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::Cosine
248248

249249
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
250250

251+
impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
252+
253+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
254+
255+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
256+
251257
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity
252258

253259
pub type vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::Options = vortex_array::scalar_fn::vtable::EmptyOptions
@@ -284,6 +290,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::inner_product::InnerProdu
284290

285291
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::clone(&self) -> vortex_tensor::scalar_fns::inner_product::InnerProduct
286292

293+
impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct
294+
295+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
296+
297+
pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
298+
287299
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct
288300

289301
pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_array::scalar_fn::vtable::EmptyOptions
@@ -322,6 +334,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::l2_denorm::L2Denorm
322334

323335
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::clone(&self) -> vortex_tensor::scalar_fns::l2_denorm::L2Denorm
324336

337+
impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm
338+
339+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
340+
341+
pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
342+
325343
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_denorm::L2Denorm
326344

327345
pub type vortex_tensor::scalar_fns::l2_denorm::L2Denorm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
@@ -362,6 +380,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm
362380

363381
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm
364382

383+
impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm
384+
385+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::deserialize(&self, _dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
386+
387+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
388+
365389
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::l2_norm::L2Norm
366390

367391
pub type vortex_tensor::scalar_fns::l2_norm::L2Norm::Options = vortex_array::scalar_fn::vtable::EmptyOptions
@@ -444,6 +468,12 @@ impl core::clone::Clone for vortex_tensor::scalar_fns::sorf_transform::SorfTrans
444468

445469
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::clone(&self) -> vortex_tensor::scalar_fns::sorf_transform::SorfTransform
446470

471+
impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform
472+
473+
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts<Self>>
474+
475+
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::serialize(&self, view: &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
476+
447477
impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::sorf_transform::SorfTransform
448478

449479
pub type vortex_tensor::scalar_fns::sorf_transform::SorfTransform::Options = vortex_tensor::scalar_fns::sorf_transform::SorfOptions

vortex-tensor/src/lib.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
//! including unit vectors, spherical coordinates, and similarity measures such as cosine
66
//! similarity.
77
8+
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
89
use vortex_array::dtype::session::DTypeSessionExt;
910
use vortex_array::scalar_fn::session::ScalarFnSessionExt;
11+
use vortex_array::session::ArraySessionExt;
1012
use vortex_session::VortexSession;
1113

1214
use crate::fixed_shape::FixedShapeTensor;
@@ -34,9 +36,18 @@ pub fn initialize(session: &VortexSession) {
3436
session.dtypes().register(Vector);
3537
session.dtypes().register(FixedShapeTensor);
3638

37-
session.scalar_fns().register(CosineSimilarity);
38-
session.scalar_fns().register(InnerProduct);
39-
session.scalar_fns().register(L2Denorm);
40-
session.scalar_fns().register(L2Norm);
41-
session.scalar_fns().register(SorfTransform);
39+
let session_fns = session.scalar_fns();
40+
let session_arrays = session.arrays();
41+
42+
session_fns.register(CosineSimilarity);
43+
session_fns.register(InnerProduct);
44+
session_fns.register(L2Denorm);
45+
session_fns.register(L2Norm);
46+
session_fns.register(SorfTransform);
47+
48+
session_arrays.register(ScalarFnArrayPlugin::new(CosineSimilarity));
49+
session_arrays.register(ScalarFnArrayPlugin::new(InnerProduct));
50+
session_arrays.register(ScalarFnArrayPlugin::new(L2Denorm));
51+
session_arrays.register(ScalarFnArrayPlugin::new(L2Norm));
52+
session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform));
4253
}

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ use vortex_array::IntoArray;
1212
use vortex_array::arrays::PrimitiveArray;
1313
use vortex_array::arrays::ScalarFnArray;
1414
use vortex_array::arrays::scalar_fn::ExactScalarFn;
15+
use vortex_array::arrays::scalar_fn::ScalarFnArrayView;
16+
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts;
17+
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable;
1518
use vortex_array::builtins::ArrayBuiltins;
1619
use vortex_array::dtype::DType;
1720
use vortex_array::dtype::Nullability;
@@ -25,11 +28,14 @@ use vortex_array::scalar_fn::ExecutionArgs;
2528
use vortex_array::scalar_fn::ScalarFn;
2629
use vortex_array::scalar_fn::ScalarFnId;
2730
use vortex_array::scalar_fn::ScalarFnVTable;
31+
use vortex_array::serde::ArrayChildren;
2832
use vortex_array::validity::Validity;
2933
use vortex_buffer::Buffer;
3034
use vortex_error::VortexResult;
3135
use vortex_error::vortex_ensure;
36+
use vortex_session::VortexSession;
3237

38+
use crate::scalar_fns::inner_product::BinaryTensorOpMetadata;
3339
use crate::scalar_fns::inner_product::InnerProduct;
3440
use crate::scalar_fns::l2_denorm::L2Denorm;
3541
use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm;
@@ -221,6 +227,37 @@ impl ScalarFnVTable for CosineSimilarity {
221227
}
222228
}
223229

230+
impl ScalarFnArrayVTable for CosineSimilarity {
231+
fn serialize(
232+
&self,
233+
view: &ScalarFnArrayView<Self>,
234+
_session: &VortexSession,
235+
) -> VortexResult<Option<Vec<u8>>> {
236+
Ok(Some(BinaryTensorOpMetadata::encode_from_view(view)?))
237+
}
238+
239+
fn deserialize(
240+
&self,
241+
_dtype: &DType,
242+
len: usize,
243+
metadata: &[u8],
244+
children: &dyn ArrayChildren,
245+
session: &VortexSession,
246+
) -> VortexResult<ScalarFnArrayParts<Self>> {
247+
let reconstructed = BinaryTensorOpMetadata::decode_children(
248+
metadata,
249+
len,
250+
children,
251+
session,
252+
"CosineSimilarity",
253+
)?;
254+
Ok(ScalarFnArrayParts {
255+
options: EmptyOptions,
256+
children: reconstructed,
257+
})
258+
}
259+
}
260+
224261
impl CosineSimilarity {
225262
/// Both sides are `L2Denorm`: treat the normalized children as authoritative, so
226263
/// `cosine_similarity = dot(n_l, n_r)`.
@@ -295,12 +332,14 @@ mod tests {
295332
use std::sync::LazyLock;
296333

297334
use rstest::rstest;
335+
use vortex_array::ArrayPlugin;
298336
use vortex_array::ArrayRef;
299337
use vortex_array::IntoArray;
300338
use vortex_array::VortexSessionExecute;
301339
use vortex_array::arrays::MaskedArray;
302340
use vortex_array::arrays::PrimitiveArray;
303341
use vortex_array::arrays::ScalarFnArray;
342+
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
304343
use vortex_array::session::ArraySession;
305344
use vortex_array::validity::Validity;
306345
use vortex_error::VortexResult;
@@ -314,8 +353,11 @@ mod tests {
314353
use crate::utils::test_helpers::tensor_array;
315354
use crate::utils::test_helpers::vector_array;
316355

317-
static SESSION: LazyLock<VortexSession> =
318-
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
356+
static SESSION: LazyLock<VortexSession> = LazyLock::new(|| {
357+
let session = VortexSession::empty().with::<ArraySession>();
358+
crate::initialize(&session);
359+
session
360+
});
319361

320362
/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
321363
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
@@ -693,4 +735,43 @@ mod tests {
693735
);
694736
Ok(())
695737
}
738+
739+
#[rstest]
740+
#[case::vector(
741+
vector_array(3, &[1.0, 0.0, 0.0, 3.0, 4.0, 0.0]).unwrap(),
742+
vector_array(3, &[0.0, 1.0, 0.0, 3.0, 4.0, 0.0]).unwrap(),
743+
2,
744+
)]
745+
#[case::fixed_shape_tensor(
746+
tensor_array(&[2], &[1.0, 0.0, 3.0, 4.0]).unwrap(),
747+
tensor_array(&[2], &[0.0, 1.0, 3.0, 4.0]).unwrap(),
748+
2,
749+
)]
750+
fn serde_round_trip(
751+
#[case] lhs: ArrayRef,
752+
#[case] rhs: ArrayRef,
753+
#[case] len: usize,
754+
) -> VortexResult<()> {
755+
let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array();
756+
757+
let plugin = ScalarFnArrayPlugin::new(CosineSimilarity);
758+
let metadata = plugin
759+
.serialize(&original, &SESSION)?
760+
.expect("CosineSimilarity serialize must produce metadata");
761+
762+
let children = vec![lhs, rhs];
763+
let recovered = plugin.deserialize(
764+
original.dtype(),
765+
original.len(),
766+
&metadata,
767+
&[],
768+
&children,
769+
&SESSION,
770+
)?;
771+
772+
assert_eq!(recovered.dtype(), original.dtype());
773+
assert_eq!(recovered.len(), original.len());
774+
assert_eq!(recovered.encoding_id(), original.encoding_id());
775+
Ok(())
776+
}
696777
}

0 commit comments

Comments
 (0)