Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions datafusion/core/tests/sql/aggregates/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use super::*;
use datafusion::common::test_util::batches_to_string;
use datafusion_catalog::MemTable;
use datafusion_common::ScalarValue;
use datafusion_physical_plan::displayable;
use insta::assert_snapshot;

#[tokio::test]
Expand Down Expand Up @@ -442,6 +443,66 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
let ctx =
SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2));

let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let schema = Arc::new(Schema::new(vec![Field::new(
"dict",
dict_type.clone(),
true,
)]));

let batch1 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(DictionaryArray::new(
Int32Array::from(vec![Some(1), Some(1), None]),
Arc::new(StringArray::from(vec!["a", "z", "zz_unused"])),
))],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(DictionaryArray::new(
Int32Array::from(vec![Some(0), Some(1)]),
Arc::new(StringArray::from(vec!["a", "d"])),
))],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;

let df = ctx
.sql("SELECT min(dict) AS min_dict, max(dict) AS max_dict FROM t")
.await?;
let physical_plan = df.clone().create_physical_plan().await?;
let formatted_plan = format!("{}", displayable(physical_plan.as_ref()).indent(true));
assert!(formatted_plan.contains("AggregateExec: mode=Partial, gby=[]"));
assert!(
formatted_plan.contains("AggregateExec: mode=Final, gby=[]")
|| formatted_plan.contains("AggregateExec: mode=FinalPartitioned, gby=[]")
);

let results = df.collect().await?;

assert_eq!(results[0].schema().field(0).data_type(), &DataType::Utf8);
assert_eq!(results[0].schema().field(1).data_type(), &DataType::Utf8);

assert_snapshot!(
batches_to_string(&results),
@r"
+----------+----------+
| min_dict | max_dict |
+----------+----------+
| a | z |
+----------+----------+
"
);

Ok(())
}

#[tokio::test]
async fn group_by_ree_dict_column() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
78 changes: 60 additions & 18 deletions datafusion/functions-aggregate-common/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Basic min/max functionality shared across DataFusion aggregate functions

use arrow::array::{
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray,
DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray,
Expand Down Expand Up @@ -413,6 +413,30 @@ macro_rules! min_max {
min_max_generic!(lhs, rhs, $OP)
}

(
ScalarValue::Dictionary(key_type, lhs_inner),
ScalarValue::Dictionary(_, rhs_inner),
) => {
wrap_dictionary_scalar(
key_type.as_ref(),
min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP),
)
}

(
ScalarValue::Dictionary(_, lhs_inner),
rhs,
) => {
min_max_generic!(lhs_inner.as_ref(), rhs, $OP)
}

(
lhs,
ScalarValue::Dictionary(_, rhs_inner),
) => {
min_max_generic!(lhs, rhs_inner.as_ref(), $OP)
}

e => {
return internal_err!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
Expand All @@ -423,6 +447,31 @@ macro_rules! min_max {
}};
}

fn dictionary_batch_extreme(
values: &ArrayRef,
ordering: Ordering,
) -> Result<ScalarValue> {
let mut extreme: Option<ScalarValue> = None;

for i in 0..values.len() {
let current = ScalarValue::try_from_array(values, i)?;
if current.is_null() {
continue;
}

match &extreme {
Some(existing) if existing.try_cmp(&current)? != ordering => {}
_ => extreme = Some(current),
}
}

extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok)
}

fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue {
ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(value))
}

/// An accumulator to compute the maximum value
#[derive(Debug, Clone)]
pub struct MaxAccumulator {
Expand Down Expand Up @@ -767,30 +816,26 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
min_max_batch_generic(values, Ordering::Greater)?
}
DataType::Dictionary(_, _) => {
let values = values.as_any_dictionary().values();
min_batch(values)?
dictionary_batch_extreme(values, Ordering::Greater)?
}
_ => min_max_batch!(values, min),
})
}

/// Generic min/max implementation for complex types
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
if array.len() == array.null_count() {
let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i));
let Some(first_idx) = non_null_indices.next() else {
return ScalarValue::try_from(array.data_type());
}
let mut extreme = ScalarValue::try_from_array(array, 0)?;
for i in 1..array.len() {
};

