diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 0c1ae57a82e..8a14e2204f0 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -412,7 +412,7 @@ pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _opt pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 391ca57bc19..39c8b135e43 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -9,9 +9,10 @@ use num_traits::Zero; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; @@ -23,13 +24,16 @@ use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; +use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_norm::L2Norm; +use crate::utils::extract_l2_denorm_children; use crate::utils::validate_tensor_float_input; /// Cosine similarity between two columns. @@ -126,35 +130,47 @@ impl ScalarFnVTable for CosineSimilarity { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs = args.get(0)?.execute::(ctx)?; - let rhs = args.get(1)?.execute::(ctx)?; + let mut lhs_ref = args.get(0)?; + let mut rhs_ref = args.get(1)?; let len = args.row_count(); - // Compute combined validity. - let validity = lhs.as_ref().validity()?.and(rhs.as_ref().validity()?)?; + // Check if any of our children have be already normalized. + { + let lhs_is_denorm = lhs_ref.is::>(); + let rhs_is_denorm = rhs_ref.is::>(); + + if lhs_is_denorm && rhs_is_denorm { + return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + } else if lhs_is_denorm || rhs_is_denorm { + if rhs_is_denorm { + (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); + } + return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + } + } - let lhs = lhs.into_array(); - let rhs = rhs.into_array(); + // Compute combined validity. + let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; // Compute inner product and norms as columnar operations, and propagate the options. - let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?; - let norm_rhs_arr = L2Norm::try_new_array(options, rhs.clone(), len)?; - let dot_arr = InnerProduct::try_new_array(options, lhs, rhs, len)?; + let norm_lhs_arr = L2Norm::try_new_array(options, lhs_ref.clone(), len)?; + let norm_rhs_arr = L2Norm::try_new_array(options, rhs_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(options, lhs_ref, rhs_ref, len)?; - // Execute to get PrimitiveArrays. + // Execute to get the inner product and norms of the arrays. We only fully decompress + // because we need to perform special logic (guard against 0) during division. let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?; let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?; - // Divide element-wise, guarding against zero norms. + // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. + // TODO(connor): This can be written in a more SIMD-friendly manner. match_each_float_ptype!(dot.ptype(), |T| { let dots = dot.as_slice::(); let norms_l = norm_l.as_slice::(); let norms_r = norm_r.as_slice::(); let buffer: Buffer = (0..len) .map(|i| { - // TODO(connor): Would it be better to make this a binary multiply? - // What happens when this overflows??? let denom = norms_l[i] * norms_r[i]; if denom == T::zero() { @@ -191,6 +207,74 @@ impl ScalarFnVTable for CosineSimilarity { } } +impl CosineSimilarity { + /// Both sides are `L2Denorm`: norms cancel, so `cosine_similarity = dot(n_l, n_r)`. + fn execute_both_denorm( + &self, + options: &ApproxOptions, + lhs_ref: &ArrayRef, + rhs_ref: &ArrayRef, + len: usize, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; + + let (normalized_l, _) = extract_l2_denorm_children(lhs_ref); + let (normalized_r, _) = extract_l2_denorm_children(rhs_ref); + + // Dot product of already-normalized children IS the cosine similarity. + let dot = + InnerProduct::try_new_array(options, normalized_l, normalized_r, len)?.into_array(); + + if !matches!(validity, Validity::NonNullable) { + // Masking always changes the nullability to nullable. + dot.mask(validity.to_array(len)) + } else { + Ok(dot) + } + } + + /// One side is `L2Denorm`: `cosine_similarity = dot(n, b) / ||b||`. + /// + /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. + fn execute_one_denorm( + &self, + options: &ApproxOptions, + denorm_ref: &ArrayRef, + plain_ref: &ArrayRef, + len: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?; + + let (normalized, _) = extract_l2_denorm_children(denorm_ref); + + let dot_arr = InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)?; + let norm_arr = L2Norm::try_new_array(options, plain_ref.clone(), len)?; + let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; + let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?; + + // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. + // TODO(connor): This can be written in a more SIMD-friendly manner. + match_each_float_ptype!(dot.ptype(), |T| { + let dots = dot.as_slice::(); + let norms = plain_norm.as_slice::(); + let buffer: Buffer = (0..len) + .map(|i| { + if norms[i] == T::zero() { + T::zero() + } else { + dots[i] / norms[i] + } + }) + .collect(); + + // SAFETY: The buffer length equals `len`, which matches the source validity length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) + }) + } +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -210,6 +294,7 @@ mod tests { use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::cosine_similarity::CosineSimilarity; + use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; use crate::utils::test_helpers::constant_vector_array; @@ -403,4 +488,99 @@ mod tests { assert_close(&[prim.as_slice::()[0]], &[1.0]); Ok(()) } + + /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. + fn l2_denorm_array( + shape: &[usize], + normalized_elements: &[f64], + norms: &[f64], + ) -> VortexResult { + let len = norms.len(); + let normalized = tensor_array(shape, normalized_elements)?; + let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + Ok( + L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)? + .into_array(), + ) + } + + #[test] + fn both_denorm_self_similarity() -> VortexResult<()> { + // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8]. + // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0]. + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + + // Self-similarity should always be 1.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); + Ok(()) + } + + #[test] + fn both_denorm_orthogonal() -> VortexResult<()> { + // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0. + // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0. + let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?; + let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?; + + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); + Ok(()) + } + + #[test] + fn both_denorm_zero_norm() -> VortexResult<()> { + // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0. + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + + // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + #[test] + fn one_side_denorm_lhs() -> VortexResult<()> { + // LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. + // RHS is plain [3.0, 4.0]. + // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0. + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = tensor_array(&[2], &[3.0, 4.0])?; + + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + Ok(()) + } + + #[test] + fn one_side_denorm_rhs() -> VortexResult<()> { + // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. + // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6. + let lhs = tensor_array(&[2], &[1.0, 0.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + + assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); + Ok(()) + } + + #[test] + fn both_denorm_null_norms() -> VortexResult<()> { + // Row 0: valid, row 1: null (via nullable norms on rhs). + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + + let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; + let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + let rhs = + L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_r, norms_r, 2, &mut ctx)? + .into_array(); + + let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + assert!(prim.is_valid(0)?); + assert!(!prim.is_valid(1)?); + assert_close(&[prim.as_slice::()[0]], &[1.0]); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 7d07af09404..da1b62e6ca1 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -13,6 +13,7 @@ use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; @@ -33,7 +34,9 @@ use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::extract_flat_elements; +use crate::utils::extract_l2_denorm_children; /// Inner product (dot product) between two columns. /// @@ -141,14 +144,35 @@ impl ScalarFnVTable for InnerProduct { fn execute( &self, - _options: &Self::Options, + options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs: ExtensionArray = args.get(0)?.execute(ctx)?; - let rhs: ExtensionArray = args.get(1)?.execute(ctx)?; + let mut lhs_ref = args.get(0)?; + let mut rhs_ref = args.get(1)?; + let len = args.row_count(); + + // Check if any of our children have be already normalized. + { + let lhs_is_denorm = lhs_ref.is::>(); + let rhs_is_denorm = rhs_ref.is::>(); + + if lhs_is_denorm && rhs_is_denorm { + return self.execute_both_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + } else if lhs_is_denorm || rhs_is_denorm { + if rhs_is_denorm { + (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); + } + return self.execute_one_denorm(options, &lhs_ref, &rhs_ref, len, ctx); + } + } - let row_count = args.row_count(); + // Compute combined validity. + let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; + + // Canonicalize so we can perform the math directly. + let lhs: ExtensionArray = lhs_ref.execute(ctx)?; + let rhs: ExtensionArray = rhs_ref.execute(ctx)?; // We validated that both inputs have the same type. let ext = lhs.dtype().as_extension(); @@ -157,10 +181,6 @@ impl ScalarFnVTable for InnerProduct { .vortex_expect("we already validated this in `return_dtype`"); let dimensions = tensor_match.list_size(); - // Compute combined validity. - let rhs_validity = rhs.as_ref().validity()?; - let validity = lhs.as_ref().validity()?.and(rhs_validity)?; - // Extract the storage array from each extension input. We pass the storage (FSL) rather // than the extension array to avoid canonicalizing the extension wrapper. let lhs_storage = lhs.storage_array(); @@ -170,7 +190,7 @@ impl ScalarFnVTable for InnerProduct { let rhs_flat = extract_flat_elements(rhs_storage, dimensions, ctx)?; match_each_float_ptype!(lhs_flat.ptype(), |T| { - let buffer: Buffer = (0..row_count) + let buffer: Buffer = (0..len) .map(|i| inner_product_row(lhs_flat.row::(i), rhs_flat.row::(i))) .collect(); @@ -201,6 +221,72 @@ impl ScalarFnVTable for InnerProduct { } } +impl InnerProduct { + /// Both sides are `L2Denorm`: `inner_product = s_l * s_r * dot(n_l, n_r)`. + fn execute_both_denorm( + &self, + options: &ApproxOptions, + lhs_ref: &ArrayRef, + rhs_ref: &ArrayRef, + len: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; + + let (normalized_l, norms_l) = extract_l2_denorm_children(lhs_ref); + let (normalized_r, norms_r) = extract_l2_denorm_children(rhs_ref); + + let norms_l: PrimitiveArray = norms_l.execute(ctx)?; + let norms_r: PrimitiveArray = norms_r.execute(ctx)?; + + let dot: PrimitiveArray = + InnerProduct::try_new_array(options, normalized_l, normalized_r, len)? + .into_array() + .execute(ctx)?; + + match_each_float_ptype!(dot.ptype(), |T| { + let dots = dot.as_slice::(); + let nl = norms_l.as_slice::(); + let nr = norms_r.as_slice::(); + let buffer: Buffer = (0..len).map(|i| nl[i] * nr[i] * dots[i]).collect(); + + // SAFETY: The buffer length equals `len`, which matches the source validity length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) + }) + } + + /// One side is `L2Denorm`: `inner_product = s * dot(n, other)`. + /// + /// The caller must pass the denorm array as `denorm_ref` and the plain array as `plain_ref`. + fn execute_one_denorm( + &self, + options: &ApproxOptions, + denorm_ref: &ArrayRef, + plain_ref: &ArrayRef, + len: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let validity = denorm_ref.validity()?.and(plain_ref.validity()?)?; + + let (normalized, norms) = extract_l2_denorm_children(denorm_ref); + let denorm_norms: PrimitiveArray = norms.execute(ctx)?; + + let dot: PrimitiveArray = + InnerProduct::try_new_array(options, normalized, plain_ref.clone(), len)? + .into_array() + .execute(ctx)?; + + match_each_float_ptype!(dot.ptype(), |T| { + let dots = dot.as_slice::(); + let ns = denorm_norms.as_slice::(); + let buffer: Buffer = (0..len).map(|i| ns[i] * dots[i]).collect(); + + // SAFETY: The buffer length equals `len`, which matches the source validity length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) + }) + } +} + /// Computes the inner product (dot product) of two equal-length float slices. /// /// Returns `sum(a_i * b_i)`. @@ -230,6 +316,7 @@ mod tests { use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; + use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -347,4 +434,93 @@ mod tests { assert!(result.is_err()); Ok(()) } + + /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. + fn l2_denorm_array( + shape: &[usize], + normalized_elements: &[f64], + norms: &[f64], + ) -> VortexResult { + use vortex_array::IntoArray; + + let len = norms.len(); + let normalized = tensor_array(shape, normalized_elements)?; + let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + Ok( + L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)? + .into_array(), + ) + } + + #[test] + fn both_denorm() -> VortexResult<()> { + // LHS: [3.0, 4.0] = L2Denorm([0.6, 0.8], 5.0). + // RHS: [1.0, 0.0] = L2Denorm([1.0, 0.0], 1.0). + // dot([3.0, 4.0], [1.0, 0.0]) = 3.0. + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0])?; + + // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. + assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); + Ok(()) + } + + #[test] + fn both_denorm_multiple_rows() -> VortexResult<()> { + // Row 0: [3.0, 4.0] dot [3.0, 4.0] = 25.0. + // Row 1: [1.0, 0.0] dot [0.0, 1.0] = 0.0. + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0])?; + + assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); + Ok(()) + } + + #[test] + fn one_side_denorm_lhs() -> VortexResult<()> { + // LHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. + // RHS: plain [1.0, 2.0]. + // dot([3.0, 4.0], [1.0, 2.0]) = 3.0 + 8.0 = 11.0. + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = tensor_array(&[2], &[1.0, 2.0])?; + + assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); + Ok(()) + } + + #[test] + fn one_side_denorm_rhs() -> VortexResult<()> { + // LHS: plain [1.0, 2.0]. + // RHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. + // dot([1.0, 2.0], [3.0, 4.0]) = 3.0 + 8.0 = 11.0. + let lhs = tensor_array(&[2], &[1.0, 2.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + + assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); + Ok(()) + } + + #[test] + fn both_denorm_null_norms() -> VortexResult<()> { + // Row 0: valid, row 1: null (via nullable norms on lhs). + let normalized_l = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; + let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + + let lhs = + L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_l, norms_l, 2, &mut ctx)? + .into_array(); + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + + let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; + + // Row 0: 5.0 * 5.0 * dot([0.6, 0.8], [0.6, 0.8]) = 25.0, row 1: null. + assert!(prim.is_valid(0)?); + assert!(!prim.is_valid(1)?); + assert_close(&[prim.as_slice::()[0]], &[25.0]); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index f0a0c978cc5..04a348ff7d8 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -69,7 +69,8 @@ impl L2Denorm { /// This is the correct constructor for [`L2Denorm`] arrays. In addition to the structural /// checks performed by [`ScalarFnArray::try_new`], it validates that every valid row of the /// `normalized` child has L2 norm `1.0` (or `0.0` for zero rows), within the tolerance implied - /// by the child element precision. + /// by the child element precision. It also validates that stored norms are non-negative, and + /// that any row with stored norm `0.0` has an all-zero normalized row. /// /// # Errors /// @@ -82,10 +83,15 @@ impl L2Denorm { len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - validate_l2_normalized_rows(normalized.clone(), ctx)?; + let result = ScalarFnArray::try_new( + L2Denorm::new(options).erased(), + vec![normalized.clone(), norms.clone()], + len, + )?; - // SAFETY: We just validated that it is normalized. - unsafe { Self::new_array_unchecked(options, normalized, norms, len) } + validate_l2_denorm_children(normalized, norms, ctx)?; + + Ok(result) } /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually @@ -114,49 +120,6 @@ impl L2Denorm { } } -/// Returns the acceptable unit-norm drift for the given element precision. -fn unit_norm_tolerance(element_ptype: PType) -> f64 { - match element_ptype { - PType::F16 => 2e-3, - PType::F32 => 2e-6, - PType::F64 => 1e-10, - _ => unreachable!("L2Denorm requires float elements, got {element_ptype:?}"), - } -} - -/// Validates that every valid row of `input` is already L2-normalized. -pub fn validate_l2_normalized_rows(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - let row_count = input.len(); - if row_count == 0 { - return Ok(()); - } - - let tensor_match = validate_tensor_float_input(input.dtype())?; - let element_ptype = tensor_match.element_ptype(); - let tolerance = unit_norm_tolerance(element_ptype); - - let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, input, row_count)?; - let norms: PrimitiveArray = norms_sfn.into_array().execute(ctx)?; - let norms_validity = norms.validity()?; - - match_each_float_ptype!(element_ptype, |T| { - for (i, &norm) in norms.as_slice::().iter().enumerate() { - if !norms_validity.is_valid(i)? { - continue; - } - - let norm_f64 = ToPrimitive::to_f64(&norm).unwrap_or(f64::NAN); - vortex_ensure!( - norm_f64 == 0.0 || (norm_f64 - 1.0).abs() <= tolerance, - "L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \ - {norm_f64:.6}", - ); - } - }); - - Ok(()) -} - impl ScalarFnVTable for L2Denorm { type Options = ApproxOptions; @@ -373,6 +336,104 @@ fn build_tensor_array( Ok(ExtensionArray::new(dtype.as_extension().clone(), storage.into_array()).into_array()) } +/// Returns the acceptable unit-norm drift for the given element precision. +fn unit_norm_tolerance(element_ptype: PType) -> f64 { + match element_ptype { + PType::F16 => 2e-3, + PType::F32 => 2e-6, + PType::F64 => 1e-10, + _ => unreachable!("L2Denorm requires float elements, got {element_ptype:?}"), + } +} + +/// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0). +pub fn validate_l2_normalized_rows(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { + validate_l2_normalized_rows_impl(input, None, ctx) +} + +/// Validates that the `normalized` and `norms` children jointly satisfy the [`L2Denorm`] +/// invariants, which are: +/// +/// - All vectors in the normalized array have length 1 or 0. +/// - If the vector has a norm of 0, then the vector must be all 0s. +fn validate_l2_denorm_children( + normalized: ArrayRef, + norms: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + validate_l2_normalized_rows_impl(normalized, Some(norms), ctx) +} + +fn validate_l2_normalized_rows_impl( + normalized: ArrayRef, + norms: Option, + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let row_count = normalized.len(); + if row_count == 0 { + return Ok(()); + } + + let tensor_match = validate_tensor_float_input(normalized.dtype())?; + let element_ptype = tensor_match.element_ptype(); + let tolerance = unit_norm_tolerance(element_ptype); + let tensor_flat_size = tensor_match.list_size(); + + let normalized: ExtensionArray = normalized.execute(ctx)?; + let normalized_validity = normalized.as_ref().validity()?; + let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?; + let norms = norms + .map(|norms| norms.execute::(ctx)) + .transpose()?; + + let combined_validity = match &norms { + Some(norms) => normalized_validity.and(norms.validity()?)?, + None => normalized_validity, + }; + + match_each_float_ptype!(element_ptype, |T| { + let stored_norms = norms.as_ref().map(|norms| norms.as_slice::()); + + for i in 0..row_count { + if !combined_validity.is_valid(i)? { + continue; + } + + let (row_norm_sq, is_zero_row) = + flat.row::(i) + .iter() + .fold((0.0f64, true), |(sum_sq, is_zero), x| { + let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN); + (sum_sq + value * value, is_zero && value.abs() <= tolerance) + }); + let row_norm = row_norm_sq.sqrt(); + + vortex_ensure!( + row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance, + "L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \ + {row_norm:.6}", + ); + + if let Some(stored_norms) = stored_norms { + let stored_norm_f64 = ToPrimitive::to_f64(&stored_norms[i]).unwrap_or(f64::NAN); + vortex_ensure!( + stored_norm_f64 >= 0.0, + "L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}", + ); + + if stored_norm_f64 == 0.0 { + vortex_ensure!( + is_zero_row, + "L2Denorm normalized child must be all zeros when norms row {i} is 0.0", + ); + } + } + } + }); + + Ok(()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -590,6 +651,28 @@ mod tests { Ok(()) } + #[test] + fn l2_denorm_try_new_array_rejects_nonzero_row_with_zero_norm() -> VortexResult<()> { + let normalized = vector_array(2, &[1.0, 0.0, 0.0, 0.0])?; + let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + + let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx); + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn l2_denorm_try_new_array_rejects_negative_norms() -> VortexResult<()> { + let normalized = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?; + let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + + let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx); + assert!(result.is_err()); + Ok(()) + } + #[test] fn l2_denorm_new_array_unchecked_accepts_unnormalized_child() -> VortexResult<()> { let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 9ece7125b34..4d78597c962 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -9,15 +9,18 @@ use vortex_array::arrays::ConstantArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::matcher::TensorMatch; +use crate::scalar_fns::l2_denorm::L2Denorm; /// Validates that `input_dtype` is a float-valued tensor-like extension dtype. pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult> { @@ -96,6 +99,20 @@ pub fn extract_flat_elements( }) } +/// Extracts the `(normalized, norms)` children from an [`L2Denorm`] scalar function array. +/// +/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm +pub fn extract_l2_denorm_children(array: &ArrayRef) -> (ArrayRef, ArrayRef) { + let sfn = array + .as_opt::>() + .vortex_expect("expected ScalarFnArray wrapping L2Denorm"); + ( + sfn.nth_child(0) + .vortex_expect("L2Denorm missing normalized array"), + sfn.nth_child(1).vortex_expect("L2Denorm missing norms"), + ) +} + #[cfg(test)] pub mod test_helpers { use vortex_array::ArrayRef;