diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 27620221cf23c..ef48fd3f69c9b 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -18,8 +18,8 @@ //! Basic min/max functionality shared across DataFusion aggregate functions use arrow::array::{ - ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, @@ -413,6 +413,22 @@ macro_rules! min_max { min_max_generic!(lhs, rhs, $OP) } + (lhs, rhs) + if matches!(lhs, ScalarValue::Dictionary(_, _)) + || matches!(rhs, ScalarValue::Dictionary(_, _)) => + { + let (lhs, lhs_key_type) = dictionary_scalar_parts(lhs); + let (rhs, rhs_key_type) = dictionary_scalar_parts(rhs); + let result = min_max_generic!(lhs, rhs, $OP); + + match lhs_key_type.zip(rhs_key_type) { + Some((key_type, _)) => { + ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(result)) + } + None => result, + } + } + e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", @@ -423,6 +439,53 @@ macro_rules! min_max { }}; } +fn scalar_batch_extreme(values: &ArrayRef, ordering: Ordering) -> Result { + let mut index = 0; + let mut extreme = loop { + if index == values.len() { + return ScalarValue::try_from(values.data_type()); + } + + let current = ScalarValue::try_from_array(values, index)?; + index += 1; + + if !current.is_null() { + break current; + } + }; + + while index < values.len() { + let current = ScalarValue::try_from_array(values, index)?; + index += 1; + + if !current.is_null() && extreme.try_cmp(¤t)? == ordering { + extreme = current; + } + } + + Ok(extreme) +} + +fn dictionary_scalar_parts(value: &ScalarValue) -> (&ScalarValue, Option<&DataType>) { + match value { + ScalarValue::Dictionary(key_type, inner) => { + (inner.as_ref(), Some(key_type.as_ref())) + } + other => (other, None), + } +} + +fn is_row_wise_batch_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Dictionary(_, _) + ) +} + /// An accumulator to compute the maximum value #[derive(Debug, Clone)] pub struct MaxAccumulator { @@ -760,44 +823,13 @@ pub fn min_batch(values: &ArrayRef) -> Result { min_binary_view ) } - DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?, - DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?, - DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?, - DataType::FixedSizeList(_, _) => { - min_max_batch_generic(values, Ordering::Greater)? - } - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - min_batch(values)? + data_type if is_row_wise_batch_type(data_type) => { + scalar_batch_extreme(values, Ordering::Greater)? } _ => min_max_batch!(values, min), }) } -/// Generic min/max implementation for complex types -fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { - if array.len() == array.null_count() { - return ScalarValue::try_from(array.data_type()); - } - let mut extreme = ScalarValue::try_from_array(array, 0)?; - for i in 1..array.len() { - let current = ScalarValue::try_from_array(array, i)?; - if current.is_null() { - continue; - } - if extreme.is_null() { - extreme = current; - continue; - } - let cmp = extreme.try_cmp(¤t)?; - if cmp == ordering { - extreme = current; - } - } - - Ok(extreme) -} - /// dynamically-typed max(array) -> ScalarValue pub fn max_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { @@ -843,13 +875,8 @@ pub fn max_batch(values: &ArrayRef) -> Result { let value = value.map(|e| e.to_vec()); ScalarValue::FixedSizeBinary(*size, value) } - DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?, - DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, - DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, - DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - max_batch(values)? + data_type if is_row_wise_batch_type(data_type) => { + scalar_batch_extreme(values, Ordering::Less)? } _ => min_max_batch!(values, max), }) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 9d05c57b02e93..9bd7e153d382b 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1259,7 +1259,158 @@ mod tests { let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); + assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); + Ok(()) + } + + fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) + } + + fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue { + dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string()))) + } + + fn string_dictionary_batch(values: &[&str], keys: &[Option]) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn optional_string_dictionary_batch( + values: &[Option<&str>], + keys: &[Option], + ) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn float_dictionary_batch(values: &[f32], keys: &[Option]) -> ArrayRef { + let values = Arc::new(Float32Array::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn evaluate_dictionary_accumulator( + mut acc: impl Accumulator, + batches: &[ArrayRef], + ) -> Result { + for batch in batches { + acc.update_batch(&[Arc::clone(batch)])?; + } + acc.evaluate() + } + + fn assert_dictionary_min_max( + dict_type: &DataType, + batches: &[ArrayRef], + expected_min: &str, + expected_max: &str, + ) -> Result<()> { + let key_type = match dict_type { + DataType::Dictionary(key_type, _) => key_type.as_ref().clone(), + other => panic!("expected dictionary type, got {other:?}"), + }; + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min)); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max)); + + Ok(()) + } + + #[test] + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a", "d"], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a"], + &[None, Some(0), None, Some(1), Some(2)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c") + } + + #[test] + fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> { + let dict_array_ref = + string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z") + } + + #[test] + fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> { + let dict_array_ref = optional_string_dictionary_batch( + &[Some("b"), None, Some("a"), Some("d")], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_multi_batch() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]); + let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]); + + assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d") + } + + #[test] + fn test_min_max_dictionary_float_with_nans() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Float32)); + let batch1 = float_dictionary_batch(&[0.0, f32::NAN], &[Some(0), Some(1)]); + let batch2 = float_dictionary_batch(&[f32::NEG_INFINITY], &[Some(0)]); + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(&dict_type)?, + &[Arc::clone(&batch1), Arc::clone(&batch2)], + )?; + assert_eq!( + min_result, + dict_scalar( + DataType::Int32, + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ) + ); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(&dict_type)?, + &[batch1, batch2], + )?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Float32(Some(f32::NAN))) + ); + Ok(()) } }