Skip to content

Commit f2330b9

Browse files
committed
Refactor min_max handling and reduce test duplication
Normalize dictionary scalars before min_max type match. Remove separate dictionary batch scan for a shared generic path. Consolidate repeated logic with a single update_extreme helper. Reduce test duplication in min_max.rs by reusing accumulators and consolidating dictionary builders. Clean up tests in basic.rs by extracting RecordBatch construction and compressing final-plan assertions.
1 parent 5002677 commit f2330b9

File tree

3 files changed

+100
-109
lines changed

3 files changed

+100
-109
lines changed

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,20 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> {
445445

446446
#[tokio::test]
447447
async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
448+
fn dictionary_batch(
449+
schema: SchemaRef,
450+
keys: Vec<Option<i32>>,
451+
values: Vec<&str>,
452+
) -> Result<RecordBatch> {
453+
Ok(RecordBatch::try_new(
454+
schema,
455+
vec![Arc::new(DictionaryArray::new(
456+
Int32Array::from(keys),
457+
Arc::new(StringArray::from(values)),
458+
))],
459+
)?)
460+
}
461+
448462
let ctx =
449463
SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2));
450464

@@ -456,20 +470,13 @@ async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
456470
true,
457471
)]));
458472

459-
let batch1 = RecordBatch::try_new(
473+
let batch1 = dictionary_batch(
460474
schema.clone(),
461-
vec![Arc::new(DictionaryArray::new(
462-
Int32Array::from(vec![Some(1), Some(1), None]),
463-
Arc::new(StringArray::from(vec!["a", "z", "zz_unused"])),
464-
))],
465-
)?;
466-
let batch2 = RecordBatch::try_new(
467-
schema.clone(),
468-
vec![Arc::new(DictionaryArray::new(
469-
Int32Array::from(vec![Some(0), Some(1)]),
470-
Arc::new(StringArray::from(vec!["a", "d"])),
471-
))],
475+
vec![Some(1), Some(1), None],
476+
vec!["a", "z", "zz_unused"],
472477
)?;
478+
let batch2 =
479+
dictionary_batch(schema.clone(), vec![Some(0), Some(1)], vec!["a", "d"])?;
473480
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
474481
ctx.register_table("t", Arc::new(provider))?;
475482

@@ -479,10 +486,9 @@ async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
479486
let physical_plan = df.clone().create_physical_plan().await?;
480487
let formatted_plan = format!("{}", displayable(physical_plan.as_ref()).indent(true));
481488
assert!(formatted_plan.contains("AggregateExec: mode=Partial, gby=[]"));
482-
assert!(
483-
formatted_plan.contains("AggregateExec: mode=Final, gby=[]")
484-
|| formatted_plan.contains("AggregateExec: mode=FinalPartitioned, gby=[]")
485-
);
489+
assert!(["Final", "FinalPartitioned"].iter().any(|mode| {
490+
formatted_plan.contains(&format!("AggregateExec: mode={mode}, gby=[]"))
491+
}));
486492

487493
let results = df.collect().await?;
488494

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 51 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,18 @@ macro_rules! min_max_generic {
144144
// min/max of two scalar values of the same type
145145
macro_rules! min_max {
146146
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
147-
Ok(match ($VALUE, $DELTA) {
147+
let value = $VALUE;
148+
let delta = $DELTA;
149+
let wrap_key_type = match (value, delta) {
150+
(ScalarValue::Dictionary(key_type, _), ScalarValue::Dictionary(_, _)) => {
151+
Some(key_type.as_ref())
152+
}
153+
_ => None,
154+
};
155+
let value = unwrap_dictionary_scalar(value);
156+
let delta = unwrap_dictionary_scalar(delta);
157+
158+
let result = match (value, delta) {
148159
(ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
149160
(
150161
lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss),
@@ -413,65 +424,51 @@ macro_rules! min_max {
413424
min_max_generic!(lhs, rhs, $OP)
414425
}
415426

416-
(
417-
ScalarValue::Dictionary(key_type, lhs_inner),
418-
ScalarValue::Dictionary(_, rhs_inner),
419-
) => {
420-
wrap_dictionary_scalar(
421-
key_type.as_ref(),
422-
min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP),
423-
)
424-
}
425-
426-
(
427-
ScalarValue::Dictionary(_, lhs_inner),
428-
rhs,
429-
) => {
430-
min_max_generic!(lhs_inner.as_ref(), rhs, $OP)
431-
}
432-
433-
(
434-
lhs,
435-
ScalarValue::Dictionary(_, rhs_inner),
436-
) => {
437-
min_max_generic!(lhs, rhs_inner.as_ref(), $OP)
438-
}
439-
440427
e => {
441428
return internal_err!(
442429
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
443430
e
444431
)
445432
}
433+
};
434+
435+
Ok(match wrap_key_type {
436+
Some(key_type) => wrap_dictionary_scalar(key_type, result),
437+
None => result,
446438
})
447439
}};
448440
}
449441

450-
fn dictionary_batch_extreme(
451-
values: &ArrayRef,
452-
ordering: Ordering,
453-
) -> Result<ScalarValue> {
454-
let mut extreme: Option<ScalarValue> = None;
455-
456-
for i in 0..values.len() {
457-
let current = ScalarValue::try_from_array(values, i)?;
458-
if current.is_null() {
459-
continue;
460-
}
461-
462-
match &extreme {
463-
Some(existing) if existing.try_cmp(&current)? != ordering => {}
464-
_ => extreme = Some(current),
465-
}
442+
fn unwrap_dictionary_scalar(value: &ScalarValue) -> &ScalarValue {
443+
match value {
444+
ScalarValue::Dictionary(_, inner) => inner.as_ref(),
445+
_ => value,
466446
}
467-
468-
extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok)
469447
}
470448

471449
fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue {
472450
ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(value))
473451
}
474452

453+
fn update_extreme(
454+
extreme: &mut Option<ScalarValue>,
455+
current: ScalarValue,
456+
ordering: Ordering,
457+
) -> Result<()> {
458+
if current.is_null() {
459+
return Ok(());
460+
}
461+
462+
if !matches!(
463+
extreme,
464+
Some(existing) if existing.try_cmp(&current)? != ordering
465+
) {
466+
*extreme = Some(current);
467+
}
468+
469+
Ok(())
470+
}
471+
475472
/// An accumulator to compute the maximum value
476473
#[derive(Debug, Clone)]
477474
pub struct MaxAccumulator {
@@ -815,32 +812,24 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
815812
DataType::FixedSizeList(_, _) => {
816813
min_max_batch_generic(values, Ordering::Greater)?
817814
}
818-
DataType::Dictionary(_, _) => {
819-
dictionary_batch_extreme(values, Ordering::Greater)?
820-
}
815+
DataType::Dictionary(_, _) => min_max_batch_generic(values, Ordering::Greater)?,
821816
_ => min_max_batch!(values, min),
822817
})
823818
}
824819

