Skip to content

Commit aabc196

Browse files
committed
fix
1 parent 2ea5478 commit aabc196

15 files changed

Lines changed: 392 additions & 63 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ benchmark/clickbench/results
9090
.lycheecache
9191

9292
.claude
93+
.codex
9394

9495
# tmp
9596
tmp

src/query/sql/src/planner/binder/aggregate.rs

Lines changed: 217 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,23 @@ use std::collections::HashSet;
1717
use std::collections::hash_map::Entry;
1818
use std::sync::Arc;
1919

20+
use databend_common_ast::ast::ColumnID;
2021
use databend_common_ast::ast::ColumnRef;
2122
use databend_common_ast::ast::Expr;
23+
use databend_common_ast::ast::FunctionCall as ASTFunctionCall;
2224
use databend_common_ast::ast::GroupBy;
2325
use databend_common_ast::ast::Literal;
26+
use databend_common_ast::ast::Query;
2427
use databend_common_ast::ast::SelectTarget;
2528
use databend_common_exception::ErrorCode;
2629
use databend_common_exception::Result;
2730
use databend_common_expression::Scalar;
2831
use databend_common_expression::types::DataType;
2932
use databend_common_expression::types::NumberDataType;
3033
use databend_common_expression::types::NumberScalar;
34+
use databend_common_functions::aggregates::AggregateFunctionFactory;
35+
use derive_visitor::Drive;
36+
use derive_visitor::Visitor;
3137
use indexmap::Equivalent;
3238
use itertools::Itertools;
3339

@@ -57,8 +63,8 @@ use crate::plans::GroupingSets;
5763
use crate::plans::ScalarExpr;
5864
use crate::plans::ScalarItem;
5965
use crate::plans::UDAFCall;
60-
use crate::plans::Visitor;
61-
use crate::plans::VisitorMut;
66+
use crate::plans::Visitor as ScalarVisitor;
67+
use crate::plans::VisitorMut as ScalarVisitorMut;
6268
use crate::plans::walk_expr_mut;
6369

6470
/// Information for `GROUPING SETS`.
@@ -748,7 +754,7 @@ struct ExistingAggregateRewriter<'a> {
748754
error_message: &'a str,
749755
}
750756

