Skip to content

Commit fbd42e4

Browse files
committed
feat(query): support aggregate common subplans in CSE
1 parent 66a717a commit fbd42e4

3 files changed

Lines changed: 436 additions & 1 deletion

File tree

src/query/sql/src/planner/optimizer/optimizers/cse/analyze.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,36 @@ mod tests {
156156
use std::any::Any;
157157

158158
use databend_common_catalog::table::Table;
159+
use databend_common_expression::Scalar;
159160
use databend_common_expression::TableDataType;
160161
use databend_common_expression::TableField;
161162
use databend_common_expression::TableSchema;
163+
use databend_common_expression::types::DataType;
162164
use databend_common_expression::types::NumberDataType;
165+
use databend_common_expression::types::NumberScalar;
163166
use databend_common_meta_app::schema::CatalogInfo;
164167
use databend_common_meta_app::schema::DatabaseType;
165168
use databend_common_meta_app::schema::TableIdent;
166169
use databend_common_meta_app::schema::TableInfo;
167170
use databend_common_meta_app::schema::TableMeta;
168171

169172
use super::*;
173+
use crate::ColumnBindingBuilder;
174+
use crate::Symbol;
175+
use crate::Visibility;
170176
use crate::planner::metadata::Metadata;
177+
use crate::plans::Aggregate;
178+
use crate::plans::AggregateFunction;
179+
use crate::plans::AggregateMode;
180+
use crate::plans::BoundColumnRef;
181+
use crate::plans::ConstantExpr;
182+
use crate::plans::EvalScalar;
183+
use crate::plans::FunctionCall;
171184
use crate::plans::Join;
172185
use crate::plans::JoinType;
173186
use crate::plans::RelOperator;
187+
use crate::plans::ScalarExpr;
188+
use crate::plans::ScalarItem;
174189
use crate::plans::Scan;
175190

176191
#[derive(Debug)]
@@ -240,6 +255,62 @@ mod tests {
240255
})))
241256
}
242257

