Skip to content

Commit ed2d3fd

Browse files
committed
Refactor min/max logic for shared row-wise handling
Consolidate row-wise min/max scan logic into a single helper in min_max.rs to ensure consistency between dictionary and generic complex-type paths. Add regression test for the float dictionary handling NaN and -inf cases, validating ordering semantics across batches.
1 parent 9240400 commit ed2d3fd

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,7 @@ macro_rules! min_max {
455455
}};
456456
}
457457

458-
fn dictionary_batch_extreme(
459-
values: &ArrayRef,
460-
ordering: Ordering,
461-
) -> Result<ScalarValue> {
458+
fn scalar_batch_extreme(values: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
462459
let mut extreme: Option<ScalarValue> = None;
463460

464461
for i in 0..values.len() {
@@ -823,32 +820,14 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
823820
DataType::FixedSizeList(_, _) => {
824821
min_max_batch_generic(values, Ordering::Greater)?
825822
}
826-
DataType::Dictionary(_, _) => {
827-
dictionary_batch_extreme(values, Ordering::Greater)?
828-
}
823+
DataType::Dictionary(_, _) => scalar_batch_extreme(values, Ordering::Greater)?,
829824
_ => min_max_batch!(values, min),
830825
})
831826
}
832827

833828
/// Generic min/max implementation for complex types
834829
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
835-
let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i));
836-
let Some(first_idx) = non_null_indices.next() else {
837-
return ScalarValue::try_from(array.data_type());
838-
};
839-
840-
let mut extreme = ScalarValue::try_from_array(array, first_idx)?;
841-
for i in non_null_indices {
842-
let current = ScalarValue::try_from_array(array, i)?;
843-
if current.is_null() {
844-
continue;
845-
}
846-
if extreme.is_null() || extreme.try_cmp(&current)? == ordering {
847-
extreme = current;
848-
}
849-
}
850-
851-
Ok(extreme)
830+
scalar_batch_extreme(array, ordering)
852831
}
853832

854833
/// dynamically-typed max(array) -> ScalarValue
@@ -900,7 +879,7 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
900879
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
901880
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
902881
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
903-
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, Ordering::Less)?,
882+
DataType::Dictionary(_, _) => scalar_batch_extreme(values, Ordering::Less)?,
904883
_ => min_max_batch!(values, max),
905884
})
906885
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,13 @@ mod tests {
12881288
) as ArrayRef
12891289
}
12901290

1291+
fn float_dictionary_batch(values: &[f32], keys: &[Option<i32>]) -> ArrayRef {
1292+
let values = Arc::new(Float32Array::from(values.to_vec())) as ArrayRef;
1293+
Arc::new(
1294+
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
1295+
) as ArrayRef
1296+
}
1297+
12911298
fn evaluate_dictionary_accumulator(
12921299
mut acc: impl Accumulator,
12931300
batches: &[ArrayRef],
@@ -1377,4 +1384,35 @@ mod tests {
13771384

13781385
assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
13791386
}
1387+
1388+
#[test]
1389+
fn test_min_max_dictionary_float_with_nans() -> Result<()> {
1390+
let dict_type =
1391+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Float32));
1392+
let batch1 = float_dictionary_batch(&[0.0, f32::NAN], &[Some(0), Some(1)]);
1393+
let batch2 = float_dictionary_batch(&[f32::NEG_INFINITY], &[Some(0)]);
1394+
1395+
let min_result = evaluate_dictionary_accumulator(
1396+
MinAccumulator::try_new(&dict_type)?,
1397+
&[Arc::clone(&batch1), Arc::clone(&batch2)],
1398+
)?;
1399+
assert_eq!(
1400+
min_result,
1401+
dict_scalar(
1402+
DataType::Int32,
1403+
ScalarValue::Float32(Some(f32::NEG_INFINITY)),
1404+
)
1405+
);
1406+
1407+
let max_result = evaluate_dictionary_accumulator(
1408+
MaxAccumulator::try_new(&dict_type)?,
1409+
&[batch1, batch2],
1410+
)?;
1411+
assert_eq!(
1412+
max_result,
1413+
dict_scalar(DataType::Int32, ScalarValue::Float32(Some(f32::NAN)))
1414+
);
1415+
1416+
Ok(())
1417+
}
13801418
}

0 commit comments

Comments
 (0)