Skip to content

Commit 0bbc56e

Browse files
committed
Simplify min/max flow in dictionary handling
Refactor dictionary min/max flow by removing the wrap macro arm, making re-wrapping explicit through a private helper. This separates the "choose inner winner" from the "wrap as dictionary" step for easier auditing. In `datafusion/functions-aggregate/src/min_max.rs`, update `string_dictionary_batch` to accept slices instead of owned Vecs, and introduce a small `evaluate_dictionary_accumulator` helper to streamline min/max assertions with a shared accumulator execution path, reducing repeated setup.
1 parent caafe1c commit 0bbc56e

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,7 @@ macro_rules! min_max_generic {
142142
}
143143

144144
macro_rules! min_max_dictionary {
145-
($VALUE:expr, $DELTA:expr, wrap $KEY_TYPE:expr, $OP:ident) => {{
146-
let winner = min_max_generic!($VALUE, $DELTA, $OP);
147-
ScalarValue::Dictionary($KEY_TYPE.clone(), Box::new(winner))
148-
}};
149-
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
150-
min_max_generic!($VALUE, $DELTA, $OP)
151-
}};
145+
($VALUE:expr, $DELTA:expr, $OP:ident) => {{ min_max_generic!($VALUE, $DELTA, $OP) }};
152146
}
153147

154148
// min/max of two scalar values of the same type
@@ -427,11 +421,13 @@ macro_rules! min_max {
427421
ScalarValue::Dictionary(key_type, lhs_inner),
428422
ScalarValue::Dictionary(_, rhs_inner),
429423
) => {
430-
min_max_dictionary!(
424+
wrap_dictionary_scalar(
425+
key_type.as_ref(),
426+
min_max_dictionary!(
431427
lhs_inner.as_ref(),
432428
rhs_inner.as_ref(),
433-
wrap key_type,
434429
$OP
430+
),
435431
)
436432
}
437433

@@ -467,7 +463,11 @@ fn dictionary_batch_extreme(
467463
unreachable!("dictionary_batch_extreme requires dictionary arrays")
468464
};
469465
let inner = extreme_fn(values.as_any_dictionary().values())?;
470-
Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(inner)))
466+
Ok(wrap_dictionary_scalar(key_type.as_ref(), inner))
467+
}
468+
469+
fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue {
470+
ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(value))
471471
}
472472

473473
/// An accumulator to compute the maximum value

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,13 +1271,21 @@ mod tests {
12711271
dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string())))
12721272
}
12731273

1274-
fn string_dictionary_batch(
1275-
values: Vec<&str>,
1276-
keys: Vec<Option<i32>>,
1277-
) -> ArrayRef {
1278-
let values = Arc::new(StringArray::from(values)) as ArrayRef;
1279-
Arc::new(DictionaryArray::try_new(Int32Array::from(keys), values).unwrap())
1280-
as ArrayRef
1274+
fn string_dictionary_batch(values: &[&str], keys: &[Option<i32>]) -> ArrayRef {
1275+
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
1276+
Arc::new(
1277+
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
1278+
) as ArrayRef
1279+
}
1280+
1281+
fn evaluate_dictionary_accumulator(
1282+
mut acc: impl Accumulator,
1283+
batches: &[ArrayRef],
1284+
) -> Result<ScalarValue> {
1285+
for batch in batches {
1286+
acc.update_batch(&[Arc::clone(batch)])?;
1287+
}
1288+
acc.evaluate()
12811289
}
12821290

12831291
fn assert_dictionary_min_max(
@@ -1291,29 +1299,26 @@ mod tests {
12911299
other => panic!("expected dictionary type, got {other:?}"),
12921300
};
12931301

1294-
let mut min_acc = MinAccumulator::try_new(dict_type)?;
1295-
for batch in batches {
1296-
min_acc.update_batch(&[Arc::clone(batch)])?;
1297-
}
1298-
assert_eq!(
1299-
min_acc.evaluate()?,
1300-
utf8_dict_scalar(key_type.clone(), expected_min)
1301-
);
1302+
let min_result = evaluate_dictionary_accumulator(
1303+
MinAccumulator::try_new(dict_type)?,
1304+
batches,
1305+
)?;
1306+
assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min));
13021307

1303-
let mut max_acc = MaxAccumulator::try_new(dict_type)?;
1304-
for batch in batches {
1305-
max_acc.update_batch(&[Arc::clone(batch)])?;
1306-
}
1307-
assert_eq!(max_acc.evaluate()?, utf8_dict_scalar(key_type, expected_max));
1308+
let max_result = evaluate_dictionary_accumulator(
1309+
MaxAccumulator::try_new(dict_type)?,
1310+
batches,
1311+
)?;
1312+
assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max));
13081313

13091314
Ok(())
13101315
}
13111316

13121317
#[test]
13131318
fn test_min_max_dictionary_without_coercion() -> Result<()> {
13141319
let dict_array_ref = string_dictionary_batch(
1315-
vec!["b", "c", "a", "d"],
1316-
vec![Some(0), Some(1), Some(2), Some(3)],
1320+
&["b", "c", "a", "d"],
1321+
&[Some(0), Some(1), Some(2), Some(3)],
13171322
);
13181323
let dict_type = dict_array_ref.data_type().clone();
13191324

@@ -1323,8 +1328,8 @@ mod tests {
13231328
#[test]
13241329
fn test_min_max_dictionary_with_nulls() -> Result<()> {
13251330
let dict_array_ref = string_dictionary_batch(
1326-
vec!["b", "c", "a"],
1327-
vec![None, Some(0), None, Some(1), Some(2)],
1331+
&["b", "c", "a"],
1332+
&[None, Some(0), None, Some(1), Some(2)],
13281333
);
13291334
let dict_type = dict_array_ref.data_type().clone();
13301335

@@ -1335,10 +1340,8 @@ mod tests {
13351340
fn test_min_max_dictionary_multi_batch() -> Result<()> {
13361341
let dict_type =
13371342
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1338-
let batch1 =
1339-
string_dictionary_batch(vec!["b", "c"], vec![Some(0), Some(1)]);
1340-
let batch2 =
1341-
string_dictionary_batch(vec!["a", "d"], vec![Some(0), Some(1)]);
1343+
let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]);
1344+
let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]);
13421345

13431346
assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
13441347
}

0 commit comments

Comments
 (0)