Skip to content

Commit 1a71735

Browse files
robert3005gatesn
authored andcommitted
Sum aggregate doesn't include nan values (#7009)
NaN will cause sum to be NaN. Instead we have NaNCount already that users can use to handle columns with NaNs. fix #5152 Signed-off-by: Robert Kruszewski <github@robertk.io> --------- Signed-off-by: Robert Kruszewski <github@robertk.io> Signed-off-by: Nicholas Gates <nick@nickgates.com> Co-authored-by: Nicholas Gates <nick@nickgates.com>
1 parent a824faa commit 1a71735

2 files changed

Lines changed: 69 additions & 3 deletions

File tree

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,11 @@ impl AggregateFnVTable for Sum {
203203

204204
#[inline]
205205
fn is_saturated(&self, partial: &Self::Partial) -> bool {
206-
partial.current.is_none()
206+
match partial.current.as_ref() {
207+
None => true,
208+
Some(SumState::Float(v)) => v.is_nan(),
209+
Some(_) => false,
210+
}
207211
}
208212

209213
fn accumulate(

vortex-array/src/aggregate_fn/fns/sum/primitive.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR
5454
signed: |_T| { vortex_panic!("float sum state with signed input") },
5555
floating: |T| {
5656
for &v in p.as_slice::<T>() {
57-
*acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
57+
if !v.is_nan() {
58+
*acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
59+
}
5860
}
5961
Ok(false)
6062
}
@@ -98,7 +100,7 @@ fn accumulate_primitive_valid(
98100
signed: |_T| { vortex_panic!("float sum state with signed input") },
99101
floating: |T| {
100102
for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
101-
if valid {
103+
if valid && !v.is_nan() {
102104
*acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
103105
}
104106
}
@@ -213,6 +215,66 @@ mod tests {
213215
Ok(())
214216
}
215217

218+
#[test]
219+
fn sum_f64_with_nan() -> VortexResult<()> {
220+
let arr = PrimitiveArray::new(
221+
buffer![1.0f64, f64::NAN, 2.0, f64::NAN, 3.0],
222+
Validity::NonNullable,
223+
)
224+
.into_array();
225+
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
226+
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(6.0));
227+
Ok(())
228+
}
229+
230+
#[test]
231+
fn sum_f32_with_nan() -> VortexResult<()> {
232+
let arr =
233+
PrimitiveArray::new(buffer![1.0f32, f32::NAN, 4.0], Validity::NonNullable).into_array();
234+
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
235+
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(5.0));
236+
Ok(())
237+
}
238+
239+
#[test]
240+
fn sum_f64_with_nan_and_nulls() -> VortexResult<()> {
241+
let arr = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(f64::NAN), Some(3.0)])
242+
.into_array();
243+
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
244+
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(4.0));
245+
Ok(())
246+
}
247+
248+
#[test]
249+
fn sum_all_nan() -> VortexResult<()> {
250+
let arr =
251+
PrimitiveArray::new(buffer![f64::NAN, f64::NAN], Validity::NonNullable).into_array();
252+
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
253+
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(0.0));
254+
Ok(())
255+
}
256+
257+
#[test]
258+
fn sum_f64_with_infinity() -> VortexResult<()> {
259+
let batch = PrimitiveArray::new(
260+
buffer![1.0f64, f64::INFINITY, f64::NEG_INFINITY, 2.0],
261+
Validity::NonNullable,
262+
)
263+
.into_array();
264+
let acc = sum(&batch, &mut LEGACY_SESSION.create_execution_ctx())?;
265+
// INFINITY + NEG_INFINITY = NaN, which is treated as saturated
266+
assert!(acc.as_primitive().typed_value::<f64>().unwrap().is_nan());
267+
268+
let mut acc = Accumulator::try_new(
269+
Sum,
270+
EmptyOptions,
271+
DType::Primitive(PType::F64, Nullability::NonNullable),
272+
)?;
273+
acc.accumulate(&batch, &mut LEGACY_SESSION.create_execution_ctx())?;
274+
assert!(acc.is_saturated());
275+
Ok(())
276+
}
277+
216278
#[test]
217279
fn sum_checked_overflow() -> VortexResult<()> {
218280
let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();

0 commit comments

Comments
 (0)