Skip to content

Commit 62273ba

Browse files
committed
fix l2 denorm validation
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 9653628 commit 62273ba

3 files changed

Lines changed: 151 additions & 77 deletions

File tree

vortex-tensor/src/scalar_fns/cosine_similarity.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ mod tests {
287287

288288
use crate::scalar_fns::ApproxOptions;
289289
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
290+
use crate::scalar_fns::l2_denorm::L2Denorm;
290291
use crate::utils::test_helpers::assert_close;
291292
use crate::utils::test_helpers::constant_tensor_array;
292293
use crate::utils::test_helpers::constant_vector_array;
@@ -490,13 +491,11 @@ mod tests {
490491
let len = norms.len();
491492
let normalized = tensor_array(shape, normalized_elements)?;
492493
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
493-
Ok(crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(
494-
&ApproxOptions::Exact,
495-
normalized,
496-
norms,
497-
len,
498-
)?
499-
.into_array())
494+
let mut ctx = SESSION.create_execution_ctx();
495+
Ok(
496+
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)?
497+
.into_array(),
498+
)
500499
}
501500

502501
#[test]
@@ -563,17 +562,13 @@ mod tests {
563562

564563
let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
565564
let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
566-
let rhs = crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(
567-
&ApproxOptions::Exact,
568-
normalized_r,
569-
norms_r,
570-
2,
571-
)?
572-
.into_array();
565+
let mut ctx = SESSION.create_execution_ctx();
566+
let rhs =
567+
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_r, norms_r, 2, &mut ctx)?
568+
.into_array();
573569

574570
let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased();
575571
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
576-
let mut ctx = SESSION.create_execution_ctx();
577572
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
578573

579574
assert!(prim.is_valid(0)?);

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ mod tests {
316316

317317
use crate::scalar_fns::ApproxOptions;
318318
use crate::scalar_fns::inner_product::InnerProduct;
319+
use crate::scalar_fns::l2_denorm::L2Denorm;
319320
use crate::utils::test_helpers::assert_close;
320321
use crate::utils::test_helpers::tensor_array;
321322
use crate::utils::test_helpers::vector_array;
@@ -445,13 +446,11 @@ mod tests {
445446
let len = norms.len();
446447
let normalized = tensor_array(shape, normalized_elements)?;
447448
let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array();
448-
Ok(crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(
449-
&ApproxOptions::Exact,
450-
normalized,
451-
norms,
452-
len,
453-
)?
454-
.into_array())
449+
let mut ctx = SESSION.create_execution_ctx();
450+
Ok(
451+
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, len, &mut ctx)?
452+
.into_array(),
453+
)
455454
}
456455

