Skip to content

Commit 9181e92

Browse files
committed
fix: emit NaN for single-row correlation in legacy mode
1 parent 26c252c commit 9181e92

2 files changed

Lines changed: 46 additions & 14 deletions

File tree

native/spark-expr/src/agg_funcs/correlation.rs

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -356,25 +356,41 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
356356
}
357357

358358
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
359-
// Pull each child's per-group result, then combine with corr = c/(s1*s2).
359+
// Snapshot per-group counts BEFORE the children's evaluate() consumes
360+
// their state. This lets us apply the count==0 / count==1 branches the
361+
// way the per-row CorrelationAccumulator does.
362+
let counts: Vec<f64> = match emit_to {
363+
EmitTo::All => self.covar.counts().to_vec(),
364+
EmitTo::First(n) => self.covar.counts()[..n].to_vec(),
365+
};
366+
360367
let covar = self.covar.evaluate(emit_to)?;
361368
let var1 = self.var1.evaluate(emit_to)?;
362369
let var2 = self.var2.evaluate(emit_to)?;
363370
let covar = covar.as_primitive::<Float64Type>();
364371
let var1 = var1.as_primitive::<Float64Type>();
365372
let var2 = var2.as_primitive::<Float64Type>();
366373

367-
// The child accumulators encode count==0 => null in their null buffer.
368-
// Children use Population stats so they never trigger the count==1
369-
// sample branch. We mirror per-row CorrelationAccumulator semantics:
370-
// count == 0 => null
371-
// else if s1 == 0 || s2 == 0 => null
372-
// else c / (s1 * s2)
373-
374374
let n = covar.len();
375375
let mut values = Vec::with_capacity(n);
376376
let mut validity = Vec::with_capacity(n);
377377
for i in 0..n {
378+
let count = counts[i];
379+
if count == 0.0 {
380+
values.push(0.0);
381+
validity.push(false);
382+
continue;
383+
}
384+
if count == 1.0 {
385+
if self.null_on_divide_by_zero {
386+
values.push(0.0);
387+
validity.push(false);
388+
} else {
389+
values.push(f64::NAN);
390+
validity.push(true);
391+
}
392+
continue;
393+
}
378394
if covar.is_null(i) || var1.is_null(i) || var2.is_null(i) {
379395
values.push(0.0);
380396
validity.push(false);
@@ -392,7 +408,6 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
392408
validity.push(true);
393409
}
394410

395-
let _ = self.null_on_divide_by_zero;
396411
Ok(Arc::new(Float64Array::new(
397412
values.into(),
398413
Some(arrow::buffer::NullBuffer::from(validity)),
@@ -481,11 +496,21 @@ mod groups_tests {
481496
}
482497

483498
#[test]
484-
fn single_row_yields_null() {
485-
// Correlation always uses Population stats internally. With one row the
486-
// sub-variances are zero, so s1*s2 == 0 and the per-row impl returns
487-
// null. We mirror that here.
488-
let mut a = acc(true);
499+
fn single_row_legacy_mode_yields_nan() {
500+
// Correlation always uses Population stats internally. With one row
501+
// the per-row CorrelationAccumulator returns NaN when in legacy
502+
// (null_on_divide_by_zero=false) mode and null when the flag is set.
503+
let mut a = acc(true); // legacy
504+
let v1: ArrayRef = Arc::new(Float64Array::from(vec![42.0]));
505+
let v2: ArrayRef = Arc::new(Float64Array::from(vec![7.0]));
506+
a.update_batch(&[v1, v2], &[0], None, 1).unwrap();
507+
let r = evaluate(&mut a);
508+
assert!(r[0].unwrap().is_nan());
509+
}
510+
511+
#[test]
512+
fn single_row_ansi_mode_yields_null() {
513+
let mut a = acc(false); // null_on_divide_by_zero = true
489514
let v1: ArrayRef = Arc::new(Float64Array::from(vec![42.0]));
490515
let v2: ArrayRef = Arc::new(Float64Array::from(vec![7.0]));
491516
a.update_batch(&[v1, v2], &[0], None, 1).unwrap();

native/spark-expr/src/agg_funcs/covariance.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ impl CovarianceGroupsAccumulator {
356356
}
357357
}
358358

359+
/// Returns a slice of the per-group counts. Used by the correlation
360+
/// grouped accumulator to apply the count==0 / count==1 branches before
361+
/// the children's `evaluate()` consumes their state.
362+
pub(crate) fn counts(&self) -> &[f64] {
363+
&self.counts
364+
}
365+
359366
fn resize(&mut self, total_num_groups: usize) {
360367
self.counts.resize(total_num_groups, 0.0);
361368
self.mean1s.resize(total_num_groups, 0.0);

0 commit comments

Comments
 (0)