Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 111 additions & 4 deletions vortex-tensor/src/scalar_fns/l2_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use prost::Message;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
Expand All @@ -26,6 +28,7 @@ use vortex_array::dtype::Nullability;
use vortex_array::dtype::proto::dtype as pb;
use vortex_array::expr::Expression;
use vortex_array::match_each_float_ptype;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::Arity;
use vortex_array::scalar_fn::ChildName;
use vortex_array::scalar_fn::EmptyOptions;
Expand Down Expand Up @@ -131,6 +134,8 @@ impl ScalarFnVTable for L2Norm {
let tensor_flat_size = tensor_match.list_size();
let element_ptype = tensor_match.element_ptype();

let norm_dtype = DType::Primitive(element_ptype, ext.nullability());

// L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored
// norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics
// instead of forcing a decode-and-recompute path here.
Expand All @@ -139,14 +144,37 @@ impl ScalarFnVTable for L2Norm {
.nth_child(1)
.vortex_expect("L2Denom must have at 2 children");

vortex_ensure_eq!(
norms.dtype(),
&DType::Primitive(element_ptype, input_ref.dtype().nullability())
);
vortex_ensure_eq!(norms.dtype(), &norm_dtype);

return Ok(norms);
}

// Optimize for the constant array case.
if let Some(array) = input_ref.as_opt::<Constant>() {
let scalar = array.scalar().as_extension().to_storage_scalar();

let Some(elements) = scalar.as_list().elements() else {
return Ok(ConstantArray::new(Scalar::null(norm_dtype), row_count).into_array());
};

let norm_scalar = match_each_float_ptype!(element_ptype, |T| {
let values: Vec<T> = elements
.iter()
.map(|s| {
s.as_primitive()
.as_::<T>()
.vortex_expect("element was somehow not the correct float")
})
.collect();
let norm = l2_norm_row::<T>(&values);

Scalar::try_new(norm_dtype, Some(norm.into()))
})?;

let norms = ConstantArray::new(norm_scalar, row_count).into_array();
return Ok(norms);
}

let input: ExtensionArray = input_ref.execute(ctx)?;
let validity = input.as_ref().validity()?;

Expand Down Expand Up @@ -244,10 +272,18 @@ mod tests {
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::MaskedArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::dtype::PType;
use vortex_array::dtype::extension::ExtDType;
use vortex_array::extension::EmptyMetadata;
use vortex_array::scalar::Scalar;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;

Expand All @@ -256,6 +292,7 @@ mod tests {
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;
use crate::vector::Vector;

/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
Expand Down Expand Up @@ -326,6 +363,76 @@ mod tests {
Ok(())
}

/// Builds a [`ConstantArray`] whose scalar is a [`Vector`] extension scalar wrapping a
/// fixed-size list of `elements`, broadcast to `len` rows.
fn constant_vector_extension_array(elements: &[f64], len: usize) -> ArrayRef {
let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar =
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
let ext_scalar = Scalar::extension::<Vector>(EmptyMetadata, storage_scalar);
ConstantArray::new(ext_scalar, len).into_array()
}

/// A constant input whose scalar is a non-null tensor should short-circuit to a
/// [`ConstantArray`] output whose scalar is the precomputed norm. Uses [`execute_until`] so
/// execution stops at the [`Constant`] encoding instead of canonicalizing into a
/// [`PrimitiveArray`].
#[test]
fn constant_non_null_input_yields_constant_output() -> VortexResult<()> {
let input = constant_vector_extension_array(&[3.0, 4.0], 4);

let scalar_fn = L2Norm::new().erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array();
let mut ctx = SESSION.create_execution_ctx();
let output = result.execute_until::<Constant>(&mut ctx)?;

let constant = output
.as_opt::<Constant>()
.expect("L2Norm over a constant input must produce a constant output");
assert_eq!(constant.len(), 4);
let norm = constant
.scalar()
.as_primitive()
.as_::<f64>()
.expect("norm scalar must be a non-null primitive");
assert_close(&[norm], &[5.0]);
Ok(())
}

/// A constant input whose scalar is null should short-circuit to a null [`ConstantArray`] of
/// the correct primitive dtype and length.
#[test]
fn constant_null_input_yields_null_constant_output() -> VortexResult<()> {
let storage_dtype = DType::FixedSizeList(
DType::Primitive(PType::F64, Nullability::NonNullable).into(),
2,
Nullability::Nullable,
);
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, storage_dtype)?.erased();
let null_scalar = Scalar::null(DType::Extension(ext_dtype));
let input = ConstantArray::new(null_scalar, 3).into_array();

let scalar_fn = L2Norm::new().erased();
let result = ScalarFnArray::try_new(scalar_fn, vec![input], 3)?.into_array();
let mut ctx = SESSION.create_execution_ctx();
let output = result.execute_until::<Constant>(&mut ctx)?;

let constant = output
.as_opt::<Constant>()
.expect("null constant input must produce a constant output");
assert_eq!(constant.len(), 3);
assert!(constant.scalar().is_null());
assert_eq!(
constant.dtype(),
&DType::Primitive(PType::F64, Nullability::Nullable)
);
Ok(())
}

#[rstest]
#[case::fixed_shape_tensor(tensor_array(&[3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)]
#[case::vector(vector_array(3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(), 2)]
Expand Down
Loading