258+
fn column_expr(metadata: &Metadata, table_index: usize) -> ScalarExpr {
259+
let column = metadata.columns_by_table_index(table_index)[0].clone();
260+
BoundColumnRef {
261+
span: None,
262+
column: ColumnBindingBuilder::new(
263+
column.name(),
264+
column.index(),
265+
Box::new(column.data_type()),
266+
Visibility::Visible,
267+
)
268+
.table_index(Some(table_index))
269+
.build(),
270+
}
271+
.into()
272+
}
273+
274+
fn max_aggregate_expr(
275+
metadata: &Metadata,
276+
table_index: usize,
277+
output_index: Symbol,
278+
with_group_by: bool,
279+
) -> SExpr {
280+
let group_items = if with_group_by {
281+
vec![ScalarItem {
282+
scalar: column_expr(metadata, table_index),
283+
index: Symbol::new(output_index.as_usize() + 1),
284+
}]
285+
} else {
286+
vec![]
287+
};
288+
289+
SExpr::create_unary(
290+
Arc::new(RelOperator::Aggregate(Aggregate {
291+
mode: AggregateMode::Initial,
292+
group_items,
293+
aggregate_functions: vec![ScalarItem {
294+
scalar: ScalarExpr::AggregateFunction(AggregateFunction {
295+
span: None,
296+
func_name: "max".to_string(),
297+
distinct: false,
298+
params: vec![],
299+
args: vec![column_expr(metadata, table_index)],
300+
return_type: Box::new(DataType::Number(NumberDataType::UInt64)),
301+
sort_descs: vec![],
302+
display_name: "max(a)".to_string(),
303+
}),
304+
index: output_index,
305+
}],
306+
from_distinct: false,
307+
rank_limit: None,
308+
grouping_sets: None,
309+
})),
310+
Arc::new(scan_expr(metadata, table_index)),
311+
)
312+
}
313+
243314
fn cross_join_expr(left: SExpr, right: SExpr) -> SExpr {
244315
SExpr::create_binary(
245316
Arc::new(RelOperator::Join(Join {
@@ -251,6 +322,35 @@ mod tests {
251322
)
252323
}
253324

325+
fn eval_scalar_expr(
326+
metadata: &Metadata,
327+
input: SExpr,
328+
table_index: usize,
329+
output_index: Symbol,
330+
value: u64,
331+
) -> SExpr {
332+
SExpr::create_unary(
333+
Arc::new(RelOperator::EvalScalar(EvalScalar {
334+
items: vec![ScalarItem {
335+
scalar: ScalarExpr::FunctionCall(FunctionCall {
336+
span: None,
337+
func_name: "plus".to_string(),
338+
params: vec![],
339+
arguments: vec![
340+
column_expr(metadata, table_index),
341+
ScalarExpr::ConstantExpr(ConstantExpr {
342+
span: None,
343+
value: Scalar::Number(NumberScalar::UInt64(value)),
344+
}),
345+
],
346+
}),
347+
index: output_index,
348+
}],
349+
})),
350+
Arc::new(input),
351+
)
352+
}
353+
254354
#[test]
255355
fn test_analyze_common_subexpression_prefers_cross_join_subtree() {
256356
let mut metadata = Metadata::default();
@@ -313,4 +413,86 @@ mod tests {
313413
.all(|cte| matches!(cte.child(0).unwrap().plan(), RelOperator::Scan(_)))
314414
);
315415
}
416+
417+
#[test]
418+
fn test_analyze_common_subexpression_matches_identical_aggregates() {
419+
let mut metadata = Metadata::default();
420+
let t1 = fake_fuse_table(1, "t1");
421+
422+
let t1_left = add_table(&mut metadata, t1.clone());
423+
let t1_right = add_table(&mut metadata, t1);
424+
425+
let left = max_aggregate_expr(&metadata, t1_left, Symbol::new(10), false);
426+
let right = max_aggregate_expr(&metadata, t1_right, Symbol::new(11), false);
427+
let root = cross_join_expr(left, right);
428+
429+
let (replacements, materialized_ctes) =
430+
analyze_common_subexpression(&root, &mut metadata).unwrap();
431+
432+
assert_eq!(replacements.len(), 2);
433+
assert_eq!(materialized_ctes.len(), 1);
434+
assert!(matches!(
435+
materialized_ctes[0].child(0).unwrap().plan(),
436+
RelOperator::Aggregate(_)
437+
));
438+
}
439+
440+
#[test]
441+
fn test_analyze_common_subexpression_matches_identical_group_aggregates() {
442+
let mut metadata = Metadata::default();
443+
let t1 = fake_fuse_table(1, "t1");
444+
445+
let t1_left = add_table(&mut metadata, t1.clone());
446+
let t1_right = add_table(&mut metadata, t1);
447+
448+
let left = max_aggregate_expr(&metadata, t1_left, Symbol::new(10), true);
449+
let right = max_aggregate_expr(&metadata, t1_right, Symbol::new(12), true);
450+
let root = cross_join_expr(left, right);
451+
452+
let (replacements, materialized_ctes) =
453+
analyze_common_subexpression(&root, &mut metadata).unwrap();
454+
455+
assert_eq!(replacements.len(), 2);
456+
assert_eq!(materialized_ctes.len(), 1);
457+
assert!(matches!(
458+
materialized_ctes[0].child(0).unwrap().plan(),
459+
RelOperator::Aggregate(_)
460+
));
461+
}
462+
463+
#[test]
464+
fn test_analyze_common_subexpression_does_not_materialize_eval_scalar_subtree() {
465+
let mut metadata = Metadata::default();
466+
let t1 = fake_fuse_table(1, "t1");
467+
let t2 = fake_fuse_table(2, "t2");
468+
469+
let t1_left = add_table(&mut metadata, t1.clone());
470+
let t2_left = add_table(&mut metadata, t2.clone());
471+
let t1_right = add_table(&mut metadata, t1);
472+
let t2_right = add_table(&mut metadata, t2);
473+
474+
let left_input =
475+
cross_join_expr(scan_expr(&metadata, t1_left), scan_expr(&metadata, t2_left));
476+
let right_input = cross_join_expr(
477+
scan_expr(&metadata, t1_right),
478+
scan_expr(&metadata, t2_right),
479+
);
480+
let left = eval_scalar_expr(&metadata, left_input, t1_left, Symbol::new(20), 1);
481+
let right = eval_scalar_expr(&metadata, right_input, t1_right, Symbol::new(21), 2);
482+
let root = cross_join_expr(left, right);
483+
484+
let (_replacements, materialized_ctes) =
485+
analyze_common_subexpression(&root, &mut metadata).unwrap();
486+
487+
assert!(
488+
materialized_ctes
489+
.iter()
490+
.all(|cte| !contains_eval_scalar(cte.child(0).unwrap()))
491+
);
492+
}
493+
494+
fn contains_eval_scalar(expr: &SExpr) -> bool {
495+
matches!(expr.plan(), RelOperator::EvalScalar(_))
496+
|| expr.children().any(contains_eval_scalar)
497+
}
316498
}

0 commit comments

Comments
 (0)