let mut extreme = ScalarValue::try_from_array(array, first_idx)?;
for i in non_null_indices {
let current = ScalarValue::try_from_array(array, i)?;
if current.is_null() {
continue;
}
if extreme.is_null() {
extreme = current;
continue;
}
let cmp = extreme.try_cmp(&current)?;
if cmp == ordering {
if extreme.is_null() || extreme.try_cmp(&current)? == ordering {
extreme = current;
}
}
Expand Down Expand Up @@ -847,10 +892,7 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
DataType::Dictionary(_, _) => {
let values = values.as_any_dictionary().values();
max_batch(values)?
}
DataType::Dictionary(_, _) => dictionary_batch_extreme(values, Ordering::Less)?,
_ => min_max_batch!(values, max),
})
}
149 changes: 137 additions & 12 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
// TODO add checker for datatype which min and max supported
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
// TODO add checker for datatype which min and max supported.
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function.
_ => Ok(input_types.to_vec()),
}
}
Expand Down Expand Up @@ -1215,19 +1215,31 @@ mod tests {

#[test]
fn test_min_max_coerce_types() {
// the coerced types is same with input types
let funs: Vec<Box<dyn AggregateUDFImpl>> =
vec![Box::new(Min::new()), Box::new(Max::new())];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal128(10, 2)],
vec![DataType::Decimal256(1, 1)],
vec![DataType::Utf8],
let cases = vec![
(vec![DataType::Int32], vec![DataType::Int32]),
(
vec![DataType::Decimal128(10, 2)],
vec![DataType::Decimal128(10, 2)],
),
(
vec![DataType::Decimal256(1, 1)],
vec![DataType::Decimal256(1, 1)],
),
(vec![DataType::Utf8], vec![DataType::Utf8]),
(
vec![DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
)],
vec![DataType::Utf8],
),
];
for fun in funs {
for input_type in &input_types {
for (input_type, expected_type) in &cases {
let result = fun.coerce_types(input_type);
assert_eq!(*input_type, result.unwrap());
assert_eq!(*expected_type, result.unwrap());
}
}
}
Expand All @@ -1242,7 +1254,7 @@ mod tests {
}

#[test]
fn test_min_max_dictionary() -> Result<()> {
fn test_min_max_dictionary_after_coercion() -> Result<()> {
let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]);
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]);
let dict_array =
Expand All @@ -1259,7 +1271,120 @@ mod tests {
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
let max_result = max_acc.evaluate()?;
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
Ok(())
}

fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue {
ScalarValue::Dictionary(Box::new(key_type), Box::new(inner))
}

fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue {
dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string())))
}

fn string_dictionary_batch(values: &[&str], keys: &[Option<i32>]) -> ArrayRef {
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
Arc::new(
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
) as ArrayRef
}

fn optional_string_dictionary_batch(
values: &[Option<&str>],
keys: &[Option<i32>],
) -> ArrayRef {
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
Arc::new(
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
) as ArrayRef
}

fn evaluate_dictionary_accumulator(
mut acc: impl Accumulator,
batches: &[ArrayRef],
) -> Result<ScalarValue> {
for batch in batches {
acc.update_batch(&[Arc::clone(batch)])?;
}
acc.evaluate()
}

fn assert_dictionary_min_max(
dict_type: &DataType,
batches: &[ArrayRef],
expected_min: &str,
expected_max: &str,
) -> Result<()> {
let key_type = match dict_type {
DataType::Dictionary(key_type, _) => key_type.as_ref().clone(),
other => panic!("expected dictionary type, got {other:?}"),
};

let min_result = evaluate_dictionary_accumulator(
MinAccumulator::try_new(dict_type)?,
batches,
)?;
assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min));

let max_result = evaluate_dictionary_accumulator(
MaxAccumulator::try_new(dict_type)?,
batches,
)?;
assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max));

Ok(())
}

#[test]
fn test_min_max_dictionary_without_coercion() -> Result<()> {
let dict_array_ref = string_dictionary_batch(
&["b", "c", "a", "d"],
&[Some(0), Some(1), Some(2), Some(3)],
);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
}

#[test]
fn test_min_max_dictionary_with_nulls() -> Result<()> {
let dict_array_ref = string_dictionary_batch(
&["b", "c", "a"],
&[None, Some(0), None, Some(1), Some(2)],
);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c")
}

#[test]
fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> {
let dict_array_ref =
string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z")
}

#[test]
fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> {
let dict_array_ref = optional_string_dictionary_batch(
&[Some("b"), None, Some("a"), Some("d")],
&[Some(0), Some(1), Some(2), Some(3)],
);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
}

#[test]
fn test_min_max_dictionary_multi_batch() -> Result<()> {
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]);
let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]);

assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
}
}
Loading