Skip to content

Commit 3c53618

Browse files
neilconwaymbutrovichDandandan
authored
fix: Incorrect behavior for FILTER on NULLs (#22068)
## Which issue does this PR close? - Closes #22067. ## Rationale for this change In the grouping code, `accumulate_multiple` and `accumulate_indices` take a `BooleanArray` parameter, which has the result of the aggregate's `FILTER` clause (if any). Both functions only consider the value bits of the array, not the NULL bitmap, which means they consider `NULL` filter results to be effectively true, not false. ## What changes are included in this PR? * Fix NULL handling in `accumulate_multiple` and `accumulate_indices` * Refactor `accumulate_multiple` to be more readable and make use of `NullBuffer::union_many` * Introduce a new helper, `filter_to_validity` * Optimize `filter_to_nulls` to use `filter_to_validity` and avoid constructing an unnecessary intermediate `NullBuffer` * Add unit tests for NULL handling in `accumulate_multiple` and `accumulate_indices` * Add SLT tests with SQL repros for both code paths ## Are these changes tested? Yes, with new tests added. ## Are there any user-facing changes? This changes query behavior for the affected (buggy) queries. This PR also changes the signature of `filter_to_nulls`, which is technically a public API. --------- Co-authored-by: Matt Butrovich <mbutrovich@users.noreply.github.com> Co-authored-by: Daniël Heres <danielheres@gmail.com>
1 parent fe8dbfa commit 3c53618

4 files changed

Lines changed: 125 additions & 38 deletions

File tree

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
2323
use arrow::buffer::NullBuffer;
2424
use arrow::datatypes::ArrowPrimitiveType;
2525

26+
use crate::aggregate::groups_accumulator::nulls::filter_to_validity;
2627
use datafusion_expr_common::groups_accumulator::EmitTo;
2728

2829
/// If the input has nulls, then the accumulator must potentially
@@ -471,7 +472,7 @@ pub fn accumulate<T, F>(
471472
///
472473
/// This method assumes that for any input record index, if any of the value column
473474
/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
474-
/// (won't be accumulated by `value_fn`)
475+
/// (Won't be accumulated by `value_fn`)
475476
///
476477
/// # Arguments
477478
///
@@ -491,35 +492,28 @@ pub fn accumulate_multiple<T, F>(
491492
T: ArrowPrimitiveType + Send,
492493
F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
493494
{
494-
// Calculate `valid_indices` to accumulate, non-valid indices are ignored.
495-
// `valid_indices` is a bit mask corresponding to the `group_indices`. An index
496-
// is considered valid if:
497-
// 1. All columns are non-null at this index.
498-
// 2. Not filtered out by `opt_filter`
499-
500-
// Take AND from all null buffers of `value_columns`.
501-
let combined_nulls = value_columns
502-
.iter()
503-
.map(|arr| arr.logical_nulls())
504-
.fold(None, |acc, nulls| {
505-
NullBuffer::union(acc.as_ref(), nulls.as_ref())
506-
});
507-
508-
// Take AND from previous combined nulls and `opt_filter`.
509-
let valid_indices = match (combined_nulls, opt_filter) {
510-
(None, None) => None,
511-
(None, Some(filter)) => Some(filter.clone()),
512-
(Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
513-
(Some(nulls), Some(filter)) => {
514-
let combined = nulls.inner() & filter.values();
515-
Some(BooleanArray::new(combined, None))
516-
}
517-
};
518-
519495
for col in value_columns.iter() {
520496
debug_assert_eq!(col.len(), group_indices.len());
521497
}
522498

499+
// Start with rows where all value columns are non-null.
500+
let mut valid_indices =
501+
NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls()))
502+
.map(NullBuffer::into_inner);
503+
504+
// Restrict to rows where the optional filter is Some(true). Keep the filter
505+
// as a raw BooleanBuffer to avoid computing a NullBuffer null_count just to
506+
// test row validity below.
507+
if let Some(filter) = opt_filter {
508+
debug_assert_eq!(filter.len(), group_indices.len());
509+
let filter_validity = filter_to_validity(filter);
510+
if let Some(valid_indices) = valid_indices.as_mut() {
511+
*valid_indices &= &filter_validity;
512+
} else {
513+
valid_indices = Some(filter_validity);
514+
}
515+
}
516+
523517
match valid_indices {
524518
None => {
525519
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
@@ -562,7 +556,8 @@ pub fn accumulate_indices<F>(
562556
(None, Some(filter)) => {
563557
debug_assert_eq!(filter.len(), group_indices.len());
564558
let group_indices_chunks = group_indices.chunks_exact(64);
565-
let bit_chunks = filter.values().bit_chunks();
559+
let filter_validity = filter_to_validity(filter);
560+
let bit_chunks = filter_validity.bit_chunks();
566561

567562
let group_indices_remainder = group_indices_chunks.remainder();
568563

@@ -636,7 +631,8 @@ pub fn accumulate_indices<F>(
636631

637632
let group_indices_chunks = group_indices.chunks_exact(64);
638633
let valid_bit_chunks = valids.inner().bit_chunks();
639-
let filter_bit_chunks = filter.values().bit_chunks();
634+
let filter_validity = filter_to_validity(filter);
635+
let filter_bit_chunks = filter_validity.bit_chunks();
640636

641637
let group_indices_remainder = group_indices_chunks.remainder();
642638

@@ -1188,6 +1184,68 @@ mod test {
11881184
assert_eq!(accumulated, expected);
11891185
}
11901186

1187+
#[test]
1188+
fn test_accumulate_indices_with_null_filter() {
1189+
let group_indices = vec![0, 1, 0, 1];
1190+
let filter = BooleanArray::new(
1191+
BooleanBuffer::from(vec![true, true, true, false]),
1192+
Some(NullBuffer::from(vec![true, false, true, true])),
1193+
);
1194+
1195+
let mut accumulated = vec![];
1196+
accumulate_indices(&group_indices, None, Some(&filter), |group_idx| {
1197+
accumulated.push(group_idx);
1198+
});
1199+
1200+
// A NULL filter value should be treated the same as false, even if the
1201+
// underlying BooleanBuffer value is true.
1202+
let expected = vec![0, 0];
1203+
assert_eq!(accumulated, expected);
1204+
1205+
let value_validity = NullBuffer::from(vec![true, true, false, true]);
1206+
let mut accumulated = vec![];
1207+
accumulate_indices(
1208+
&group_indices,
1209+
Some(&value_validity),
1210+
Some(&filter),
1211+
|group_idx| {
1212+
accumulated.push(group_idx);
1213+
},
1214+
);
1215+
1216+
let expected = vec![0];
1217+
assert_eq!(accumulated, expected);
1218+
}
1219+
1220+
#[test]
1221+
fn test_accumulate_multiple_with_null_filter() {
1222+
let group_indices = vec![0, 1, 0, 1];
1223+
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
1224+
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
1225+
let value_columns = [values1, values2];
1226+
1227+
let filter = BooleanArray::new(
1228+
BooleanBuffer::from(vec![true, true, true, false]),
1229+
Some(NullBuffer::from(vec![true, false, true, true])),
1230+
);
1231+
1232+
let mut accumulated = vec![];
1233+
accumulate_multiple(
1234+
&group_indices,
1235+
&value_columns.iter().collect::<Vec<_>>(),
1236+
Some(&filter),
1237+
|group_idx, batch_idx, columns| {
1238+
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
1239+
accumulated.push((group_idx, values));
1240+
},
1241+
);
1242+
1243+
// A NULL filter value should be treated the same as false, even if the
1244+
// underlying BooleanBuffer value is true.
1245+
let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
1246+
assert_eq!(accumulated, expected);
1247+
}
1248+
11911249
#[test]
11921250
fn test_accumulate_multiple_with_nulls_and_filter() {
11931251
let group_indices = vec![0, 1, 0, 1];

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use arrow::array::{
2222
BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray,
2323
StringViewArray, StructArray,
2424
};
25-
use arrow::buffer::NullBuffer;
25+
use arrow::buffer::{BooleanBuffer, NullBuffer};
2626
use arrow::datatypes::DataType;
2727
use datafusion_common::{Result, not_impl_err};
2828
use std::sync::Arc;
@@ -39,15 +39,24 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
3939
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
4040
}
4141

42-
/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
42+
/// Converts an aggregate filter expression to a validity bitmap.
43+
///
44+
/// The output is `true` for rows where the filter is `Some(true)`, and `false`
45+
/// for rows where the filter is `Some(false)` or `None`.
46+
pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
47+
let Some(filter_nulls) = filter.nulls() else {
48+
return filter.values().clone();
49+
};
50+
filter.values() & filter_nulls.inner()
51+
}
52+
53+
/// Converts an aggregate filter expression to a `NullBuffer`.
4354
///
4455
/// The `NullBuffer` is
45-
/// * `true` (representing valid) for values that were `true` in filter
46-
/// * `false` (representing null) for values that were `false` or `null` in filter
47-
pub fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
48-
let (filter_bools, filter_nulls) = filter.clone().into_parts();
49-
let filter_bools = NullBuffer::from(filter_bools);
50-
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
56+
/// * `true` (representing valid) for filter values that were `Some(true)`
57+
/// * `false` (representing null) for filter values that were `Some(false)` or `None`
58+
pub fn filter_to_nulls(filter: &BooleanArray) -> NullBuffer {
59+
NullBuffer::new(filter_to_validity(filter))
5160
}
5261

5362
/// Compute an output validity mask for an array that has been filtered
@@ -97,7 +106,7 @@ pub fn filtered_null_mask(
97106
opt_filter: Option<&BooleanArray>,
98107
input: &dyn Array,
99108
) -> Option<NullBuffer> {
100-
let opt_filter = opt_filter.and_then(filter_to_nulls);
109+
let opt_filter = opt_filter.map(filter_to_nulls);
101110
NullBuffer::union(opt_filter.as_ref(), input.nulls())
102111
}
103112

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {
776776
let offsets = OffsetBuffer::from_repeated_length(1, input.len());
777777

778778
// Filtered rows become null list entries, which merge_batch will skip.
779-
let filter_nulls = opt_filter.and_then(filter_to_nulls);
779+
let filter_nulls = opt_filter.map(filter_to_nulls);
780780

781781
// With ignore_nulls, null values also become null list entries. Without
782782
// ignore_nulls, null values stay as [NULL] so merge_batch retains them.

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,18 @@ from data
693693
----
694694
1
695695

696+
# correlation_with_group_by_and_nullable_filter
697+
query IR rowsort
698+
SELECT g, corr(x, y) FILTER (WHERE b < 1) AS r
699+
FROM (VALUES
700+
(0, 1.0, 1.0, CAST(NULL AS INT)),
701+
(0, 2.0, 2.0, CAST(NULL AS INT)),
702+
(0, 3.0, 4.0, 2)
703+
) AS t(g, x, y, b)
704+
GROUP BY g
705+
----
706+
0 NULL
707+
696708
# group correlation_query_with_nans_f32
697709
query IR
698710
select id, corr(f, b)
@@ -6177,6 +6189,14 @@ FROM test_table
61776189
----
61786190
2
61796191

6192+
# count_with_group_by_and_nullable_filter
6193+
query II rowsort
6194+
SELECT g, COUNT(a) FILTER (WHERE b < 1) AS count_a
6195+
FROM (VALUES (0, 1, CAST(NULL AS INT)), (0, 2, 2)) AS t(g, a, b)
6196+
GROUP BY g
6197+
----
6198+
0 0
6199+
61806200
# query_with_and_without_filter
61816201
query III rowsort
61826202
SELECT

0 commit comments

Comments
 (0)