751-
impl<'a> VisitorMut<'a> for ExistingAggregateRewriter<'a> {
757+
impl<'a> ScalarVisitorMut<'a> for ExistingAggregateRewriter<'a> {
752758
fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> {
753759
match expr {
754760
ScalarExpr::AggregateFunction(aggregate) => {
@@ -801,7 +807,7 @@ impl<'a> VisitorMut<'a> for ExistingAggregateRewriter<'a> {
801807
}
802808
}
803809

804-
impl<'a> VisitorMut<'a> for AggregateRewriter<'a> {
810+
impl<'a> ScalarVisitorMut<'a> for AggregateRewriter<'a> {
805811
fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> {
806812
match expr {
807813
ScalarExpr::AggregateFunction(aggregate) => {
@@ -845,6 +851,126 @@ impl<'a> VisitorMut<'a> for AggregateRewriter<'a> {
845851
}
846852
}
847853

854+
type AggregatePrepassAliases = Vec<(String, Expr)>;
855+
856+
struct AggregatePrepassFragment {
857+
expr: Expr,
858+
contains_subquery: bool,
859+
}
860+
861+
#[derive(Default, Visitor)]
862+
#[visitor(Query(enter))]
863+
struct ContainsSubqueryVisitor {
864+
found: bool,
865+
}
866+
867+
impl ContainsSubqueryVisitor {
868+
fn enter_query(&mut self, _query: &Query) {
869+
self.found = true;
870+
}
871+
}
872+
873+
#[derive(Visitor)]
874+
#[visitor(Expr(enter), ColumnRef(enter), Query)]
875+
struct AggregatePrepassScanner<'a> {
876+
name_resolution_ctx: &'a crate::NameResolutionContext,
877+
ast_aliases: &'a AggregatePrepassAliases,
878+
query_depth: usize,
879+
expanding_aliases: HashSet<String>,
880+
fragments: Vec<AggregatePrepassFragment>,
881+
}
882+
883+
impl AggregatePrepassScanner<'_> {
884+
fn scan(
885+
name_resolution_ctx: &crate::NameResolutionContext,
886+
ast_aliases: &AggregatePrepassAliases,
887+
expr: &Expr,
888+
) -> Vec<AggregatePrepassFragment> {
889+
let mut scanner = AggregatePrepassScanner {
890+
name_resolution_ctx,
891+
ast_aliases,
892+
query_depth: 0,
893+
expanding_aliases: HashSet::new(),
894+
fragments: Vec::new(),
895+
};
896+
expr.drive(&mut scanner);
897+
scanner.fragments
898+
}
899+
900+
fn enter_expr(&mut self, expr: &Expr) {
901+
if self.query_depth > 0 {
902+
return;
903+
}
904+
905+
match expr {
906+
Expr::CountAll { window: None, .. } => self.record_fragment(expr),
907+
Expr::FunctionCall { func, .. } if Binder::is_aggregate_prepass_target(func) => {
908+
self.record_fragment(expr)
909+
}
910+
_ => {}
911+
}
912+
}
913+
914+
fn enter_column_ref(&mut self, column: &ColumnRef) {
915+
if self.query_depth > 0 {
916+
return;
917+
}
918+
919+
let Some((alias, alias_expr)) =
920+
Self::find_aggregate_prepass_alias(self.name_resolution_ctx, column, self.ast_aliases)
921+
else {
922+
return;
923+
};
924+
925+
if self.expanding_aliases.insert(alias.clone()) {
926+
alias_expr.drive(self);
927+
self.expanding_aliases.remove(&alias);
928+
}
929+
}
930+
931+
fn enter_query(&mut self, _query: &Query) {
932+
self.query_depth += 1;
933+
}
934+
935+
fn exit_query(&mut self, _query: &Query) {
936+
self.query_depth -= 1;
937+
}
938+
939+
fn record_fragment(&mut self, expr: &Expr) {
940+
self.fragments.push(AggregatePrepassFragment {
941+
expr: expr.clone(),
942+
contains_subquery: Binder::aggregate_prepass_contains_subquery(expr),
943+
});
944+
}
945+
946+
fn find_aggregate_prepass_alias<'a>(
947+
name_resolution_ctx: &crate::NameResolutionContext,
948+
column: &ColumnRef,
949+
ast_aliases: &'a AggregatePrepassAliases,
950+
) -> Option<(String, &'a Expr)> {
951+
if column.database.is_some() || column.table.is_some() {
952+
return None;
953+
}
954+
955+
let ColumnID::Name(ident) = &column.column else {
956+
return None;
957+
};
958+
959+
let alias = normalize_identifier(ident, name_resolution_ctx).name;
960+
let mut matches = ast_aliases
961+
.iter()
962+
.filter(|(candidate, _)| candidate == &alias)
963+
.map(|(_, expr)| expr);
964+
965+
let expr = matches.next()?;
966+
if matches.next().is_some() {
967+
return None;
968+
}
969+
970+
Some((alias, expr))
971+
}
972+
}
973+
848974
impl Binder {
849975
/// Analyze aggregates in select clause, this will rewrite aggregate functions.
850976
/// See [`AggregateRewriter`] for more details.
@@ -864,6 +990,63 @@ impl Binder {
864990
Ok(())
865991
}
866992

993+
pub(super) fn collect_aggregate_prepass_aliases<'a>(
994+
&self,
995+
select_list: &'a SelectList<'a>,
996+
) -> AggregatePrepassAliases {
997+
select_list
998+
.items
999+
.iter()
1000+
.filter_map(|item| match item.select_target {
1001+
SelectTarget::AliasedExpr { expr, .. } => {
1002+
Some((item.alias.clone(), expr.as_ref().clone()))
1003+
}
1004+
_ => None,
1005+
})
1006+
.collect()
1007+
}
1008+
1009+
pub(super) fn pre_register_aggregate_fragments(
1010+
&mut self,
1011+
bind_context: &mut BindContext,
1012+
aliases: &[(String, ScalarExpr)],
1013+
ast_aliases: &AggregatePrepassAliases,
1014+
expr_context: ExprContext,
1015+
expr: &Expr,
1016+
) -> Result<()> {
1017+
for fragment in AggregatePrepassScanner::scan(&self.name_resolution_ctx, ast_aliases, expr)
1018+
{
1019+
if fragment.contains_subquery {
1020+
continue;
1021+
}
1022+
1023+
let _ = self.bind_and_rewrite_aggregate_expr(
1024+
bind_context,
1025+
aliases,
1026+
expr_context,
1027+
&fragment.expr,
1028+
)?;
1029+
}
1030+
1031+
Ok(())
1032+
}
1033+
1034+
fn is_aggregate_prepass_target(func: &ASTFunctionCall) -> bool {
1035+
if func.window.is_some() {
1036+
return false;
1037+
}
1038+
1039+
let func_name = func.name.name.to_lowercase();
1040+
AggregateFunctionFactory::instance().contains(func_name.as_str())
1041+
|| func_name.eq_ignore_ascii_case("grouping")
1042+
}
1043+
1044+
fn aggregate_prepass_contains_subquery(expr: &Expr) -> bool {
1045+
let mut detector = ContainsSubqueryVisitor::default();
1046+
expr.drive(&mut detector);
1047+
detector.found
1048+
}
1049+
8671050
/// We have supported three kinds of `group by` items:
8681051
///
8691052
/// - Index, a integral literal, e.g. `GROUP BY 1`. It choose the 1st item in select as
@@ -891,8 +1074,7 @@ impl Binder {
8911074
}
8921075
}
8931076

894-
let original_context = bind_context.expr_context.clone();
895-
bind_context.set_expr_context(ExprContext::GroupClaue);
1077+
let original_context = bind_context.replace_expr_context(ExprContext::GroupClaue);
8961078

