Skip to content

Commit caafe1c

Browse files
committed
Refactor dictionary min/max logic and tests
Centralize dictionary batch handling for min/max operations. Streamline min_max_batch_generic to initialize from the first non-null element. Implement shared setup/assert helpers in dictionary tests to reduce repetition while preserving test coverage.
1 parent fe226dd commit caafe1c

File tree

2 files changed

+93
-97
lines changed

2 files changed

+93
-97
lines changed

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ macro_rules! min_max_generic {
141141
}};
142142
}
143143

144+
macro_rules! min_max_dictionary {
145+
($VALUE:expr, $DELTA:expr, wrap $KEY_TYPE:expr, $OP:ident) => {{
146+
let winner = min_max_generic!($VALUE, $DELTA, $OP);
147+
ScalarValue::Dictionary($KEY_TYPE.clone(), Box::new(winner))
148+
}};
149+
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
150+
min_max_generic!($VALUE, $DELTA, $OP)
151+
}};
152+
}
153+
144154
// min/max of two scalar values of the same type
145155
macro_rules! min_max {
146156
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
@@ -417,22 +427,26 @@ macro_rules! min_max {
417427
ScalarValue::Dictionary(key_type, lhs_inner),
418428
ScalarValue::Dictionary(_, rhs_inner),
419429
) => {
420-
let winner = min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP);
421-
ScalarValue::Dictionary(key_type.clone(), Box::new(winner))
430+
min_max_dictionary!(
431+
lhs_inner.as_ref(),
432+
rhs_inner.as_ref(),
433+
wrap key_type,
434+
$OP
435+
)
422436
}
423437

424438
(
425439
ScalarValue::Dictionary(_, lhs_inner),
426440
rhs,
427441
) => {
428-
min_max_generic!(lhs_inner.as_ref(), rhs, $OP)
442+
min_max_dictionary!(lhs_inner.as_ref(), rhs, $OP)
429443
}
430444

431445
(
432446
lhs,
433447
ScalarValue::Dictionary(_, rhs_inner),
434448
) => {
435-
min_max_generic!(lhs, rhs_inner.as_ref(), $OP)
449+
min_max_dictionary!(lhs, rhs_inner.as_ref(), $OP)
436450
}
437451

438452
e => {
@@ -445,6 +459,17 @@ macro_rules! min_max {
445459
}};
446460
}
447461

462+
fn dictionary_batch_extreme(
463+
values: &ArrayRef,
464+
extreme_fn: fn(&ArrayRef) -> Result<ScalarValue>,
465+
) -> Result<ScalarValue> {
466+
let DataType::Dictionary(key_type, _) = values.data_type() else {
467+
unreachable!("dictionary_batch_extreme requires dictionary arrays")
468+
};
469+
let inner = extreme_fn(values.as_any_dictionary().values())?;
470+
Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(inner)))
471+
}
472+
448473
/// An accumulator to compute the maximum value
449474
#[derive(Debug, Clone)]
450475
pub struct MaxAccumulator {
@@ -788,32 +813,22 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
788813
DataType::FixedSizeList(_, _) => {
789814
min_max_batch_generic(values, Ordering::Greater)?
790815
}
791-
DataType::Dictionary(key_type, _) => {
792-
let dict_values = values.as_any_dictionary().values();
793-
let inner = min_batch(dict_values)?;
794-
ScalarValue::Dictionary(key_type.clone(), Box::new(inner))
795-
}
816+
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, min_batch)?,
796817
_ => min_max_batch!(values, min),
797818
})
798819
}
799820

800821
/// Generic min/max implementation for complex types
801822
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
802-
if array.len() == array.null_count() {
823+
let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i));
824+
let Some(first_idx) = non_null_indices.next() else {
803825
return ScalarValue::try_from(array.data_type());
804-
}
805-
let mut extreme = ScalarValue::try_from_array(array, 0)?;
806-
for i in 1..array.len() {
826+
};
827+
828+
let mut extreme = ScalarValue::try_from_array(array, first_idx)?;
829+
for i in non_null_indices {
807830
let current = ScalarValue::try_from_array(array, i)?;
808-
if current.is_null() {
809-
continue;
810-
}
811-
if extreme.is_null() {
812-
extreme = current;
813-
continue;
814-
}
815-
let cmp = extreme.try_cmp(&current)?;
816-
if cmp == ordering {
831+
if extreme.try_cmp(&current)? == ordering {
817832
extreme = current;
818833
}
819834
}
@@ -870,11 +885,7 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
870885
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
871886
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
872887
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
873-
DataType::Dictionary(key_type, _) => {
874-
let dict_values = values.as_any_dictionary().values();
875-
let inner = max_batch(dict_values)?;
876-
ScalarValue::Dictionary(key_type.clone(), Box::new(inner))
877-
}
888+
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, max_batch)?,
878889
_ => min_max_batch!(values, max),
879890
})
880891
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 55 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,94 +1267,79 @@ mod tests {
12671267
ScalarValue::Dictionary(Box::new(key_type), Box::new(inner))
12681268
}
12691269