825820
/// Generic min/max implementation for complex types
826821
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
827-
let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i));
828-
let Some(first_idx) = non_null_indices.next() else {
829-
return ScalarValue::try_from(array.data_type());
830-
};
831-
832-
let mut extreme = ScalarValue::try_from_array(array, first_idx)?;
833-
for i in non_null_indices {
834-
let current = ScalarValue::try_from_array(array, i)?;
835-
if current.is_null() {
836-
continue;
837-
}
838-
if extreme.is_null() || extreme.try_cmp(&current)? == ordering {
839-
extreme = current;
840-
}
822+
let mut extreme = None;
823+
824+
for i in 0..array.len() {
825+
update_extreme(
826+
&mut extreme,
827+
ScalarValue::try_from_array(array, i)?,
828+
ordering,
829+
)?;
841830
}
842831

843-
Ok(extreme)
832+
extreme.map_or_else(|| ScalarValue::try_from(array.data_type()), Ok)
844833
}
845834

846835
/// dynamically-typed max(array) -> ScalarValue
@@ -892,7 +881,7 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
892881
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
893882
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
894883
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
895-
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, Ordering::Less)?,
884+
DataType::Dictionary(_, _) => min_max_batch_generic(values, Ordering::Less)?,
896885
_ => min_max_batch!(values, max),
897886
})
898887
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,22 +1255,23 @@ mod tests {
12551255

12561256
#[test]
12571257
fn test_min_max_dictionary_after_coercion() -> Result<()> {
1258-
let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]);
1259-
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]);
1260-
let dict_array =
1261-
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1262-
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1258+
let dict_array_ref = string_dictionary_batch(
1259+
vec!["b", "c", "a", "🦀", "d"],
1260+
&[Some(0), Some(1), Some(2), None, Some(4)],
1261+
);
12631262
let rt_type =
12641263
get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone();
12651264