8971079
let group_by = Self::expand_group(group_by.clone())?;
8981080
match &group_by {
@@ -920,7 +1102,7 @@ impl Binder {
9201102
}
9211103
_ => unreachable!(),
9221104
}
923-
bind_context.set_expr_context(original_context);
1105+
bind_context.expr_context = original_context;
9241106
Ok(())
9251107
}
9261108

@@ -1448,6 +1630,34 @@ impl Binder {
14481630
Ok((scalar.clone(), scalar.data_type()?))
14491631
}
14501632
}
1633+
1634+
pub(super) fn bind_and_rewrite_aggregate_expr(
1635+
&mut self,
1636+
bind_context: &mut BindContext,
1637+
aliases: &[(String, ScalarExpr)],
1638+
expr_context: ExprContext,
1639+
expr: &Expr,
1640+
) -> Result<ScalarExpr> {
1641+
let original_context = bind_context.replace_expr_context(expr_context);
1642+
1643+
let mut scalar_binder = ScalarBinder::new(
1644+
bind_context,
1645+
self.ctx.clone(),
1646+
&self.name_resolution_ctx,
1647+
self.metadata.clone(),
1648+
aliases,
1649+
);
1650+
1651+
let (mut result, _) = scalar_binder.bind(expr)?;
1652+
AggregateRewriter::rewrite_expr(
1653+
&mut bind_context.aggregate_info,
1654+
self.metadata.clone(),
1655+
&mut result,
1656+
)?;
1657+
1658+
bind_context.expr_context = original_context;
1659+
Ok(result)
1660+
}
14511661
}
14521662

14531663
fn build_replaced_aggregate_column(

src/query/sql/src/planner/binder/bind_context.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ use crate::plans::ScalarExpr;
6565

6666
/// Context of current expression, this is used to check if
6767
/// the expression is valid in current context.
68-
#[derive(Debug, Clone, Default, EnumAsInner)]
68+
#[derive(Debug, Clone, Copy, Default, EnumAsInner)]
6969
pub enum ExprContext {
7070
SelectClause,
7171
WhereClause,
@@ -919,8 +919,10 @@ impl BindContext {
919919
self.columns.iter().map(|c| c.index).collect()
920920
}
921921

922-
pub fn set_expr_context(&mut self, expr_context: ExprContext) {
923-
self.expr_context = expr_context;
922+
pub fn replace_expr_context(&mut self, new: ExprContext) -> ExprContext {
923+
let old = self.expr_context;
924+
self.expr_context = new;
925+
old
924926
}
925927
}
926928

src/query/sql/src/planner/binder/bind_query/bind_select.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ use crate::optimizer::ir::SExpr;
5454
use crate::planner::QueryExecutor;
5555
use crate::planner::binder::BindContext;
5656
use crate::planner::binder::Binder;
57+
use crate::planner::binder::ExprContext;
5758

5859
impl Binder {
5960
#[async_backtrace::framed]
@@ -142,6 +143,27 @@ impl Binder {
142143
}
143144

144145
self.analyze_aggregate_select(&mut from_context, &mut select_list)?;
146+
let aggregate_prepass_aliases = self.collect_aggregate_prepass_aliases(&select_list);
147+
148+
if let Some(having) = &stmt.having {
149+
self.pre_register_aggregate_fragments(
150+
&mut from_context,
151+
&semantic_aliases,
152+
&aggregate_prepass_aliases,
153+
ExprContext::HavingClause,
154+
having,
155+
)?;
156+
}
157+
158+
for order in order_by {
159+
self.pre_register_aggregate_fragments(
160+
&mut from_context,
161+
&semantic_aliases,
162+
&aggregate_prepass_aliases,
163+
ExprContext::OrderByClause,
164+
&order.expr,
165+
)?;
166+
}
145167

146168
// `analyze_window` should behind `analyze_aggregate_select`,
147169
// because `analyze_window` will rewrite the aggregate functions in the window function's arguments.
@@ -247,7 +269,7 @@ impl Binder {
247269
}
248270

249271
if stmt.distinct {
250-
s_expr = self.bind_distinct(stmt.span, &mut from_context, &mut select_info, s_expr)?;
272+
s_expr = self.bind_distinct(stmt.span, &mut select_info, s_expr)?;
251273
}
252274

253275
s_expr = self.bind_projection(&mut from_context, select_info, s_expr)?;

src/query/sql/src/planner/binder/distinct.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use std::sync::Arc;
1717
use databend_common_ast::Span;
1818
use databend_common_exception::Result;
1919

20-
use crate::BindContext;
2120
use crate::binder::Binder;
2221
use crate::binder::project::SelectInfo;
2322
use crate::optimizer::ir::SExpr;
@@ -28,7 +27,6 @@ impl Binder {
2827
pub fn bind_distinct(
2928
&self,
3029
span: Span,
31-
_bind_context: &mut BindContext,
3230
select_info: &mut SelectInfo,
3331
child: SExpr,
3432
) -> Result<SExpr> {

0 commit comments

Comments
 (0)