Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 69 additions & 42 deletions datafusion/functions-aggregate-common/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Basic min/max functionality shared across DataFusion aggregate functions

use arrow::array::{
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array,
DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray,
DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray,
Expand Down Expand Up @@ -413,6 +413,22 @@ macro_rules! min_max {
min_max_generic!(lhs, rhs, $OP)
}

(lhs, rhs)
if matches!(lhs, ScalarValue::Dictionary(_, _))
|| matches!(rhs, ScalarValue::Dictionary(_, _)) =>
{
let (lhs, lhs_key_type) = dictionary_scalar_parts(lhs);
let (rhs, rhs_key_type) = dictionary_scalar_parts(rhs);
let result = min_max_generic!(lhs, rhs, $OP);

match lhs_key_type.zip(rhs_key_type) {
Some((key_type, _)) => {
ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(result))
}
None => result,
}
}

e => {
return internal_err!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
Expand All @@ -423,6 +439,53 @@ macro_rules! min_max {
}};
}

fn scalar_batch_extreme(values: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
let mut index = 0;
let mut extreme = loop {
if index == values.len() {
return ScalarValue::try_from(values.data_type());
}

let current = ScalarValue::try_from_array(values, index)?;
index += 1;

if !current.is_null() {
break current;
}
};

while index < values.len() {
let current = ScalarValue::try_from_array(values, index)?;
index += 1;

if !current.is_null() && extreme.try_cmp(&current)? == ordering {
extreme = current;
}
}

Ok(extreme)
}

fn dictionary_scalar_parts(value: &ScalarValue) -> (&ScalarValue, Option<&DataType>) {
match value {
ScalarValue::Dictionary(key_type, inner) => {
(inner.as_ref(), Some(key_type.as_ref()))
}
other => (other, None),
}
}

fn is_row_wise_batch_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Struct(_)
| DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _)
| DataType::Dictionary(_, _)
)
}

/// An accumulator to compute the maximum value
#[derive(Debug, Clone)]
pub struct MaxAccumulator {
Expand Down Expand Up @@ -760,44 +823,13 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
min_binary_view
)
}
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?,
DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?,
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?,
DataType::FixedSizeList(_, _) => {
min_max_batch_generic(values, Ordering::Greater)?
}
DataType::Dictionary(_, _) => {
let values = values.as_any_dictionary().values();
min_batch(values)?
data_type if is_row_wise_batch_type(data_type) => {
scalar_batch_extreme(values, Ordering::Greater)?
}
_ => min_max_batch!(values, min),
})
}

/// Generic min/max implementation for complex types
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
if array.len() == array.null_count() {
return ScalarValue::try_from(array.data_type());
}
let mut extreme = ScalarValue::try_from_array(array, 0)?;
for i in 1..array.len() {
let current = ScalarValue::try_from_array(array, i)?;
if current.is_null() {
continue;
}
if extreme.is_null() {
extreme = current;
continue;
}
let cmp = extreme.try_cmp(&current)?;
if cmp == ordering {
extreme = current;
}
}

Ok(extreme)
}

/// dynamically-typed max(array) -> ScalarValue
pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
Ok(match values.data_type() {
Expand Down Expand Up @@ -843,13 +875,8 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
let value = value.map(|e| e.to_vec());
ScalarValue::FixedSizeBinary(*size, value)
}
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
DataType::Dictionary(_, _) => {
let values = values.as_any_dictionary().values();
max_batch(values)?
data_type if is_row_wise_batch_type(data_type) => {
scalar_batch_extreme(values, Ordering::Less)?
}
_ => min_max_batch!(values, max),
})
Expand Down
153 changes: 152 additions & 1 deletion datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,158 @@ mod tests {
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
let max_result = max_acc.evaluate()?;
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
Ok(())
}

fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue {
ScalarValue::Dictionary(Box::new(key_type), Box::new(inner))
}

fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue {
dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string())))
}

fn string_dictionary_batch(values: &[&str], keys: &[Option<i32>]) -> ArrayRef {
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
Arc::new(
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
) as ArrayRef
}

fn optional_string_dictionary_batch(
values: &[Option<&str>],
keys: &[Option<i32>],
) -> ArrayRef {
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
Arc::new(
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
) as ArrayRef
}

fn float_dictionary_batch(values: &[f32], keys: &[Option<i32>]) -> ArrayRef {
let values = Arc::new(Float32Array::from(values.to_vec())) as ArrayRef;
Arc::new(
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
) as ArrayRef
}

fn evaluate_dictionary_accumulator(
mut acc: impl Accumulator,
batches: &[ArrayRef],
) -> Result<ScalarValue> {
for batch in batches {
acc.update_batch(&[Arc::clone(batch)])?;
}
acc.evaluate()
}

fn assert_dictionary_min_max(
dict_type: &DataType,
batches: &[ArrayRef],
expected_min: &str,
expected_max: &str,
) -> Result<()> {
let key_type = match dict_type {
DataType::Dictionary(key_type, _) => key_type.as_ref().clone(),
other => panic!("expected dictionary type, got {other:?}"),
};

let min_result = evaluate_dictionary_accumulator(
MinAccumulator::try_new(dict_type)?,
batches,
)?;
assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min));

let max_result = evaluate_dictionary_accumulator(
MaxAccumulator::try_new(dict_type)?,
batches,
)?;
assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max));

Ok(())
}

#[test]
fn test_min_max_dictionary_without_coercion() -> Result<()> {
let dict_array_ref = string_dictionary_batch(
&["b", "c", "a", "d"],
&[Some(0), Some(1), Some(2), Some(3)],
);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
}

#[test]
fn test_min_max_dictionary_with_nulls() -> Result<()> {
let dict_array_ref = string_dictionary_batch(
&["b", "c", "a"],
&[None, Some(0), None, Some(1), Some(2)],
);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c")
}

#[test]
fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> {
let dict_array_ref =
string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z")
}

#[test]
fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> {
let dict_array_ref = optional_string_dictionary_batch(
&[Some("b"), None, Some("a"), Some("d")],
&[Some(0), Some(1), Some(2), Some(3)],
);
let dict_type = dict_array_ref.data_type().clone();

assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
}

#[test]
fn test_min_max_dictionary_multi_batch() -> Result<()> {
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]);
let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]);

assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
}

#[test]
fn test_min_max_dictionary_float_with_nans() -> Result<()> {
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Float32));
let batch1 = float_dictionary_batch(&[0.0, f32::NAN], &[Some(0), Some(1)]);
let batch2 = float_dictionary_batch(&[f32::NEG_INFINITY], &[Some(0)]);

let min_result = evaluate_dictionary_accumulator(
MinAccumulator::try_new(&dict_type)?,
&[Arc::clone(&batch1), Arc::clone(&batch2)],
)?;
assert_eq!(
min_result,
dict_scalar(
DataType::Int32,
ScalarValue::Float32(Some(f32::NEG_INFINITY)),
)
);

let max_result = evaluate_dictionary_accumulator(
MaxAccumulator::try_new(&dict_type)?,
&[batch1, batch2],
)?;
assert_eq!(
max_result,
dict_scalar(DataType::Int32, ScalarValue::Float32(Some(f32::NAN)))
);

Ok(())
}
}
Loading