Skip to content

Commit dad6e02

Browse files
committed
Refactor dictionary handling and simplify batch logic
Remove the no-op dictionary macro and single-use wrapper. Collapse dictionary handling into a normalized path and seed scalar_batch_extreme from the first non-null value. Unify row-wise batch dispatch behind a shared predicate. Apply formatting adjustments in min_max.rs as per cargo fmt.
1 parent ed2d3fd commit dad6e02

File tree

2 files changed

+56
-60
lines changed

2 files changed

+56
-60
lines changed

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 54 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,6 @@ macro_rules! min_max_generic {
141141
}};
142142
}
143143

144-
macro_rules! min_max_dictionary {
145-
($VALUE:expr, $DELTA:expr, $OP:ident) => {{ min_max_generic!($VALUE, $DELTA, $OP) }};
146-
}
147-
148144
// min/max of two scalar values of the same type
149145
macro_rules! min_max {
150146
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
@@ -417,32 +413,20 @@ macro_rules! min_max {
417413
min_max_generic!(lhs, rhs, $OP)
418414
}
419415

420-
(
421-
ScalarValue::Dictionary(key_type, lhs_inner),
422-
ScalarValue::Dictionary(_, rhs_inner),
423-
) => {
424-
wrap_dictionary_scalar(
425-
key_type.as_ref(),
426-
min_max_dictionary!(
427-
lhs_inner.as_ref(),
428-
rhs_inner.as_ref(),
429-
$OP
430-
),
431-
)
432-
}
416+
(lhs, rhs)
417+
if matches!(lhs, ScalarValue::Dictionary(_, _))
418+
|| matches!(rhs, ScalarValue::Dictionary(_, _)) =>
419+
{
420+
let (lhs, lhs_key_type) = dictionary_scalar_parts(lhs);
421+
let (rhs, rhs_key_type) = dictionary_scalar_parts(rhs);
422+
let result = min_max_generic!(lhs, rhs, $OP);
433423

434-
(
435-
ScalarValue::Dictionary(_, lhs_inner),
436-
rhs,
437-
) => {
438-
min_max_dictionary!(lhs_inner.as_ref(), rhs, $OP)
439-
}
440-
441-
(
442-
lhs,
443-
ScalarValue::Dictionary(_, rhs_inner),
444-
) => {
445-
min_max_dictionary!(lhs, rhs_inner.as_ref(), $OP)
424+
match lhs_key_type.zip(rhs_key_type) {
425+
Some((key_type, _)) => {
426+
ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(result))
427+
}
428+
None => result,
429+
}
446430
}
447431

448432
e => {
@@ -456,25 +440,50 @@ macro_rules! min_max {
456440
}
457441

458442
fn scalar_batch_extreme(values: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
459-
let mut extreme: Option<ScalarValue> = None;
443+
let mut index = 0;
444+
let mut extreme = loop {
445+
if index == values.len() {
446+
return ScalarValue::try_from(values.data_type());
447+
}
448+
449+
let current = ScalarValue::try_from_array(values, index)?;
450+
index += 1;
460451

461-
for i in 0..values.len() {
462-
let current = ScalarValue::try_from_array(values, i)?;
463-
if current.is_null() {
464-
continue;
452+
if !current.is_null() {
453+
break current;
465454
}
455+
};
456+
457+
while index < values.len() {
458+
let current = ScalarValue::try_from_array(values, index)?;
459+
index += 1;
466460

467-
match &extreme {
468-
Some(existing) if existing.try_cmp(&current)? != ordering => {}
469-
_ => extreme = Some(current),
461+
if !current.is_null() && extreme.try_cmp(&current)? == ordering {
462+
extreme = current;
470463
}
471464
}
472465

473-
extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok)
466+
Ok(extreme)
467+
}
468+
469+
fn dictionary_scalar_parts(value: &ScalarValue) -> (&ScalarValue, Option<&DataType>) {
470+
match value {
471+
ScalarValue::Dictionary(key_type, inner) => {
472+
(inner.as_ref(), Some(key_type.as_ref()))
473+
}
474+
other => (other, None),
475+
}
474476
}
475477

476-
fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue {
477-
ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(value))
478+
fn is_row_wise_batch_type(data_type: &DataType) -> bool {
479+
matches!(
480+
data_type,
481+
DataType::Struct(_)
482+
| DataType::List(_)
483+
| DataType::LargeList(_)
484+
| DataType::FixedSizeList(_, _)
485+
| DataType::Dictionary(_, _)
486+
)
478487
}
479488

480489
/// An accumulator to compute the maximum value
@@ -814,22 +823,13 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
814823
min_binary_view
815824
)
816825
}
817-
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?,
818-
DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?,
819-
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?,
820-
DataType::FixedSizeList(_, _) => {
821-
min_max_batch_generic(values, Ordering::Greater)?
826+
data_type if is_row_wise_batch_type(data_type) => {
827+
scalar_batch_extreme(values, Ordering::Greater)?
822828
}
823-
DataType::Dictionary(_, _) => scalar_batch_extreme(values, Ordering::Greater)?,
824829
_ => min_max_batch!(values, min),
825830
})
826831
}
827832

828-
/// Generic min/max implementation for complex types
829-
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
830-
scalar_batch_extreme(array, ordering)
831-
}
832-
833833
/// dynamically-typed max(array) -> ScalarValue
834834
pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
835835
Ok(match values.data_type() {
@@ -875,11 +875,9 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
875875
let value = value.map(|e| e.to_vec());
876876
ScalarValue::FixedSizeBinary(*size, value)
877877
}
878-
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?,
879-
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
880-
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
881-
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
882-
DataType::Dictionary(_, _) => scalar_batch_extreme(values, Ordering::Less)?,
878+
data_type if is_row_wise_batch_type(data_type) => {
879+
scalar_batch_extreme(values, Ordering::Less)?
880+
}
883881
_ => min_max_batch!(values, max),
884882
})
885883
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,10 +1355,8 @@ mod tests {
13551355

13561356
#[test]
13571357
fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> {
1358-
let dict_array_ref = string_dictionary_batch(
1359-
&["a", "z", "zz_unused"],
1360-
&[Some(1), Some(1), None],
1361-
);
1358+
let dict_array_ref =
1359+
string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]);
13621360
let dict_type = dict_array_ref.data_type().clone();
13631361

13641362
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z")

0 commit comments

Comments
 (0)