Skip to content

Commit a85804b

Browse files
committed
Restore dictionary coercion in min_max.rs
Ensure MIN/MAX(Dictionary(..., T)) returns T at SQL boundary while retaining the new dictionary comparison logic. Update regression tests to verify that dictionary inputs are accepted and that the result is the underlying scalar type. Adjust planner-level regression test in basic.rs to expect final output schema to be Utf8 instead of dictionary-typed.
1 parent 4bc0ac3 commit a85804b

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
486486

487487
let results = df.collect().await?;
488488

489-
assert_eq!(results[0].schema().field(0).data_type(), &dict_type);
490-
assert_eq!(results[0].schema().field(1).data_type(), &dict_type);
489+
assert_eq!(results[0].schema().field(0).data_type(), &DataType::Utf8);
490+
assert_eq!(results[0].schema().field(1).data_type(), &DataType::Utf8);
491491

492492
assert_snapshot!(
493493
batches_to_string(&results),

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
5353
use datafusion_macros::user_doc;
5454
use half::f16;
5555
use std::mem::size_of_val;
56+
use std::ops::Deref;
5657

5758
fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
5859
// make sure that the input types only has one element.
@@ -62,12 +63,17 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
6263
input_types.len()
6364
);
6465
}
65-
// Preserve dictionary inputs so planned MIN/MAX execution uses the same
66-
// dictionary-aware accumulator/state path as direct accumulator tests.
67-
//
68-
// TODO add checker for datatype which min and max supported.
69-
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function.
70-
Ok(input_types.to_vec())
66+
// min and max support the dictionary data type
67+
// unpack the dictionary to get the value
68+
match &input_types[0] {
69+
DataType::Dictionary(_, dict_value_type) => {
70+
// TODO add checker, if the value type is complex data type
71+
Ok(vec![dict_value_type.deref().clone()])
72+
}
73+
// TODO add checker for datatype which min and max supported.
74+
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function.
75+
_ => Ok(input_types.to_vec()),
76+
}
7177
}
7278

7379
#[user_doc(
@@ -1209,23 +1215,31 @@ mod tests {
12091215

12101216
#[test]
12111217
fn test_min_max_coerce_types() {
1212-
// the coerced types is same with input types
12131218
let funs: Vec<Box<dyn AggregateUDFImpl>> =
12141219
vec![Box::new(Min::new()), Box::new(Max::new())];
1215-
let input_types = vec![
1216-
vec![DataType::Int32],
1217-
vec![DataType::Decimal128(10, 2)],
1218-
vec![DataType::Decimal256(1, 1)],
1219-
vec![DataType::Utf8],
1220-
vec![DataType::Dictionary(
1221-
Box::new(DataType::Int32),
1222-
Box::new(DataType::Utf8),
1223-
)],
1220+
let cases = vec![
1221+
(vec![DataType::Int32], vec![DataType::Int32]),
1222+
(
1223+
vec![DataType::Decimal128(10, 2)],
1224+
vec![DataType::Decimal128(10, 2)],
1225+
),
1226+
(
1227+
vec![DataType::Decimal256(1, 1)],
1228+
vec![DataType::Decimal256(1, 1)],
1229+
),
1230+
(vec![DataType::Utf8], vec![DataType::Utf8]),
1231+
(
1232+
vec![DataType::Dictionary(
1233+
Box::new(DataType::Int32),
1234+
Box::new(DataType::Utf8),
1235+
)],
1236+
vec![DataType::Utf8],
1237+
),
12241238
];
12251239
for fun in funs {
1226-
for input_type in &input_types {
1240+
for (input_type, expected_type) in &cases {
12271241
let result = fun.coerce_types(input_type);
1228-
assert_eq!(*input_type, result.unwrap());
1242+
assert_eq!(*expected_type, result.unwrap());
12291243
}
12301244
}
12311245
}
@@ -1235,18 +1249,12 @@ mod tests {
12351249
let data_type =
12361250
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
12371251
let result = get_min_max_result_type(&[data_type])?;
1238-
assert_eq!(
1239-
result,
1240-
vec![DataType::Dictionary(
1241-
Box::new(DataType::Int32),
1242-
Box::new(DataType::Utf8),
1243-
)]
1244-
);
1252+
assert_eq!(result, vec![DataType::Utf8]);
12451253
Ok(())
12461254
}
12471255

12481256
#[test]
1249-
fn test_min_max_dictionary() -> Result<()> {
1257+
fn test_min_max_dictionary_after_coercion() -> Result<()> {
12501258
let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]);
12511259
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]);
12521260
let dict_array =
@@ -1258,18 +1266,12 @@ mod tests {
12581266
let mut min_acc = MinAccumulator::try_new(&rt_type)?;
12591267
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
12601268
let min_result = min_acc.evaluate()?;
1261-
assert_eq!(
1262-
min_result,
1263-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string())))
1264-
);
1269+
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
12651270

12661271
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
12671272
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
12681273
let max_result = max_acc.evaluate()?;
1269-
assert_eq!(
1270-
max_result,
1271-
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string())))
1272-
);
1274+
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
12731275
Ok(())
12741276
}
12751277

0 commit comments

Comments
 (0)