Skip to content

Commit 5df0f94

Browse files
authored
Numerical aggregate functions have an option to skip or include nans in calculation, skip by default (#8457)
Almost all of the time you want to skip nans but for rare cases when you don't we need to be able to configure it
1 parent de60638 commit 5df0f94

43 files changed

Lines changed: 1434 additions & 275 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

encodings/runend/src/compute/min_max.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ impl DynAggregateKernel for RunEndMinMaxKernel {
2828
batch: &ArrayRef,
2929
ctx: &mut ExecutionCtx,
3030
) -> VortexResult<Option<Scalar>> {
31-
if !aggregate_fn.is::<MinMax>() {
31+
let Some(options) = aggregate_fn.as_opt::<MinMax>() else {
3232
return Ok(None);
33-
}
33+
};
3434

3535
let Some(run_end) = batch.as_opt::<RunEnd>() else {
3636
return Ok(None);
3737
};
3838

3939
let struct_dtype = make_minmax_dtype(batch.dtype());
40-
match min_max(run_end.values(), ctx)? {
40+
match min_max(run_end.values(), ctx, *options)? {
4141
Some(result) => Ok(Some(Scalar::struct_(
4242
struct_dtype,
4343
vec![result.min, result.max],

encodings/sparse/benches/sparse_pushdown.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use vortex_array::Canonical;
1919
use vortex_array::ExecutionCtx;
2020
use vortex_array::IntoArray;
2121
use vortex_array::VortexSessionExecute;
22+
use vortex_array::aggregate_fn::NumericalAggregateOpts;
2223
use vortex_array::aggregate_fn::fns::is_constant::is_constant;
2324
use vortex_array::aggregate_fn::fns::min_max::min_max;
2425
use vortex_array::aggregate_fn::fns::null_count::null_count;
@@ -106,7 +107,10 @@ fn sparse_min_max(bencher: Bencher) {
106107
bencher
107108
.with_inputs(|| (make_sparse(40_000, false), SESSION.create_execution_ctx()))
108109
.bench_values(|(array, mut ctx)| {
109-
divan::black_box(min_max(&array, &mut ctx).vortex_expect("min_max"))
110+
divan::black_box(
111+
min_max(&array, &mut ctx, NumericalAggregateOpts::default())
112+
.vortex_expect("min_max"),
113+
)
110114
});
111115
}
112116

encodings/sparse/src/compute/min_max.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use vortex_array::IntoArray;
77
use vortex_array::aggregate_fn::Accumulator;
88
use vortex_array::aggregate_fn::AggregateFnRef;
99
use vortex_array::aggregate_fn::DynAccumulator;
10-
use vortex_array::aggregate_fn::EmptyOptions;
1110
use vortex_array::aggregate_fn::fns::min_max::MinMax;
1211
use vortex_array::aggregate_fn::kernels::DynAggregateKernel;
1312
use vortex_array::arrays::ConstantArray;
@@ -32,17 +31,17 @@ impl DynAggregateKernel for SparseMinMaxKernel {
3231
batch: &ArrayRef,
3332
ctx: &mut ExecutionCtx,
3433
) -> VortexResult<Option<Scalar>> {
35-
if !aggregate_fn.is::<MinMax>() {
34+
let Some(options) = aggregate_fn.as_opt::<MinMax>() else {
3635
return Ok(None);
37-
}
36+
};
3837

3938
let Some(sparse) = batch.as_opt::<Sparse>() else {
4039
return Ok(None);
4140
};
4241

4342
let patches = sparse.patches();
4443

45-
let mut acc = Accumulator::try_new(MinMax, EmptyOptions, batch.dtype().clone())?;
44+
let mut acc = Accumulator::try_new(MinMax, *options, batch.dtype().clone())?;
4645

4746
if !patches.values().is_empty() {
4847
acc.accumulate(patches.values(), ctx)?;
@@ -66,6 +65,7 @@ mod tests {
6665
use rstest::rstest;
6766
use vortex_array::IntoArray;
6867
use vortex_array::VortexSessionExecute;
68+
use vortex_array::aggregate_fn::NumericalAggregateOpts;
6969
use vortex_array::aggregate_fn::fns::min_max::MinMaxResult;
7070
use vortex_array::aggregate_fn::fns::min_max::min_max;
7171
use vortex_array::scalar::Scalar;
@@ -100,10 +100,18 @@ mod tests {
100100
#[case(Sparse::try_new(buffer![0u64, 1, 2].into_array(), buffer![7i32, 3, 9].into_array(), 3, Scalar::from(99i32)).unwrap())]
101101
fn min_max_matches_canonical(#[case] array: SparseArray) {
102102
let arr = array.into_array();
103-
let kernel: Option<MinMaxResult> =
104-
min_max(&arr, &mut SESSION.create_execution_ctx()).unwrap();
105-
let canonical: Option<MinMaxResult> =
106-
min_max(&arr, &mut CANONICAL_SESSION.create_execution_ctx()).unwrap();
103+
let kernel: Option<MinMaxResult> = min_max(
104+
&arr,
105+
&mut SESSION.create_execution_ctx(),
106+
NumericalAggregateOpts::default(),
107+
)
108+
.unwrap();
109+
let canonical: Option<MinMaxResult> = min_max(
110+
&arr,
111+
&mut CANONICAL_SESSION.create_execution_ctx(),
112+
NumericalAggregateOpts::default(),
113+
)
114+
.unwrap();
107115
assert_eq!(kernel, canonical);
108116
}
109117
}

encodings/sparse/src/compute/sum.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use vortex_array::IntoArray;
77
use vortex_array::aggregate_fn::Accumulator;
88
use vortex_array::aggregate_fn::AggregateFnRef;
99
use vortex_array::aggregate_fn::DynAccumulator;
10-
use vortex_array::aggregate_fn::EmptyOptions;
1110
use vortex_array::aggregate_fn::fns::sum::Sum;
1211
use vortex_array::aggregate_fn::kernels::DynAggregateKernel;
1312
use vortex_array::arrays::ConstantArray;
@@ -34,9 +33,9 @@ impl DynAggregateKernel for SparseSumKernel {
3433
batch: &ArrayRef,
3534
ctx: &mut ExecutionCtx,
3635
) -> VortexResult<Option<Scalar>> {
37-
if !aggregate_fn.is::<Sum>() {
36+
let Some(options) = aggregate_fn.as_opt::<Sum>() else {
3837
return Ok(None);
39-
}
38+
};
4039

4140
let Some(sparse) = batch.as_opt::<Sparse>() else {
4241
return Ok(None);
@@ -47,8 +46,8 @@ impl DynAggregateKernel for SparseSumKernel {
4746

4847
// Build a fresh Sum accumulator over the array dtype and fold in the fill and patch
4948
// contributions. The accumulator's existing semantics (checked overflow → null
50-
// partial) are preserved.
51-
let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone())?;
49+
// partial, NaN handling per the options) are preserved.
50+
let mut acc = Accumulator::try_new(Sum, *options, batch.dtype().clone())?;
5251

5352
if n_fill > 0 {
5453
let fill_array = ConstantArray::new(sparse.fill_scalar().clone(), n_fill).into_array();

fuzz/src/array/min_max.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use vortex_array::Canonical;
55
use vortex_array::ExecutionCtx;
66
use vortex_array::IntoArray as _;
7+
use vortex_array::aggregate_fn::NumericalAggregateOpts;
78
use vortex_array::aggregate_fn::fns::min_max::MinMaxResult;
89
use vortex_array::aggregate_fn::fns::min_max::min_max;
910
use vortex_error::VortexResult;
@@ -13,5 +14,9 @@ pub fn min_max_canonical_array(
1314
canonical: Canonical,
1415
ctx: &mut ExecutionCtx,
1516
) -> VortexResult<Option<MinMaxResult>> {
16-
min_max(&canonical.into_array(), ctx)
17+
min_max(
18+
&canonical.into_array(),
19+
ctx,
20+
NumericalAggregateOpts::default(),
21+
)
1722
}

fuzz/src/array/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use vortex_array::ArrayRef;
4343
use vortex_array::Canonical;
4444
use vortex_array::IntoArray;
4545
use vortex_array::VortexSessionExecute;
46+
use vortex_array::aggregate_fn::NumericalAggregateOpts;
4647
use vortex_array::aggregate_fn::fns::all_non_distinct::all_non_distinct;
4748
use vortex_array::aggregate_fn::fns::min_max::MinMaxResult;
4849
use vortex_array::aggregate_fn::fns::min_max::min_max;
@@ -667,8 +668,9 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult<bool> {
667668
assert_scalar_eq(&expected.scalar(), &sum_result, i)?;
668669
}
669670
Action::MinMax => {
670-
let min_max_result = min_max(&current_array, &mut ctx)
671-
.vortex_expect("min_max operation should succeed in fuzz test");
671+
let min_max_result =
672+
min_max(&current_array, &mut ctx, NumericalAggregateOpts::default())
673+
.vortex_expect("min_max operation should succeed in fuzz test");
672674
assert_min_max_eq(expected.min_max().as_ref(), min_max_result.as_ref(), i)?;
673675
}
674676
Action::FillNull(fill_value) => {

vortex-array/benches/aggregate_grouped.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use vortex_array::LEGACY_SESSION;
1414
use vortex_array::VortexSessionExecute;
1515
use vortex_array::aggregate_fn::AggregateFnVTable;
1616
use vortex_array::aggregate_fn::DynGroupedAccumulator;
17-
use vortex_array::aggregate_fn::EmptyOptions;
1817
use vortex_array::aggregate_fn::GroupedAccumulator;
18+
use vortex_array::aggregate_fn::NumericalAggregateOpts;
1919
use vortex_array::aggregate_fn::fns::count::Count;
2020
use vortex_array::aggregate_fn::fns::sum::Sum;
2121
use vortex_array::arrays::ListViewArray;
@@ -149,10 +149,14 @@ fn list_element_dtype(list_view: &ArrayRef) -> DType {
149149

150150
fn grouped_accumulator<V>(list_view: &ArrayRef, vtable: V) -> ArrayRef
151151
where
152-
V: AggregateFnVTable<Options = EmptyOptions> + Clone,
152+
V: AggregateFnVTable<Options = NumericalAggregateOpts> + Clone,
153153
{
154-
let mut acc =
155-
GroupedAccumulator::try_new(vtable, EmptyOptions, list_element_dtype(list_view)).unwrap();
154+
let mut acc = GroupedAccumulator::try_new(
155+
vtable,
156+
NumericalAggregateOpts::default(),
157+
list_element_dtype(list_view),
158+
)
159+
.unwrap();
156160
acc.accumulate_list(list_view, &mut LEGACY_SESSION.create_execution_ctx())
157161
.unwrap();
158162
divan::black_box(acc.finish().unwrap())

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ mod tests {
274274
use crate::aggregate_fn::AggregateFnRef;
275275
use crate::aggregate_fn::AggregateFnVTable;
276276
use crate::aggregate_fn::DynAccumulator;
277-
use crate::aggregate_fn::EmptyOptions;
277+
use crate::aggregate_fn::NumericalAggregateOpts;
278278
use crate::aggregate_fn::combined::Combined;
279279
use crate::aggregate_fn::combined::PairOptions;
280280
use crate::aggregate_fn::fns::mean::Mean;
@@ -348,7 +348,10 @@ mod tests {
348348
let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
349349
Accumulator::try_new(
350350
Mean::combined(),
351-
PairOptions(EmptyOptions, EmptyOptions),
351+
PairOptions(
352+
NumericalAggregateOpts::default(),
353+
NumericalAggregateOpts::default(),
354+
),
352355
dtype,
353356
)
354357
}

vortex-array/src/aggregate_fn/fns/bounded_max/mod.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::aggregate_fn::AggregateFnId;
2222
use crate::aggregate_fn::AggregateFnRef;
2323
use crate::aggregate_fn::AggregateFnSatisfaction;
2424
use crate::aggregate_fn::AggregateFnVTable;
25-
use crate::aggregate_fn::EmptyOptions;
25+
use crate::aggregate_fn::NumericalAggregateOpts;
2626
use crate::aggregate_fn::fns::max::Max;
2727
use crate::aggregate_fn::fns::min_max::MinMax;
2828
use crate::aggregate_fn::fns::min_max::min_max;
@@ -172,7 +172,11 @@ impl AggregateFnVTable for BoundedMax {
172172
};
173173
}
174174

175-
if requested.is::<Max>() {
175+
// The stored bound skips NaNs, so it cannot stand in for a NaN-including maximum.
176+
if requested
177+
.as_opt::<Max>()
178+
.is_some_and(|options| options.skip_nans)
179+
{
176180
AggregateFnSatisfaction::Approximate
177181
} else {
178182
AggregateFnSatisfaction::No
@@ -263,7 +267,7 @@ impl AggregateFnVTable for BoundedMax {
263267
Columnar::Canonical(canonical) => canonical.clone().into_array(),
264268
Columnar::Constant(constant) => constant.clone().into_array(),
265269
};
266-
let Some(result) = min_max(&array, ctx)? else {
270+
let Some(result) = min_max(&array, ctx, NumericalAggregateOpts::default())? else {
267271
return Ok(());
268272
};
269273
match truncate_max(result.max, partial.max_bytes.get())? {
@@ -284,7 +288,7 @@ impl AggregateFnVTable for BoundedMax {
284288

285289
fn supported_dtype<'a>(_options: &BoundedMaxOptions, input_dtype: &'a DType) -> Option<&'a DType> {
286290
MinMax
287-
.return_dtype(&EmptyOptions, input_dtype)
291+
.return_dtype(&NumericalAggregateOpts::default(), input_dtype)
288292
.map(|_| input_dtype)
289293
}
290294

@@ -324,7 +328,7 @@ mod tests {
324328
use crate::aggregate_fn::AggregateFnVTable;
325329
use crate::aggregate_fn::AggregateFnVTableExt;
326330
use crate::aggregate_fn::DynAccumulator;
327-
use crate::aggregate_fn::EmptyOptions;
331+
use crate::aggregate_fn::NumericalAggregateOpts;
328332
use crate::aggregate_fn::fns::bounded_max::BoundedMax;
329333
use crate::aggregate_fn::fns::bounded_max::BoundedMaxOptions;
330334
use crate::aggregate_fn::fns::bounded_max::make_bounded_max_partial_dtype;
@@ -519,15 +523,25 @@ mod tests {
519523
AggregateFnSatisfaction::No
520524
);
521525
assert_eq!(
522-
stored.can_satisfy(&Max.bind(EmptyOptions)),
526+
stored.can_satisfy(&Max.bind(NumericalAggregateOpts::default())),
523527
AggregateFnSatisfaction::Approximate
524528
);
525529
assert_eq!(
526-
Max.bind(EmptyOptions).can_satisfy(&stored),
530+
stored.can_satisfy(&Max.bind(NumericalAggregateOpts::include_nans())),
531+
AggregateFnSatisfaction::No
532+
);
533+
assert_eq!(
534+
Max.bind(NumericalAggregateOpts::include_nans())
535+
.can_satisfy(&stored),
536+
AggregateFnSatisfaction::No
537+
);
538+
assert_eq!(
539+
Max.bind(NumericalAggregateOpts::default())
540+
.can_satisfy(&stored),
527541
AggregateFnSatisfaction::Approximate
528542
);
529543
assert_eq!(
530-
stored.can_satisfy(&Min.bind(EmptyOptions)),
544+
stored.can_satisfy(&Min.bind(NumericalAggregateOpts::default())),
531545
AggregateFnSatisfaction::No
532546
);
533547
}

vortex-array/src/aggregate_fn/fns/bounded_min/mod.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::aggregate_fn::AggregateFnId;
2020
use crate::aggregate_fn::AggregateFnRef;
2121
use crate::aggregate_fn::AggregateFnSatisfaction;
2222
use crate::aggregate_fn::AggregateFnVTable;
23-
use crate::aggregate_fn::EmptyOptions;
23+
use crate::aggregate_fn::NumericalAggregateOpts;
2424
use crate::aggregate_fn::fns::min::Min;
2525
use crate::aggregate_fn::fns::min_max::MinMax;
2626
use crate::aggregate_fn::fns::min_max::min_max;
@@ -126,7 +126,11 @@ impl AggregateFnVTable for BoundedMin {
126126
};
127127
}
128128

129-
if requested.is::<Min>() {
129+
// The stored bound skips NaNs, so it cannot stand in for a NaN-including minimum.
130+
if requested
131+
.as_opt::<Min>()
132+
.is_some_and(|options| options.skip_nans)
133+
{
130134
AggregateFnSatisfaction::Approximate
131135
} else {
132136
AggregateFnSatisfaction::No
@@ -182,7 +186,7 @@ impl AggregateFnVTable for BoundedMin {
182186
Columnar::Canonical(canonical) => canonical.clone().into_array(),
183187
Columnar::Constant(constant) => constant.clone().into_array(),
184188
};
185-
let Some(result) = min_max(&array, ctx)? else {
189+
let Some(result) = min_max(&array, ctx, NumericalAggregateOpts::default())? else {
186190
return Ok(());
187191
};
188192
if let Some(bound) = truncate_min(result.min, partial.max_bytes.get())? {
@@ -202,7 +206,7 @@ impl AggregateFnVTable for BoundedMin {
202206

203207
fn supported_dtype<'a>(_options: &BoundedMinOptions, input_dtype: &'a DType) -> Option<&'a DType> {
204208
MinMax
205-
.return_dtype(&EmptyOptions, input_dtype)
209+
.return_dtype(&NumericalAggregateOpts::default(), input_dtype)
206210
.map(|_| input_dtype)
207211
}
208212

@@ -241,7 +245,7 @@ mod tests {
241245
use crate::aggregate_fn::AggregateFnVTable;
242246
use crate::aggregate_fn::AggregateFnVTableExt;
243247
use crate::aggregate_fn::DynAccumulator;
244-
use crate::aggregate_fn::EmptyOptions;
248+
use crate::aggregate_fn::NumericalAggregateOpts;
245249
use crate::aggregate_fn::fns::bounded_min::BoundedMin;
246250
use crate::aggregate_fn::fns::bounded_min::BoundedMinOptions;
247251
use crate::aggregate_fn::fns::max::Max;
@@ -350,15 +354,25 @@ mod tests {
350354
AggregateFnSatisfaction::No
351355
);
352356
assert_eq!(
353-
stored.can_satisfy(&Min.bind(EmptyOptions)),
357+
stored.can_satisfy(&Min.bind(NumericalAggregateOpts::default())),
354358
AggregateFnSatisfaction::Approximate
355359
);
356360
assert_eq!(
357-
Min.bind(EmptyOptions).can_satisfy(&stored),
361+
stored.can_satisfy(&Min.bind(NumericalAggregateOpts::include_nans())),
362+
AggregateFnSatisfaction::No
363+
);
364+
assert_eq!(
365+
Min.bind(NumericalAggregateOpts::include_nans())
366+
.can_satisfy(&stored),
367+
AggregateFnSatisfaction::No
368+
);
369+
assert_eq!(
370+
Min.bind(NumericalAggregateOpts::default())
371+
.can_satisfy(&stored),
358372
AggregateFnSatisfaction::Approximate
359373
);
360374
assert_eq!(
361-
stored.can_satisfy(&Max.bind(EmptyOptions)),
375+
stored.can_satisfy(&Max.bind(NumericalAggregateOpts::default())),
362376
AggregateFnSatisfaction::No
363377
);
364378
}

0 commit comments

Comments
 (0)