Skip to content

Commit aa58eec

Browse files
committed
feat(query): support aggregate common subplans in CSE
1 parent 66a717a commit aa58eec

3 files changed

Lines changed: 301 additions & 1 deletion

File tree

src/query/sql/src/planner/optimizer/optimizers/cse/analyze.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ mod tests {
159159
use databend_common_expression::TableDataType;
160160
use databend_common_expression::TableField;
161161
use databend_common_expression::TableSchema;
162+
use databend_common_expression::types::DataType;
162163
use databend_common_expression::types::NumberDataType;
163164
use databend_common_meta_app::schema::CatalogInfo;
164165
use databend_common_meta_app::schema::DatabaseType;
@@ -167,10 +168,19 @@ mod tests {
167168
use databend_common_meta_app::schema::TableMeta;
168169

169170
use super::*;
171+
use crate::ColumnBindingBuilder;
172+
use crate::Symbol;
173+
use crate::Visibility;
170174
use crate::planner::metadata::Metadata;
175+
use crate::plans::Aggregate;
176+
use crate::plans::AggregateFunction;
177+
use crate::plans::AggregateMode;
178+
use crate::plans::BoundColumnRef;
171179
use crate::plans::Join;
172180
use crate::plans::JoinType;
173181
use crate::plans::RelOperator;
182+
use crate::plans::ScalarExpr;
183+
use crate::plans::ScalarItem;
174184
use crate::plans::Scan;
175185

176186
#[derive(Debug)]
@@ -240,6 +250,62 @@ mod tests {
240250
})))
241251
}
242252