457456
#[test]
@@ -507,18 +506,15 @@ mod tests {
507506
// Row 0: valid, row 1: null (via nullable norms on lhs).
508507
let normalized_l = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?;
509508
let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array();
510-
let lhs = crate::scalar_fns::l2_denorm::L2Denorm::try_new_array(
511-
&ApproxOptions::Exact,
512-
normalized_l,
513-
norms_l,
514-
2,
515-
)?
516-
.into_array();
509+
let mut ctx = SESSION.create_execution_ctx();
510+
511+
let lhs =
512+
L2Denorm::try_new_array(&ApproxOptions::Exact, normalized_l, norms_l, 2, &mut ctx)?
513+
.into_array();
517514
let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?;
518515

519516
let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased();
520517
let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?;
521-
let mut ctx = SESSION.create_execution_ctx();
522518
let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?;
523519

524520
// Row 0: 5.0 * 5.0 * dot([0.6, 0.8], [0.6, 0.8]) = 25.0, row 1: null.

vortex-tensor/src/scalar_fns/l2_denorm.rs

Lines changed: 130 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ impl L2Denorm {
6969
/// This is the correct constructor for [`L2Denorm`] arrays. In addition to the structural
7070
/// checks performed by [`ScalarFnArray::try_new`], it validates that every valid row of the
7171
/// `normalized` child has L2 norm `1.0` (or `0.0` for zero rows), within the tolerance implied
72-
/// by the child element precision.
72+
/// by the child element precision. It also validates that stored norms are non-negative, and
73+
/// that any row with stored norm `0.0` has an all-zero normalized row.
7374
///
7475
/// # Errors
7576
///
@@ -82,10 +83,15 @@ impl L2Denorm {
8283
len: usize,
8384
ctx: &mut ExecutionCtx,
8485
) -> VortexResult<ScalarFnArray> {
85-
validate_l2_normalized_rows(normalized.clone(), ctx)?;
86+
let result = ScalarFnArray::try_new(
87+
L2Denorm::new(options).erased(),
88+
vec![normalized.clone(), norms.clone()],
89+
len,
90+
)?;
8691

87-
// SAFETY: We just validated that it is normalized.
88-
unsafe { Self::new_array_unchecked(options, normalized, norms, len) }
92+
validate_l2_denorm_children(normalized, norms, ctx)?;
93+
94+
Ok(result)
8995
}
9096

9197
/// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually
@@ -114,49 +120,6 @@ impl L2Denorm {
114120
}
115121
}
116122

117-
/// Returns the acceptable unit-norm drift for the given element precision.
118-
fn unit_norm_tolerance(element_ptype: PType) -> f64 {
119-
match element_ptype {
120-
PType::F16 => 2e-3,
121-
PType::F32 => 2e-6,
122-
PType::F64 => 1e-10,
123-
_ => unreachable!("L2Denorm requires float elements, got {element_ptype:?}"),
124-
}
125-
}
126-
127-
/// Validates that every valid row of `input` is already L2-normalized.
128-
pub fn validate_l2_normalized_rows(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
129-
let row_count = input.len();
130-
if row_count == 0 {
131-
return Ok(());
132-
}
133-
134-
let tensor_match = validate_tensor_float_input(input.dtype())?;
135-
let element_ptype = tensor_match.element_ptype();
136-
let tolerance = unit_norm_tolerance(element_ptype);
137-
138-
let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, input, row_count)?;
139-
let norms: PrimitiveArray = norms_sfn.into_array().execute(ctx)?;
140-
let norms_validity = norms.validity()?;
141-
142-
match_each_float_ptype!(element_ptype, |T| {
143-
for (i, &norm) in norms.as_slice::<T>().iter().enumerate() {
144-
if !norms_validity.is_valid(i)? {
145-
continue;
146-
}
147-
148-
let norm_f64 = ToPrimitive::to_f64(&norm).unwrap_or(f64::NAN);
149-
vortex_ensure!(
150-
norm_f64 == 0.0 || (norm_f64 - 1.0).abs() <= tolerance,
151-
"L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \
152-
{norm_f64:.6}",
153-
);
154-
}
155-
});
156-
157-
Ok(())
158-
}
159-
160123
impl ScalarFnVTable for L2Denorm {
161124
type Options = ApproxOptions;
162125

@@ -373,6 +336,104 @@ fn build_tensor_array<T: NativePType>(
373336
Ok(ExtensionArray::new(dtype.as_extension().clone(), storage.into_array()).into_array())
374337
}
375338

339+
/// Returns the acceptable unit-norm drift for the given element precision.
340+
fn unit_norm_tolerance(element_ptype: PType) -> f64 {
341+
match element_ptype {
342+
PType::F16 => 2e-3,
343+
PType::F32 => 2e-6,
344+
PType::F64 => 1e-10,
345+
_ => unreachable!("L2Denorm requires float elements, got {element_ptype:?}"),
346+
}
347+
}
348+
349+
/// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0).
350+
pub fn validate_l2_normalized_rows(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
351+
validate_l2_normalized_rows_impl(input, None, ctx)
352+
}
353+
354+
/// Validates that the `normalized` and `norms` children jointly satisfy the [`L2Denorm`]
355+
/// invariants, which are:
356+
///
357+
/// - All vectors in the normalized array have length 1 or 0.
358+
/// - If the vector has a norm of 0, then the vector must be all 0s.
359+
fn validate_l2_denorm_children(
360+
normalized: ArrayRef,
361+
norms: ArrayRef,
362+
ctx: &mut ExecutionCtx,
363+
) -> VortexResult<()> {
364+
validate_l2_normalized_rows_impl(normalized, Some(norms), ctx)
365+
}
366+
367+
fn validate_l2_normalized_rows_impl(
368+
normalized: ArrayRef,
369+
norms: Option<ArrayRef>,
370+
ctx: &mut ExecutionCtx,
371+
) -> VortexResult<()> {
372+
let row_count = normalized.len();
373+
if row_count == 0 {
374+
return Ok(());
375+
}
376+
377+
let tensor_match = validate_tensor_float_input(normalized.dtype())?;
378+
let element_ptype = tensor_match.element_ptype();
379+
let tolerance = unit_norm_tolerance(element_ptype);
380+
let tensor_flat_size = tensor_match.list_size();
381+
382+
let normalized: ExtensionArray = normalized.execute(ctx)?;
383+
let normalized_validity = normalized.as_ref().validity()?;
384+
let flat = extract_flat_elements(normalized.storage_array(), tensor_flat_size, ctx)?;
385+
let norms = norms
386+
.map(|norms| norms.execute::<PrimitiveArray>(ctx))
387+
.transpose()?;
388+
389+
let combined_validity = match &norms {
390+
Some(norms) => normalized_validity.and(norms.validity()?)?,
391+
None => normalized_validity,
392+
};
393+
394+
match_each_float_ptype!(element_ptype, |T| {
395+
let stored_norms = norms.as_ref().map(|norms| norms.as_slice::<T>());
396+
397+
for i in 0..row_count {
398+
if !combined_validity.is_valid(i)? {
399+
continue;
400+
}
401+
402+
let (row_norm_sq, is_zero_row) =
403+
flat.row::<T>(i)
404+
.iter()
405+
.fold((0.0f64, true), |(sum_sq, is_zero), x| {
406+
let value = ToPrimitive::to_f64(x).unwrap_or(f64::NAN);
407+
(sum_sq + value * value, is_zero && value.abs() <= tolerance)
408+
});
409+
let row_norm = row_norm_sq.sqrt();
410+
411+
vortex_ensure!(
412+
row_norm == 0.0 || (row_norm - 1.0).abs() <= tolerance,
413+
"L2Denorm normalized child must have L2 norm 1.0 or 0.0, but row {i} has \
414+
{row_norm:.6}",
415+
);
416+
417+
if let Some(stored_norms) = stored_norms {
418+
let stored_norm_f64 = ToPrimitive::to_f64(&stored_norms[i]).unwrap_or(f64::NAN);
419+
vortex_ensure!(
420+
stored_norm_f64 >= 0.0,
421+
"L2Denorm norms must be non-negative, but row {i} has {stored_norm_f64:.6}",
422+
);
423+
424+
if stored_norm_f64 == 0.0 {
425+
vortex_ensure!(
426+
is_zero_row,
427+
"L2Denorm normalized child must be all zeros when norms row {i} is 0.0",
428+
);
429+
}
430+
}
431+
}
432+
});
433+
434+
Ok(())
435+
}
436+
376437
#[cfg(test)]
377438
mod tests {
378439
use std::sync::LazyLock;
@@ -590,6 +651,28 @@ mod tests {
590651
Ok(())
591652
}
592653

654+
#[test]
655+
fn l2_denorm_try_new_array_rejects_nonzero_row_with_zero_norm() -> VortexResult<()> {
656+
let normalized = vector_array(2, &[1.0, 0.0, 0.0, 0.0])?;
657+
let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array();
658+
let mut ctx = SESSION.create_execution_ctx();
659+
660+
let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx);
661+
assert!(result.is_err());
662+
Ok(())
663+
}
664+
665+
#[test]
666+
fn l2_denorm_try_new_array_rejects_negative_norms() -> VortexResult<()> {
667+
let normalized = vector_array(2, &[1.0, 0.0, 0.0, 1.0])?;
668+
let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array();
669+
let mut ctx = SESSION.create_execution_ctx();
670+
671+
let result = L2Denorm::try_new_array(&ApproxOptions::Exact, normalized, norms, 2, &mut ctx);
672+
assert!(result.is_err());
673+
Ok(())
674+
}
675+
593676
#[test]
594677
fn l2_denorm_new_array_unchecked_accepts_unnormalized_child() -> VortexResult<()> {
595678
let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?;

0 commit comments

Comments
 (0)