1717
1818//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20+ use std:: any:: { type_name, Any } ;
2021use std:: cmp:: Ordering ;
2122use std:: collections:: { HashSet , VecDeque } ;
22- use std:: mem:: { size_of, size_of_val} ;
23+ use std:: hash:: { DefaultHasher , Hash , Hasher } ;
24+ use std:: mem:: { size_of, size_of_val, take} ;
2325use std:: sync:: Arc ;
2426
2527use arrow:: array:: {
@@ -31,7 +33,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields};
3133
3234use datafusion_common:: cast:: as_list_array;
3335use datafusion_common:: scalar:: copy_array_data;
34- use datafusion_common:: utils:: { get_row_at_idx, SingleRowListArrayBuilder } ;
36+ use datafusion_common:: utils:: { compare_rows , get_row_at_idx, SingleRowListArrayBuilder } ;
3537use datafusion_common:: { exec_err, internal_err, Result , ScalarValue } ;
3638use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
3739use datafusion_expr:: utils:: format_state_name;
@@ -74,22 +76,24 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp
7476"# ,
7577 standard_argument( name = "expression" , )
7678) ]
77- #[ derive( Debug ) ]
79+ #[ derive( Debug , PartialEq , Eq , Hash ) ]
7880/// ARRAY_AGG aggregate expression
7981pub struct ArrayAgg {
8082 signature : Signature ,
83+ is_input_pre_ordered : bool ,
8184}
8285
8386impl Default for ArrayAgg {
8487 fn default ( ) -> Self {
8588 Self {
8689 signature : Signature :: any ( 1 , Volatility :: Immutable ) ,
90+ is_input_pre_ordered : false ,
8791 }
8892 }
8993}
9094
9195impl AggregateUDFImpl for ArrayAgg {
92- fn as_any ( & self ) -> & dyn std :: any :: Any {
96+ fn as_any ( & self ) -> & dyn Any {
9397 self
9498 }
9599
@@ -144,6 +148,16 @@ impl AggregateUDFImpl for ArrayAgg {
144148 Ok ( fields)
145149 }
146150
151+ fn with_beneficial_ordering (
152+ self : Arc < Self > ,
153+ beneficial_ordering : bool ,
154+ ) -> Result < Option < Arc < dyn AggregateUDFImpl > > > {
155+ Ok ( Some ( Arc :: new ( Self {
156+ signature : self . signature . clone ( ) ,
157+ is_input_pre_ordered : beneficial_ordering,
158+ } ) ) )
159+ }
160+
147161 fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
148162 let data_type = acc_args. exprs [ 0 ] . data_type ( acc_args. schema ) ?;
149163 let ignore_nulls =
@@ -196,6 +210,7 @@ impl AggregateUDFImpl for ArrayAgg {
196210 & data_type,
197211 & ordering_dtypes,
198212 ordering,
213+ self . is_input_pre_ordered ,
199214 acc_args. is_reversed ,
200215 ignore_nulls,
201216 )
@@ -209,6 +224,23 @@ impl AggregateUDFImpl for ArrayAgg {
209224 fn documentation ( & self ) -> Option < & Documentation > {
210225 self . doc ( )
211226 }
227+
228+ fn equals ( & self , other : & dyn AggregateUDFImpl ) -> bool {
229+ let Some ( other) = <dyn Any + ' static >:: downcast_ref :: < Self > ( other. as_any ( ) )
230+ else {
231+ return false ;
232+ } ;
233+ fn assert_self_impls_eq < T : Eq > ( ) { }
234+ assert_self_impls_eq :: < Self > ( ) ;
235+ PartialEq :: eq ( self , other)
236+ }
237+
238+ fn hash_value ( & self ) -> u64 {
239+ let hasher = & mut DefaultHasher :: new ( ) ;
240+ type_name :: < Self > ( ) . hash ( hasher) ;
241+ Hash :: hash ( self , hasher) ;
242+ Hasher :: finish ( hasher)
243+ }
212244}
213245
214246#[ derive( Debug ) ]
@@ -518,6 +550,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
518550 datatypes : Vec < DataType > ,
519551 /// Stores the ordering requirement of the `Accumulator`.
520552 ordering_req : LexOrdering ,
553+ /// Whether the input is known to be pre-ordered
554+ is_input_pre_ordered : bool ,
521555 /// Whether the aggregation is running in reverse.
522556 reverse : bool ,
523557 /// Whether the aggregation should ignore null values.
@@ -531,6 +565,7 @@ impl OrderSensitiveArrayAggAccumulator {
531565 datatype : & DataType ,
532566 ordering_dtypes : & [ DataType ] ,
533567 ordering_req : LexOrdering ,
568+ is_input_pre_ordered : bool ,
534569 reverse : bool ,
535570 ignore_nulls : bool ,
536571 ) -> Result < Self > {
@@ -541,11 +576,34 @@ impl OrderSensitiveArrayAggAccumulator {
541576 ordering_values : vec ! [ ] ,
542577 datatypes,
543578 ordering_req,
579+ is_input_pre_ordered,
544580 reverse,
545581 ignore_nulls,
546582 } )
547583 }
548584
585+ fn sort ( & mut self ) {
586+ let sort_options = self
587+ . ordering_req
588+ . iter ( )
589+ . map ( |sort_expr| sort_expr. options )
590+ . collect :: < Vec < _ > > ( ) ;
591+ let mut values = take ( & mut self . values )
592+ . into_iter ( )
593+ . zip ( take ( & mut self . ordering_values ) )
594+ . collect :: < Vec < _ > > ( ) ;
595+ let mut delayed_cmp_err = Ok ( ( ) ) ;
596+ values. sort_by ( |( _, left_ordering) , ( _, right_ordering) | {
597+ compare_rows ( left_ordering, right_ordering, & sort_options) . unwrap_or_else (
598+ |err| {
599+ delayed_cmp_err = Err ( err) ;
600+ Ordering :: Equal
601+ } ,
602+ )
603+ } ) ;
604+ ( self . values , self . ordering_values ) = values. into_iter ( ) . unzip ( ) ;
605+ }
606+
549607 fn evaluate_orderings ( & self ) -> Result < ScalarValue > {
550608 let fields = ordering_fields ( & self . ordering_req , & self . datatypes [ 1 ..] ) ;
551609
@@ -629,6 +687,9 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
629687 let mut partition_ordering_values = vec ! [ ] ;
630688
631689 // Existing values should be merged also.
690+ if !self . is_input_pre_ordered {
691+ self . sort ( ) ;
692+ }
632693 partition_values. push ( self . values . clone ( ) . into ( ) ) ;
633694 partition_ordering_values. push ( self . ordering_values . clone ( ) . into ( ) ) ;
634695
@@ -679,13 +740,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
679740 }
680741
681742 fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
743+ if !self . is_input_pre_ordered {
744+ self . sort ( ) ;
745+ }
746+
682747 let mut result = vec ! [ self . evaluate( ) ?] ;
683748 result. push ( self . evaluate_orderings ( ) ?) ;
684749
685750 Ok ( result)
686751 }
687752
688753 fn evaluate ( & mut self ) -> Result < ScalarValue > {
754+ if !self . is_input_pre_ordered {
755+ self . sort ( ) ;
756+ }
757+
689758 if self . values . is_empty ( ) {
690759 return Ok ( ScalarValue :: new_null_list (
691760 self . datatypes [ 0 ] . clone ( ) ,
0 commit comments