@@ -53,6 +53,7 @@ use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
5353use datafusion_macros:: user_doc;
5454use half:: f16;
5555use std:: mem:: size_of_val;
56+ use std:: ops:: Deref ;
5657
5758fn get_min_max_result_type ( input_types : & [ DataType ] ) -> Result < Vec < DataType > > {
5859 // make sure that the input types only has one element.
@@ -62,12 +63,17 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
6263 input_types. len( )
6364 ) ;
6465 }
65- // Preserve dictionary inputs so planned MIN/MAX execution uses the same
66- // dictionary-aware accumulator/state path as direct accumulator tests.
67- //
68- // TODO add checker for datatype which min and max supported.
69- // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function.
70- Ok ( input_types. to_vec ( ) )
66+ // min and max support the dictionary data type
67+ // unpack the dictionary to get the value
68+ match & input_types[ 0 ] {
69+ DataType :: Dictionary ( _, dict_value_type) => {
70+ // TODO add checker, if the value type is complex data type
71+ Ok ( vec ! [ dict_value_type. deref( ) . clone( ) ] )
72+ }
73+ // TODO add checker for datatype which min and max supported.
74+ // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function.
75+ _ => Ok ( input_types. to_vec ( ) ) ,
76+ }
7177}
7278
7379#[ user_doc(
@@ -1209,23 +1215,31 @@ mod tests {
12091215
12101216 #[ test]
12111217 fn test_min_max_coerce_types ( ) {
1212- // the coerced types is same with input types
12131218 let funs: Vec < Box < dyn AggregateUDFImpl > > =
12141219 vec ! [ Box :: new( Min :: new( ) ) , Box :: new( Max :: new( ) ) ] ;
1215- let input_types = vec ! [
1216- vec![ DataType :: Int32 ] ,
1217- vec![ DataType :: Decimal128 ( 10 , 2 ) ] ,
1218- vec![ DataType :: Decimal256 ( 1 , 1 ) ] ,
1219- vec![ DataType :: Utf8 ] ,
1220- vec![ DataType :: Dictionary (
1221- Box :: new( DataType :: Int32 ) ,
1222- Box :: new( DataType :: Utf8 ) ,
1223- ) ] ,
1220+ let cases = vec ! [
1221+ ( vec![ DataType :: Int32 ] , vec![ DataType :: Int32 ] ) ,
1222+ (
1223+ vec![ DataType :: Decimal128 ( 10 , 2 ) ] ,
1224+ vec![ DataType :: Decimal128 ( 10 , 2 ) ] ,
1225+ ) ,
1226+ (
1227+ vec![ DataType :: Decimal256 ( 1 , 1 ) ] ,
1228+ vec![ DataType :: Decimal256 ( 1 , 1 ) ] ,
1229+ ) ,
1230+ ( vec![ DataType :: Utf8 ] , vec![ DataType :: Utf8 ] ) ,
1231+ (
1232+ vec![ DataType :: Dictionary (
1233+ Box :: new( DataType :: Int32 ) ,
1234+ Box :: new( DataType :: Utf8 ) ,
1235+ ) ] ,
1236+ vec![ DataType :: Utf8 ] ,
1237+ ) ,
12241238 ] ;
12251239 for fun in funs {
1226- for input_type in & input_types {
1240+ for ( input_type, expected_type ) in & cases {
12271241 let result = fun. coerce_types ( input_type) ;
1228- assert_eq ! ( * input_type , result. unwrap( ) ) ;
1242+ assert_eq ! ( * expected_type , result. unwrap( ) ) ;
12291243 }
12301244 }
12311245 }
@@ -1235,18 +1249,12 @@ mod tests {
12351249 let data_type =
12361250 DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Utf8 ) ) ;
12371251 let result = get_min_max_result_type ( & [ data_type] ) ?;
1238- assert_eq ! (
1239- result,
1240- vec![ DataType :: Dictionary (
1241- Box :: new( DataType :: Int32 ) ,
1242- Box :: new( DataType :: Utf8 ) ,
1243- ) ]
1244- ) ;
1252+ assert_eq ! ( result, vec![ DataType :: Utf8 ] ) ;
12451253 Ok ( ( ) )
12461254 }
12471255
12481256 #[ test]
1249- fn test_min_max_dictionary ( ) -> Result < ( ) > {
1257+ fn test_min_max_dictionary_after_coercion ( ) -> Result < ( ) > {
12501258 let values = StringArray :: from ( vec ! [ "b" , "c" , "a" , "🦀" , "d" ] ) ;
12511259 let keys = Int32Array :: from ( vec ! [ Some ( 0 ) , Some ( 1 ) , Some ( 2 ) , None , Some ( 4 ) ] ) ;
12521260 let dict_array =
@@ -1258,18 +1266,12 @@ mod tests {
12581266 let mut min_acc = MinAccumulator :: try_new ( & rt_type) ?;
12591267 min_acc. update_batch ( & [ Arc :: clone ( & dict_array_ref) ] ) ?;
12601268 let min_result = min_acc. evaluate ( ) ?;
1261- assert_eq ! (
1262- min_result,
1263- dict_scalar( DataType :: Int32 , ScalarValue :: Utf8 ( Some ( "a" . to_string( ) ) ) )
1264- ) ;
1269+ assert_eq ! ( min_result, ScalarValue :: Utf8 ( Some ( "a" . to_string( ) ) ) ) ;
12651270
12661271 let mut max_acc = MaxAccumulator :: try_new ( & rt_type) ?;
12671272 max_acc. update_batch ( & [ Arc :: clone ( & dict_array_ref) ] ) ?;
12681273 let max_result = max_acc. evaluate ( ) ?;
1269- assert_eq ! (
1270- max_result,
1271- dict_scalar( DataType :: Int32 , ScalarValue :: Utf8 ( Some ( "d" . to_string( ) ) ) )
1272- ) ;
1274+ assert_eq ! ( max_result, ScalarValue :: Utf8 ( Some ( "d" . to_string( ) ) ) ) ;
12731275 Ok ( ( ) )
12741276 }
12751277
0 commit comments