Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -588,4 +588,6 @@ pub fn vortex_tensor::vector_search::build_similarity_search_tree<T: vortex_arra

pub fn vortex_tensor::vector_search::compress_turboquant(data: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>

pub const vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str

pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession)
26 changes: 20 additions & 6 deletions vortex-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,39 @@ pub mod vector_search;

mod utils;

/// Environment variable that gates registration of the tensor scalar-fn array plugins (the array
/// encodings that let [`CosineSimilarity`], [`InnerProduct`], [`L2Denorm`], [`L2Norm`], and
/// [`SorfTransform`] persist in a Vortex file). When unset, only the scalar functions themselves
/// are registered; readers of files containing serialized tensor scalar-fn arrays will fail to
/// deserialize. Opt-in by setting the variable to any non-empty value.
pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_PLUGIN";

/// Initialize the Vortex tensor library with a Vortex session.
pub fn initialize(session: &VortexSession) {
session.dtypes().register(Vector);
session.dtypes().register(FixedShapeTensor);

let session_fns = session.scalar_fns();
let session_arrays = session.arrays();

session_fns.register(CosineSimilarity);
session_fns.register(InnerProduct);
session_fns.register(L2Denorm);
session_fns.register(L2Norm);
session_fns.register(SorfTransform);

session_arrays.register(ScalarFnArrayPlugin::new(CosineSimilarity));
session_arrays.register(ScalarFnArrayPlugin::new(InnerProduct));
session_arrays.register(ScalarFnArrayPlugin::new(L2Denorm));
session_arrays.register(ScalarFnArrayPlugin::new(L2Norm));
session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform));
// Registering the scalar-fn array plugins lets the tensor scalar fns be serialized as array
// encodings inside Vortex files. Gate this on an env var so applications that do not intend
// to persist these encodings do not pay the registry cost or widen their stable-encoding
// surface unintentionally.
if std::env::var_os(SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV).is_some_and(|v| !v.is_empty()) {
let session_arrays = session.arrays();

session_arrays.register(ScalarFnArrayPlugin::new(CosineSimilarity));
session_arrays.register(ScalarFnArrayPlugin::new(InnerProduct));
session_arrays.register(ScalarFnArrayPlugin::new(L2Denorm));
session_arrays.register(ScalarFnArrayPlugin::new(L2Norm));
session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform));
}
}

#[cfg(test)]
Expand Down
Loading