253+
fn column_expr(metadata: &Metadata, table_index: usize) -> ScalarExpr {
254+
let column = metadata.columns_by_table_index(table_index)[0].clone();
255+
BoundColumnRef {
256+
span: None,
257+
column: ColumnBindingBuilder::new(
258+
column.name(),
259+
column.index(),
260+
Box::new(column.data_type()),
261+
Visibility::Visible,
262+
)
263+
.table_index(Some(table_index))
264+
.build(),
265+
}
266+
.into()
267+
}
268+
269+
fn max_aggregate_expr(
270+
metadata: &Metadata,
271+
table_index: usize,
272+
output_index: Symbol,
273+
with_group_by: bool,
274+
) -> SExpr {
275+
let group_items = if with_group_by {
276+
vec![ScalarItem {
277+
scalar: column_expr(metadata, table_index),
278+
index: Symbol::new(output_index.as_usize() + 1),
279+
}]
280+
} else {
281+
vec![]
282+
};
283+
284+
SExpr::create_unary(
285+
Arc::new(RelOperator::Aggregate(Aggregate {
286+
mode: AggregateMode::Initial,
287+
group_items,
288+
aggregate_functions: vec![ScalarItem {
289+
scalar: ScalarExpr::AggregateFunction(AggregateFunction {
290+
span: None,
291+
func_name: "max".to_string(),
292+
distinct: false,
293+
params: vec![],
294+
args: vec![column_expr(metadata, table_index)],
295+
return_type: Box::new(DataType::Number(NumberDataType::UInt64)),
296+
sort_descs: vec![],
297+
display_name: "max(a)".to_string(),
298+
}),
299+
index: output_index,
300+
}],
301+
from_distinct: false,
302+
rank_limit: None,
303+
grouping_sets: None,
304+
})),
305+
Arc::new(scan_expr(metadata, table_index)),
306+
)
307+
}
308+
243309
fn cross_join_expr(left: SExpr, right: SExpr) -> SExpr {
244310
SExpr::create_binary(
245311
Arc::new(RelOperator::Join(Join {
@@ -313,4 +379,50 @@ mod tests {
313379
.all(|cte| matches!(cte.child(0).unwrap().plan(), RelOperator::Scan(_)))
314380
);
315381
}
382+
383+
#[test]
384+
fn test_analyze_common_subexpression_matches_identical_aggregates() {
385+
let mut metadata = Metadata::default();
386+
let t1 = fake_fuse_table(1, "t1");
387+
388+
let t1_left = add_table(&mut metadata, t1.clone());
389+
let t1_right = add_table(&mut metadata, t1);
390+
391+
let left = max_aggregate_expr(&metadata, t1_left, Symbol::new(10), false);
392+
let right = max_aggregate_expr(&metadata, t1_right, Symbol::new(11), false);
393+
let root = cross_join_expr(left, right);
394+
395+
let (replacements, materialized_ctes) =
396+
analyze_common_subexpression(&root, &mut metadata).unwrap();
397+
398+
assert_eq!(replacements.len(), 2);
399+
assert_eq!(materialized_ctes.len(), 1);
400+
assert!(matches!(
401+
materialized_ctes[0].child(0).unwrap().plan(),
402+
RelOperator::Aggregate(_)
403+
));
404+
}
405+
406+
#[test]
407+
fn test_analyze_common_subexpression_matches_identical_group_aggregates() {
408+
let mut metadata = Metadata::default();
409+
let t1 = fake_fuse_table(1, "t1");
410+
411+
let t1_left = add_table(&mut metadata, t1.clone());
412+
let t1_right = add_table(&mut metadata, t1);
413+
414+
let left = max_aggregate_expr(&metadata, t1_left, Symbol::new(10), true);
415+
let right = max_aggregate_expr(&metadata, t1_right, Symbol::new(12), true);
416+
let root = cross_join_expr(left, right);
417+
418+
let (replacements, materialized_ctes) =
419+
analyze_common_subexpression(&root, &mut metadata).unwrap();
420+
421+
assert_eq!(replacements.len(), 2);
422+
assert_eq!(materialized_ctes.len(), 1);
423+
assert!(matches!(
424+
materialized_ctes[0].child(0).unwrap().plan(),
425+
RelOperator::Aggregate(_)
426+
));
427+
}
316428
}

src/query/sql/src/planner/optimizer/optimizers/cse/table_signature.rs

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,32 @@
1414

1515
use std::collections::HashMap;
1616

17+
use crate::ColumnBindingBuilder;
1718
use crate::ColumnEntry;
1819
use crate::IndexType;
20+
use crate::ScalarExpr;
21+
use crate::Symbol;
22+
use crate::Visibility;
1923
use crate::optimizer::ir::SExpr;
2024
use crate::planner::metadata::Metadata;
25+
use crate::plans::Aggregate;
2126
use crate::plans::Join;
2227
use crate::plans::JoinType;
2328
use crate::plans::RelOperator;
29+
use crate::plans::ScalarItem;
2430
use crate::plans::Scan;
31+
use crate::plans::VisitorMut;
32+
use crate::plans::walk_expr_mut;
2533

26-
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
34+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
2735
pub struct TableSignature {
2836
pub tables: Vec<IndexType>,
37+
pub aggregate: Option<AggregateSignature>,
38+
}
39+
40+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
41+
pub struct AggregateSignature {
42+
pub aggregate: Aggregate,
2943
}
3044

3145
pub fn collect_table_signatures(
@@ -63,6 +77,7 @@ fn collect_table_signatures_rec(
6377
signature_to_exprs
6478
.entry(TableSignature {
6579
tables: tables.clone(),
80+
aggregate: None,
6681
})
6782
.or_default()
6883
.push((path.clone(), expr.clone()));
@@ -81,15 +96,123 @@ fn collect_table_signatures_rec(
8196
signature_to_exprs
8297
.entry(TableSignature {
8398
tables: tables.clone(),
99+
aggregate: None,
84100
})
85101
.or_default()
86102
.push((path.clone(), expr.clone()));
87103
Some(tables)
88104
}
105+
RelOperator::EvalScalar(_) if child_tables.len() == 1 && child_tables[0].is_some() => {
106+
child_tables[0].clone()
107+
}
108+
RelOperator::Aggregate(aggregate)
109+
if child_tables.len() == 1 && child_tables[0].is_some() =>
110+
{
111+
let tables = child_tables[0].clone().unwrap();
112+
if let Some(aggregate_signature) = aggregate_signature(aggregate, expr.child(0).ok()?) {
113+
signature_to_exprs
114+
.entry(TableSignature {
115+
tables: tables.clone(),
116+
aggregate: Some(aggregate_signature),
117+
})
118+
.or_default()
119+
.push((path.clone(), expr.clone()));
120+
}
121+
None
122+
}
89123
_ => None,
90124
}
91125
}
92126

127+
fn aggregate_signature(aggregate: &Aggregate, input: &SExpr) -> Option<AggregateSignature> {
128+
if aggregate.rank_limit.is_some() || aggregate.grouping_sets.is_some() {
129+
return None;
130+
}
131+
132+
let input_columns = input
133+
.derive_relational_prop()
134+
.ok()?
135+
.output_columns
136+
.iter()
137+
.copied()
138+
.enumerate()
139+
.map(|(position, column)| (column, Symbol::new(position)))
140+
.collect::<HashMap<_, _>>();
141+
142+
let mut aggregate = aggregate.clone();
143+
aggregate.group_items = normalize_scalar_items(&aggregate.group_items, &input_columns)?;
144+
aggregate.aggregate_functions =
145+
normalize_scalar_items(&aggregate.aggregate_functions, &input_columns)?;
146+
147+
Some(AggregateSignature { aggregate })
148+
}
149+
150+
fn normalize_scalar_items(
151+
items: &[ScalarItem],
152+
input_columns: &HashMap<Symbol, Symbol>,
153+
) -> Option<Vec<ScalarItem>> {
154+
items
155+
.iter()
156+
.enumerate()
157+
.map(|(position, item)| {
158+
Some(ScalarItem {
159+
scalar: normalize_scalar_expr(&item.scalar, input_columns)?,
160+
index: Symbol::new(position),
161+
})
162+
})
163+
.collect()
164+
}
165+
166+
fn normalize_scalar_expr(
167+
scalar: &ScalarExpr,
168+
input_columns: &HashMap<Symbol, Symbol>,
169+
) -> Option<ScalarExpr> {
170+
let mut scalar = scalar.clone();
171+
let mut visitor = NormalizeColumnVisitor { input_columns };
172+
visitor.visit(&mut scalar).ok()?;
173+
Some(scalar)
174+
}
175+
176+
struct NormalizeColumnVisitor<'a> {
177+
input_columns: &'a HashMap<Symbol, Symbol>,
178+
}
179+
180+
impl VisitorMut<'_> for NormalizeColumnVisitor<'_> {
181+
fn visit(&mut self, expr: &mut ScalarExpr) -> databend_common_exception::Result<()> {
182+
walk_expr_mut(self, expr)
183+
}
184+
185+
fn visit_bound_column_ref(
186+
&mut self,
187+
col: &mut crate::plans::BoundColumnRef,
188+
) -> databend_common_exception::Result<()> {
189+
let Some(normalized) = self.input_columns.get(&col.column.index) else {
190+
return Err(databend_common_exception::ErrorCode::Internal(
191+
"aggregate CSE column is not produced by input",
192+
));
193+
};
194+
col.column = ColumnBindingBuilder::new(
195+
normalized.to_string(),
196+
*normalized,
197+
col.column.data_type.clone(),
198+
Visibility::Visible,
199+
)
200+
.build();
201+
Ok(())
202+
}
203+
204+
fn visit_aggregate_function(
205+
&mut self,
206+
aggregate: &mut crate::plans::AggregateFunction,
207+
) -> databend_common_exception::Result<()> {
208+
aggregate.display_name.clear();
209+
for expr in aggregate.exprs_mut() {
210+
self.visit(expr)?;
211+
}
212+
Ok(())
213+
}
214+
}
215+
93216
fn scan_signature(scan: &Scan, metadata: &Metadata) -> Option<IndexType> {
94217
let has_internal_column = scan.columns.iter().any(|column_index| {
95218
let column = metadata.column(*column_index);

tests/sqllogictests/suites/mode/standalone/explain/common_subexpression_optimizer.test

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,68 @@ drop table time_dim;
677677

678678
statement ok
679679
drop table store;
680+
681+
statement ok
682+
create or replace table cse_agg_t as
683+
select number as a, number % 3 as b
684+
from numbers(6);
685+
686+
query T nosort
687+
explain select l.b, l.max_a, r.max_a
688+
from (select b, max(a) as max_a from cse_agg_t group by b) l
689+
join (select b, max(a) as max_a from cse_agg_t group by b) r
690+
on l.b = r.b;
691+
----
692+
Sequence
693+
├── MaterializedCTE: cte_cse_0
694+
│ └── AggregateFinal
695+
│ ├── output columns: [max(a) (#2), cse_agg_t.b (#1)]
696+
│ ├── group by: [b]
697+
│ ├── aggregate functions: [max(a)]
698+
│ ├── estimated rows: 3.00
699+
│ └── AggregatePartial
700+
│ ├── group by: [b]
701+
│ ├── aggregate functions: [max(a)]
702+
│ ├── estimated rows: 3.00
703+
│ └── TableScan
704+
│ ├── table: default.default.cse_agg_t
705+
│ ├── scan id: 0
706+
│ ├── output columns: [a (#0), b (#1)]
707+
│ ├── read rows: 6
708+
│ ├── read size: < 1 KiB
709+
│ ├── partitions total: 1
710+
│ ├── partitions scanned: 1
711+
│ ├── pruning stats: [segments: <range pruning: 1 to 1 cost: <slt:ignore>>, blocks: <range pruning: 1 to 1 cost: <slt:ignore>>]
712+
│ ├── push downs: [filters: [], limit: NONE]
713+
│ └── estimated rows: 6.00
714+
└── HashJoin
715+
├── output columns: [cse_agg_t.b (#1), max(a) (#2), max(a) (#5)]
716+
├── join type: INNER
717+
├── build keys: [r.b (#4)]
718+
├── probe keys: [l.b (#1)]
719+
├── keys is null equal: [false]
720+
├── filters: []
721+
├── build join filters:
722+
│ └── filter id:0, build key:r.b (#4), probe targets:[l.b (#1)@scan0], filter type:bloom,inlist,min_max
723+
├── estimated rows: 9.00
724+
├── MaterializeCTERef(Build)
725+
│ ├── cte_name: cte_cse_0
726+
│ ├── cte_schema: [b (#4), max(a) (#5)]
727+
│ └── estimated rows: 3.00
728+
└── MaterializeCTERef(Probe)
729+
├── cte_name: cte_cse_0
730+
├── cte_schema: [b (#1), max(a) (#2)]
731+
└── estimated rows: 3.00
732+
733+
query III rowsort
734+
select l.b, l.max_a, r.max_a
735+
from (select b, max(a) as max_a from cse_agg_t group by b) l
736+
join (select b, max(a) as max_a from cse_agg_t group by b) r
737+
on l.b = r.b;
738+
----
739+
0 3 3
740+
1 4 4
741+
2 5 5
742+
743+
statement ok
744+
drop table cse_agg_t;

0 commit comments

Comments
 (0)