1270-
#[test]
1271-
fn test_min_max_dictionary_without_coercion() -> Result<()> {
1272-
let values = StringArray::from(vec!["b", "c", "a", "d"]);
1273-
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]);
1274-
let dict_array =
1275-
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1276-
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1277-
let dict_type = dict_array_ref.data_type().clone();
1270+
fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue {
1271+
dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string())))
1272+
}
1273+
1274+
fn string_dictionary_batch(
1275+
values: Vec<&str>,
1276+
keys: Vec<Option<i32>>,
1277+
) -> ArrayRef {
1278+
let values = Arc::new(StringArray::from(values)) as ArrayRef;
1279+
Arc::new(DictionaryArray::try_new(Int32Array::from(keys), values).unwrap())
1280+
as ArrayRef
1281+
}
1282+
1283+
fn assert_dictionary_min_max(
1284+
dict_type: &DataType,
1285+
batches: &[ArrayRef],
1286+
expected_min: &str,
1287+
expected_max: &str,
1288+
) -> Result<()> {
1289+
let key_type = match dict_type {
1290+
DataType::Dictionary(key_type, _) => key_type.as_ref().clone(),
1291+
other => panic!("expected dictionary type, got {other:?}"),
1292+
};
12781293

1279-
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1280-
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1281-
let min_result = min_acc.evaluate()?;
1294+
let mut min_acc = MinAccumulator::try_new(dict_type)?;
1295+
for batch in batches {
1296+
min_acc.update_batch(&[Arc::clone(batch)])?;
1297+
}
12821298
assert_eq!(
1283-
min_result,
1284-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string())))
1299+
min_acc.evaluate()?,
1300+
utf8_dict_scalar(key_type.clone(), expected_min)
12851301
);
12861302

1287-
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1288-
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1289-
let max_result = max_acc.evaluate()?;
1290-
assert_eq!(
1291-
max_result,
1292-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string())))
1293-
);
1303+
let mut max_acc = MaxAccumulator::try_new(dict_type)?;
1304+
for batch in batches {
1305+
max_acc.update_batch(&[Arc::clone(batch)])?;
1306+
}
1307+
assert_eq!(max_acc.evaluate()?, utf8_dict_scalar(key_type, expected_max));
1308+
12941309
Ok(())
12951310
}
12961311

12971312
#[test]
1298-
fn test_min_max_dictionary_with_nulls() -> Result<()> {
1299-
let values = StringArray::from(vec!["b", "c", "a"]);
1300-
let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]);
1301-
let dict_array =
1302-
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1303-
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1313+
fn test_min_max_dictionary_without_coercion() -> Result<()> {
1314+
let dict_array_ref = string_dictionary_batch(
1315+
vec!["b", "c", "a", "d"],
1316+
vec![Some(0), Some(1), Some(2), Some(3)],
1317+
);
13041318
let dict_type = dict_array_ref.data_type().clone();
13051319

1306-
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1307-
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1308-
let min_result = min_acc.evaluate()?;
1309-
assert_eq!(
1310-
min_result,
1311-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string())))
1312-
);
1320+
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
1321+
}
13131322

1314-
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1315-
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1316-
let max_result = max_acc.evaluate()?;
1317-
assert_eq!(
1318-
max_result,
1319-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("c".to_string())))
1323+
#[test]
1324+
fn test_min_max_dictionary_with_nulls() -> Result<()> {
1325+
let dict_array_ref = string_dictionary_batch(
1326+
vec!["b", "c", "a"],
1327+
vec![None, Some(0), None, Some(1), Some(2)],
13201328
);
1321-
Ok(())
1329+
let dict_type = dict_array_ref.data_type().clone();
1330+
1331+
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c")
13221332
}
13231333

13241334
#[test]
13251335
fn test_min_max_dictionary_multi_batch() -> Result<()> {
13261336
let dict_type =
13271337
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1338+
let batch1 =
1339+
string_dictionary_batch(vec!["b", "c"], vec![Some(0), Some(1)]);
1340+
let batch2 =
1341+
string_dictionary_batch(vec!["a", "d"], vec![Some(0), Some(1)]);
13281342

1329-
let values1 = StringArray::from(vec!["b", "c"]);
1330-
let keys1 = Int32Array::from(vec![Some(0), Some(1)]);
1331-
let batch1 = Arc::new(
1332-
DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(),
1333-
) as ArrayRef;
1334-
1335-
let values2 = StringArray::from(vec!["a", "d"]);
1336-
let keys2 = Int32Array::from(vec![Some(0), Some(1)]);
1337-
let batch2 = Arc::new(
1338-
DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(),
1339-
) as ArrayRef;
1340-
1341-
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1342-
min_acc.update_batch(&[Arc::clone(&batch1)])?;
1343-
min_acc.update_batch(&[Arc::clone(&batch2)])?;
1344-
let min_result = min_acc.evaluate()?;
1345-
assert_eq!(
1346-
min_result,
1347-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string())))
1348-
);
1349-
1350-
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1351-
max_acc.update_batch(&[Arc::clone(&batch1)])?;
1352-
max_acc.update_batch(&[Arc::clone(&batch2)])?;
1353-
let max_result = max_acc.evaluate()?;
1354-
assert_eq!(
1355-
max_result,
1356-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string())))
1357-
);
1358-
Ok(())
1343+
assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
13591344
}
13601345
}

0 commit comments

Comments
 (0)