Skip to content

Commit 9240400

Browse files
committed
Fix dictionary min/max behavior in DataFusion
Update min_max.rs to ensure dictionary batches iterate actual array rows, comparing referenced scalar values. Unreferenced dictionary entries no longer affect MIN/MAX, and referenced null values are correctly skipped. Expanded tests to cover these changes and updated expectations Added regression tests for unreferenced and referenced null dictionary values.
1 parent 0bbc56e commit 9240400

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
//! Basic min/max functionality shared across DataFusion aggregate functions
1919
2020
use arrow::array::{
21-
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
22-
Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
21+
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
22+
Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
2323
DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray,
2424
DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
2525
Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray,
@@ -457,13 +457,23 @@ macro_rules! min_max {
457457

458458
fn dictionary_batch_extreme(
459459
values: &ArrayRef,
460-
extreme_fn: fn(&ArrayRef) -> Result<ScalarValue>,
460+
ordering: Ordering,
461461
) -> Result<ScalarValue> {
462-
let DataType::Dictionary(key_type, _) = values.data_type() else {
463-
unreachable!("dictionary_batch_extreme requires dictionary arrays")
464-
};
465-
let inner = extreme_fn(values.as_any_dictionary().values())?;
466-
Ok(wrap_dictionary_scalar(key_type.as_ref(), inner))
462+
let mut extreme: Option<ScalarValue> = None;
463+
464+
for i in 0..values.len() {
465+
let current = ScalarValue::try_from_array(values, i)?;
466+
if current.is_null() {
467+
continue;
468+
}
469+
470+
match &extreme {
471+
Some(existing) if existing.try_cmp(&current)? != ordering => {}
472+
_ => extreme = Some(current),
473+
}
474+
}
475+
476+
extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok)
467477
}
468478

469479
fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue {
@@ -813,7 +823,9 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
813823
DataType::FixedSizeList(_, _) => {
814824
min_max_batch_generic(values, Ordering::Greater)?
815825
}
816-
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, min_batch)?,
826+
DataType::Dictionary(_, _) => {
827+
dictionary_batch_extreme(values, Ordering::Greater)?
828+
}
817829
_ => min_max_batch!(values, min),
818830
})
819831
}
@@ -828,7 +840,10 @@ fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarV
828840
let mut extreme = ScalarValue::try_from_array(array, first_idx)?;
829841
for i in non_null_indices {
830842
let current = ScalarValue::try_from_array(array, i)?;
831-
if extreme.try_cmp(&current)? == ordering {
843+
if current.is_null() {
844+
continue;
845+
}
846+
if extreme.is_null() || extreme.try_cmp(&current)? == ordering {
832847
extreme = current;
833848
}
834849
}
@@ -885,7 +900,7 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
885900
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
886901
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
887902
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
888-
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, max_batch)?,
903+
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, Ordering::Less)?,
889904
_ => min_max_batch!(values, max),
890905
})
891906
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ mod tests {
12591259
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
12601260
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
12611261
let max_result = max_acc.evaluate()?;
1262-
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
1262+
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
12631263
Ok(())
12641264
}
12651265

@@ -1278,6 +1278,16 @@ mod tests {
12781278
) as ArrayRef
12791279
}
12801280

1281+
fn optional_string_dictionary_batch(
1282+
values: &[Option<&str>],
1283+
keys: &[Option<i32>],
1284+
) -> ArrayRef {
1285+
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
1286+
Arc::new(
1287+
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
1288+
) as ArrayRef
1289+
}
1290+
12811291
fn evaluate_dictionary_accumulator(
12821292
mut acc: impl Accumulator,
12831293
batches: &[ArrayRef],
@@ -1336,6 +1346,28 @@ mod tests {
13361346
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c")
13371347
}
13381348

1349+
#[test]
1350+
fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> {
1351+
let dict_array_ref = string_dictionary_batch(
1352+
&["a", "z", "zz_unused"],
1353+
&[Some(1), Some(1), None],
1354+
);
1355+
let dict_type = dict_array_ref.data_type().clone();
1356+
1357+
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z")
1358+
}
1359+
1360+
#[test]
1361+
fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> {
1362+
let dict_array_ref = optional_string_dictionary_batch(
1363+
&[Some("b"), None, Some("a"), Some("d")],
1364+
&[Some(0), Some(1), Some(2), Some(3)],
1365+
);
1366+
let dict_type = dict_array_ref.data_type().clone();
1367+
1368+
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
1369+
}
1370+
13391371
#[test]
13401372
fn test_min_max_dictionary_multi_batch() -> Result<()> {
13411373
let dict_type =

0 commit comments

Comments
 (0)