@@ -17,17 +17,23 @@ use std::collections::HashSet;
1717use std:: collections:: hash_map:: Entry ;
1818use std:: sync:: Arc ;
1919
20+ use databend_common_ast:: ast:: ColumnID ;
2021use databend_common_ast:: ast:: ColumnRef ;
2122use databend_common_ast:: ast:: Expr ;
23+ use databend_common_ast:: ast:: FunctionCall as ASTFunctionCall ;
2224use databend_common_ast:: ast:: GroupBy ;
2325use databend_common_ast:: ast:: Literal ;
26+ use databend_common_ast:: ast:: Query ;
2427use databend_common_ast:: ast:: SelectTarget ;
2528use databend_common_exception:: ErrorCode ;
2629use databend_common_exception:: Result ;
2730use databend_common_expression:: Scalar ;
2831use databend_common_expression:: types:: DataType ;
2932use databend_common_expression:: types:: NumberDataType ;
3033use databend_common_expression:: types:: NumberScalar ;
34+ use databend_common_functions:: aggregates:: AggregateFunctionFactory ;
35+ use derive_visitor:: Drive ;
36+ use derive_visitor:: Visitor ;
3137use indexmap:: Equivalent ;
3238use itertools:: Itertools ;
3339
@@ -57,8 +63,8 @@ use crate::plans::GroupingSets;
5763use crate :: plans:: ScalarExpr ;
5864use crate :: plans:: ScalarItem ;
5965use crate :: plans:: UDAFCall ;
60- use crate :: plans:: Visitor ;
61- use crate :: plans:: VisitorMut ;
66+ use crate :: plans:: Visitor as ScalarVisitor ;
67+ use crate :: plans:: VisitorMut as ScalarVisitorMut ;
6268use crate :: plans:: walk_expr_mut;
6369
6470/// Information for `GROUPING SETS`.
@@ -748,7 +754,7 @@ struct ExistingAggregateRewriter<'a> {
748754 error_message : & ' a str ,
749755}
750756
751- impl < ' a > VisitorMut < ' a > for ExistingAggregateRewriter < ' a > {
757+ impl < ' a > ScalarVisitorMut < ' a > for ExistingAggregateRewriter < ' a > {
752758 fn visit ( & mut self , expr : & ' a mut ScalarExpr ) -> Result < ( ) > {
753759 match expr {
754760 ScalarExpr :: AggregateFunction ( aggregate) => {
@@ -801,7 +807,7 @@ impl<'a> VisitorMut<'a> for ExistingAggregateRewriter<'a> {
801807 }
802808}
803809
804- impl < ' a > VisitorMut < ' a > for AggregateRewriter < ' a > {
810+ impl < ' a > ScalarVisitorMut < ' a > for AggregateRewriter < ' a > {
805811 fn visit ( & mut self , expr : & ' a mut ScalarExpr ) -> Result < ( ) > {
806812 match expr {
807813 ScalarExpr :: AggregateFunction ( aggregate) => {
@@ -845,6 +851,126 @@ impl<'a> VisitorMut<'a> for AggregateRewriter<'a> {
845851 }
846852}
847853
854+ type AggregatePrepassAliases = Vec < ( String , Expr ) > ;
855+
856+ struct AggregatePrepassFragment {
857+ expr : Expr ,
858+ contains_subquery : bool ,
859+ }
860+
861+ #[ derive( Default , Visitor ) ]
862+ #[ visitor( Query ( enter) ) ]
863+ struct ContainsSubqueryVisitor {
864+ found : bool ,
865+ }
866+
867+ impl ContainsSubqueryVisitor {
868+ fn enter_query ( & mut self , _query : & Query ) {
869+ self . found = true ;
870+ }
871+ }
872+
873+ #[ derive( Visitor ) ]
874+ #[ visitor( Expr ( enter) , ColumnRef ( enter) , Query ) ]
875+ struct AggregatePrepassScanner < ' a > {
876+ name_resolution_ctx : & ' a crate :: NameResolutionContext ,
877+ ast_aliases : & ' a AggregatePrepassAliases ,
878+ query_depth : usize ,
879+ expanding_aliases : HashSet < String > ,
880+ fragments : Vec < AggregatePrepassFragment > ,
881+ }
882+
883+ impl AggregatePrepassScanner < ' _ > {
884+ fn scan (
885+ name_resolution_ctx : & crate :: NameResolutionContext ,
886+ ast_aliases : & AggregatePrepassAliases ,
887+ expr : & Expr ,
888+ ) -> Vec < AggregatePrepassFragment > {
889+ let mut scanner = AggregatePrepassScanner {
890+ name_resolution_ctx,
891+ ast_aliases,
892+ query_depth : 0 ,
893+ expanding_aliases : HashSet :: new ( ) ,
894+ fragments : Vec :: new ( ) ,
895+ } ;
896+ expr. drive ( & mut scanner) ;
897+ scanner. fragments
898+ }
899+
900+ fn enter_expr ( & mut self , expr : & Expr ) {
901+ if self . query_depth > 0 {
902+ return ;
903+ }
904+
905+ match expr {
906+ Expr :: CountAll { window : None , .. } => self . record_fragment ( expr) ,
907+ Expr :: FunctionCall { func, .. } if Binder :: is_aggregate_prepass_target ( func) => {
908+ self . record_fragment ( expr)
909+ }
910+ _ => { }
911+ }
912+ }
913+
914+ fn enter_column_ref ( & mut self , column : & ColumnRef ) {
915+ if self . query_depth > 0 {
916+ return ;
917+ }
918+
919+ let Some ( ( alias, alias_expr) ) =
920+ Self :: find_aggregate_prepass_alias ( self . name_resolution_ctx , column, self . ast_aliases )
921+ else {
922+ return ;
923+ } ;
924+
925+ if self . expanding_aliases . insert ( alias. clone ( ) ) {
926+ alias_expr. drive ( self ) ;
927+ self . expanding_aliases . remove ( & alias) ;
928+ }
929+ }
930+
931+ fn enter_query ( & mut self , _query : & Query ) {
932+ self . query_depth += 1 ;
933+ }
934+
935+ fn exit_query ( & mut self , _query : & Query ) {
936+ self . query_depth -= 1 ;
937+ }
938+
939+ fn record_fragment ( & mut self , expr : & Expr ) {
940+ self . fragments . push ( AggregatePrepassFragment {
941+ expr : expr. clone ( ) ,
942+ contains_subquery : Binder :: aggregate_prepass_contains_subquery ( expr) ,
943+ } ) ;
944+ }
945+
946+ fn find_aggregate_prepass_alias < ' a > (
947+ name_resolution_ctx : & crate :: NameResolutionContext ,
948+ column : & ColumnRef ,
949+ ast_aliases : & ' a AggregatePrepassAliases ,
950+ ) -> Option < ( String , & ' a Expr ) > {
951+ if column. database . is_some ( ) || column. table . is_some ( ) {
952+ return None ;
953+ }
954+
955+ let ColumnID :: Name ( ident) = & column. column else {
956+ return None ;
957+ } ;
958+
959+ let alias = normalize_identifier ( ident, name_resolution_ctx) . name ;
960+ let mut matches = ast_aliases
961+ . iter ( )
962+ . filter ( |( candidate, _) | candidate == & alias)
963+ . map ( |( _, expr) | expr) ;
964+
965+ let expr = matches. next ( ) ?;
966+ if matches. next ( ) . is_some ( ) {
967+ return None ;
968+ }
969+
970+ Some ( ( alias, expr) )
971+ }
972+ }
973+
848974impl Binder {
849975 /// Analyze aggregates in select clause, this will rewrite aggregate functions.
850976 /// See [`AggregateRewriter`] for more details.
@@ -864,6 +990,63 @@ impl Binder {
864990 Ok ( ( ) )
865991 }
866992
993+ pub ( super ) fn collect_aggregate_prepass_aliases < ' a > (
994+ & self ,
995+ select_list : & ' a SelectList < ' a > ,
996+ ) -> AggregatePrepassAliases {
997+ select_list
998+ . items
999+ . iter ( )
1000+ . filter_map ( |item| match item. select_target {
1001+ SelectTarget :: AliasedExpr { expr, .. } => {
1002+ Some ( ( item. alias . clone ( ) , expr. as_ref ( ) . clone ( ) ) )
1003+ }
1004+ _ => None ,
1005+ } )
1006+ . collect ( )
1007+ }
1008+
1009+ pub ( super ) fn pre_register_aggregate_fragments (
1010+ & mut self ,
1011+ bind_context : & mut BindContext ,
1012+ aliases : & [ ( String , ScalarExpr ) ] ,
1013+ ast_aliases : & AggregatePrepassAliases ,
1014+ expr_context : ExprContext ,
1015+ expr : & Expr ,
1016+ ) -> Result < ( ) > {
1017+ for fragment in AggregatePrepassScanner :: scan ( & self . name_resolution_ctx , ast_aliases, expr)
1018+ {
1019+ if fragment. contains_subquery {
1020+ continue ;
1021+ }
1022+
1023+ let _ = self . bind_and_rewrite_aggregate_expr (
1024+ bind_context,
1025+ aliases,
1026+ expr_context,
1027+ & fragment. expr ,
1028+ ) ?;
1029+ }
1030+
1031+ Ok ( ( ) )
1032+ }
1033+
1034+ fn is_aggregate_prepass_target ( func : & ASTFunctionCall ) -> bool {
1035+ if func. window . is_some ( ) {
1036+ return false ;
1037+ }
1038+
1039+ let func_name = func. name . name . to_lowercase ( ) ;
1040+ AggregateFunctionFactory :: instance ( ) . contains ( func_name. as_str ( ) )
1041+ || func_name. eq_ignore_ascii_case ( "grouping" )
1042+ }
1043+
1044+ fn aggregate_prepass_contains_subquery ( expr : & Expr ) -> bool {
1045+ let mut detector = ContainsSubqueryVisitor :: default ( ) ;
1046+ expr. drive ( & mut detector) ;
1047+ detector. found
1048+ }
1049+
8671050 /// We have supported three kinds of `group by` items:
8681051 ///
8691052 /// - Index, a integral literal, e.g. `GROUP BY 1`. It choose the 1st item in select as
@@ -891,8 +1074,7 @@ impl Binder {
8911074 }
8921075 }
8931076
894- let original_context = bind_context. expr_context . clone ( ) ;
895- bind_context. set_expr_context ( ExprContext :: GroupClaue ) ;
1077+ let original_context = bind_context. replace_expr_context ( ExprContext :: GroupClaue ) ;
8961078
8971079 let group_by = Self :: expand_group ( group_by. clone ( ) ) ?;
8981080 match & group_by {
@@ -920,7 +1102,7 @@ impl Binder {
9201102 }
9211103 _ => unreachable ! ( ) ,
9221104 }
923- bind_context. set_expr_context ( original_context) ;
1105+ bind_context. expr_context = original_context;
9241106 Ok ( ( ) )
9251107 }
9261108
@@ -1448,6 +1630,34 @@ impl Binder {
14481630 Ok ( ( scalar. clone ( ) , scalar. data_type ( ) ?) )
14491631 }
14501632 }
1633+
1634+ pub ( super ) fn bind_and_rewrite_aggregate_expr (
1635+ & mut self ,
1636+ bind_context : & mut BindContext ,
1637+ aliases : & [ ( String , ScalarExpr ) ] ,
1638+ expr_context : ExprContext ,
1639+ expr : & Expr ,
1640+ ) -> Result < ScalarExpr > {
1641+ let original_context = bind_context. replace_expr_context ( expr_context) ;
1642+
1643+ let mut scalar_binder = ScalarBinder :: new (
1644+ bind_context,
1645+ self . ctx . clone ( ) ,
1646+ & self . name_resolution_ctx ,
1647+ self . metadata . clone ( ) ,
1648+ aliases,
1649+ ) ;
1650+
1651+ let ( mut result, _) = scalar_binder. bind ( expr) ?;
1652+ AggregateRewriter :: rewrite_expr (
1653+ & mut bind_context. aggregate_info ,
1654+ self . metadata . clone ( ) ,
1655+ & mut result,
1656+ ) ?;
1657+
1658+ bind_context. expr_context = original_context;
1659+ Ok ( result)
1660+ }
14511661}
14521662
14531663fn build_replaced_aggregate_column (
0 commit comments