1266-
let mut min_acc = MinAccumulator::try_new(&rt_type)?;
1267-
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1268-
let min_result = min_acc.evaluate()?;
1265+
let min_result = evaluate_dictionary_accumulator(
1266+
MinAccumulator::try_new(&rt_type)?,
1267+
&[Arc::clone(&dict_array_ref)],
1268+
)?;
12691269
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
12701270

1271-
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
1272-
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1273-
let max_result = max_acc.evaluate()?;
1271+
let max_result = evaluate_dictionary_accumulator(
1272+
MaxAccumulator::try_new(&rt_type)?,
1273+
&[dict_array_ref],
1274+
)?;
12741275
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
12751276
Ok(())
12761277
}
@@ -1283,18 +1284,11 @@ mod tests {
12831284
dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string())))
12841285
}
12851286

1286-
fn string_dictionary_batch(values: &[&str], keys: &[Option<i32>]) -> ArrayRef {
1287-
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
1288-
Arc::new(
1289-
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
1290-
) as ArrayRef
1291-
}
1292-
1293-
fn optional_string_dictionary_batch(
1294-
values: &[Option<&str>],
1295-
keys: &[Option<i32>],
1296-
) -> ArrayRef {
1297-
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
1287+
fn string_dictionary_batch<T>(values: Vec<T>, keys: &[Option<i32>]) -> ArrayRef
1288+
where
1289+
StringArray: From<Vec<T>>,
1290+
{
1291+
let values = Arc::new(StringArray::from(values)) as ArrayRef;
12981292
Arc::new(
12991293
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
13001294
) as ArrayRef
@@ -1339,7 +1333,7 @@ mod tests {
13391333
#[test]
13401334
fn test_min_max_dictionary_without_coercion() -> Result<()> {
13411335
let dict_array_ref = string_dictionary_batch(
1342-
&["b", "c", "a", "d"],
1336+
vec!["b", "c", "a", "d"],
13431337
&[Some(0), Some(1), Some(2), Some(3)],
13441338
);
13451339
let dict_type = dict_array_ref.data_type().clone();
@@ -1350,7 +1344,7 @@ mod tests {
13501344
#[test]
13511345
fn test_min_max_dictionary_with_nulls() -> Result<()> {
13521346
let dict_array_ref = string_dictionary_batch(
1353-
&["b", "c", "a"],
1347+
vec!["b", "c", "a"],
13541348
&[None, Some(0), None, Some(1), Some(2)],
13551349
);
13561350
let dict_type = dict_array_ref.data_type().clone();
@@ -1360,17 +1354,19 @@ mod tests {
13601354

13611355
#[test]
13621356
fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> {
1363-
let dict_array_ref =
1364-
string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]);
1357+
let dict_array_ref = string_dictionary_batch(
1358+
vec!["a", "z", "zz_unused"],
1359+
&[Some(1), Some(1), None],
1360+
);
13651361
let dict_type = dict_array_ref.data_type().clone();
13661362

13671363
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z")
13681364
}
13691365

13701366
#[test]
13711367
fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> {
1372-
let dict_array_ref = optional_string_dictionary_batch(
1373-
&[Some("b"), None, Some("a"), Some("d")],
1368+
let dict_array_ref = string_dictionary_batch(
1369+
vec![Some("b"), None, Some("a"), Some("d")],
13741370
&[Some(0), Some(1), Some(2), Some(3)],
13751371
);
13761372
let dict_type = dict_array_ref.data_type().clone();
@@ -1382,8 +1378,8 @@ mod tests {
13821378
fn test_min_max_dictionary_multi_batch() -> Result<()> {
13831379
let dict_type =
13841380
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1385-
let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]);
1386-
let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]);
1381+
let batch1 = string_dictionary_batch(vec!["b", "c"], &[Some(0), Some(1)]);
1382+
let batch2 = string_dictionary_batch(vec!["a", "d"], &[Some(0), Some(1)]);
13871383

13881384
assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
13891385
}

0 commit comments

Comments
 (0)