diff --git a/src/query/sql/src/planner/optimizer/optimizers/cse/analyze.rs b/src/query/sql/src/planner/optimizer/optimizers/cse/analyze.rs index 7df0d50991f61..efc20558e4360 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/cse/analyze.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/cse/analyze.rs @@ -156,10 +156,13 @@ mod tests { use std::any::Any; use databend_common_catalog::table::Table; + use databend_common_expression::Scalar; use databend_common_expression::TableDataType; use databend_common_expression::TableField; use databend_common_expression::TableSchema; + use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; + use databend_common_expression::types::NumberScalar; use databend_common_meta_app::schema::CatalogInfo; use databend_common_meta_app::schema::DatabaseType; use databend_common_meta_app::schema::TableIdent; @@ -167,10 +170,22 @@ mod tests { use databend_common_meta_app::schema::TableMeta; use super::*; + use crate::ColumnBindingBuilder; + use crate::Symbol; + use crate::Visibility; use crate::planner::metadata::Metadata; + use crate::plans::Aggregate; + use crate::plans::AggregateFunction; + use crate::plans::AggregateMode; + use crate::plans::BoundColumnRef; + use crate::plans::ConstantExpr; + use crate::plans::EvalScalar; + use crate::plans::FunctionCall; use crate::plans::Join; use crate::plans::JoinType; use crate::plans::RelOperator; + use crate::plans::ScalarExpr; + use crate::plans::ScalarItem; use crate::plans::Scan; #[derive(Debug)] @@ -240,6 +255,62 @@ mod tests { }))) } + fn column_expr(metadata: &Metadata, table_index: usize) -> ScalarExpr { + let column = metadata.columns_by_table_index(table_index)[0].clone(); + BoundColumnRef { + span: None, + column: ColumnBindingBuilder::new( + column.name(), + column.index(), + Box::new(column.data_type()), + Visibility::Visible, + ) + .table_index(Some(table_index)) + .build(), + } + .into() + } + + fn max_aggregate_expr( + metadata: &Metadata, + table_index: usize, + output_index: Symbol, + with_group_by: bool, + ) -> SExpr { + let group_items = if with_group_by { + vec![ScalarItem { + scalar: column_expr(metadata, table_index), + index: Symbol::new(output_index.as_usize() + 1), + }] + } else { + vec![] + }; + + SExpr::create_unary( + Arc::new(RelOperator::Aggregate(Aggregate { + mode: AggregateMode::Initial, + group_items, + aggregate_functions: vec![ScalarItem { + scalar: ScalarExpr::AggregateFunction(AggregateFunction { + span: None, + func_name: "max".to_string(), + distinct: false, + params: vec![], + args: vec![column_expr(metadata, table_index)], + return_type: Box::new(DataType::Number(NumberDataType::UInt64)), + sort_descs: vec![], + display_name: "max(a)".to_string(), + }), + index: output_index, + }], + from_distinct: false, + rank_limit: None, + grouping_sets: None, + })), + Arc::new(scan_expr(metadata, table_index)), + ) + } + fn cross_join_expr(left: SExpr, right: SExpr) -> SExpr { SExpr::create_binary( Arc::new(RelOperator::Join(Join { @@ -251,6 +322,35 @@ mod tests { ) } + fn eval_scalar_expr( + metadata: &Metadata, + input: SExpr, + table_index: usize, + output_index: Symbol, + value: u64, + ) -> SExpr { + SExpr::create_unary( + Arc::new(RelOperator::EvalScalar(EvalScalar { + items: vec![ScalarItem { + scalar: ScalarExpr::FunctionCall(FunctionCall { + span: None, + func_name: "plus".to_string(), + params: vec![], + arguments: vec![ + column_expr(metadata, table_index), + ScalarExpr::ConstantExpr(ConstantExpr { + span: None, + value: Scalar::Number(NumberScalar::UInt64(value)), + }), + ], + }), + index: output_index, + }], + })), + Arc::new(input), + ) + } + #[test] fn test_analyze_common_subexpression_prefers_cross_join_subtree() { let mut metadata = Metadata::default(); @@ -313,4 +413,86 @@ mod tests { .all(|cte| matches!(cte.child(0).unwrap().plan(), RelOperator::Scan(_))) ); } + + #[test] + fn test_analyze_common_subexpression_matches_identical_aggregates() { + let mut metadata = Metadata::default(); + let t1 = fake_fuse_table(1, "t1"); + + let t1_left = add_table(&mut metadata, t1.clone()); + let t1_right = add_table(&mut metadata, t1); + + let left = max_aggregate_expr(&metadata, t1_left, Symbol::new(10), false); + let right = max_aggregate_expr(&metadata, t1_right, Symbol::new(11), false); + let root = cross_join_expr(left, right); + + let (replacements, materialized_ctes) = + analyze_common_subexpression(&root, &mut metadata).unwrap(); + + assert_eq!(replacements.len(), 2); + assert_eq!(materialized_ctes.len(), 1); + assert!(matches!( + materialized_ctes[0].child(0).unwrap().plan(), + RelOperator::Aggregate(_) + )); + } + + #[test] + fn test_analyze_common_subexpression_matches_identical_group_aggregates() { + let mut metadata = Metadata::default(); + let t1 = fake_fuse_table(1, "t1"); + + let t1_left = add_table(&mut metadata, t1.clone()); + let t1_right = add_table(&mut metadata, t1); + + let left = max_aggregate_expr(&metadata, t1_left, Symbol::new(10), true); + let right = max_aggregate_expr(&metadata, t1_right, Symbol::new(12), true); + let root = cross_join_expr(left, right); + + let (replacements, materialized_ctes) = + analyze_common_subexpression(&root, &mut metadata).unwrap(); + + assert_eq!(replacements.len(), 2); + assert_eq!(materialized_ctes.len(), 1); + assert!(matches!( + materialized_ctes[0].child(0).unwrap().plan(), + RelOperator::Aggregate(_) + )); + } + + #[test] + fn test_analyze_common_subexpression_does_not_materialize_eval_scalar_subtree() { + let mut metadata = Metadata::default(); + let t1 = fake_fuse_table(1, "t1"); + let t2 = fake_fuse_table(2, "t2"); + + let t1_left = add_table(&mut metadata, t1.clone()); + let t2_left = add_table(&mut metadata, t2.clone()); + let t1_right = add_table(&mut metadata, t1); + let t2_right = add_table(&mut metadata, t2); + + let left_input = + cross_join_expr(scan_expr(&metadata, t1_left), scan_expr(&metadata, t2_left)); + let right_input = cross_join_expr( + scan_expr(&metadata, t1_right), + scan_expr(&metadata, t2_right), + ); + let left = eval_scalar_expr(&metadata, left_input, t1_left, Symbol::new(20), 1); + let right = eval_scalar_expr(&metadata, right_input, t1_right, Symbol::new(21), 2); + let root = cross_join_expr(left, right); + + let (_replacements, materialized_ctes) = + analyze_common_subexpression(&root, &mut metadata).unwrap(); + + assert!( + materialized_ctes + .iter() + .all(|cte| !contains_eval_scalar(cte.child(0).unwrap())) + ); + } + + fn contains_eval_scalar(expr: &SExpr) -> bool { + matches!(expr.plan(), RelOperator::EvalScalar(_)) + || expr.children().any(contains_eval_scalar) + } } diff --git a/src/query/sql/src/planner/optimizer/optimizers/cse/table_signature.rs b/src/query/sql/src/planner/optimizer/optimizers/cse/table_signature.rs index 6552d5399243c..8420a5e6671c4 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/cse/table_signature.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/cse/table_signature.rs @@ -14,18 +14,65 @@ use std::collections::HashMap; +use databend_common_expression::ColumnId; +use databend_common_expression::FunctionKind; +use databend_common_functions::BUILTIN_FUNCTIONS; + +use crate::ColumnBindingBuilder; use crate::ColumnEntry; use crate::IndexType; +use crate::ScalarExpr; +use crate::Symbol; +use crate::Visibility; use crate::optimizer::ir::SExpr; use crate::planner::metadata::Metadata; +use crate::plans::Aggregate; +use crate::plans::AsyncFunctionCall; +use crate::plans::FunctionCall; use crate::plans::Join; use crate::plans::JoinType; use crate::plans::RelOperator; +use crate::plans::ScalarItem; use crate::plans::Scan; +use crate::plans::UDAFCall; +use crate::plans::UDFCall; +use crate::plans::UDFLambdaCall; +use crate::plans::Visitor; +use crate::plans::VisitorMut; +use crate::plans::WindowFunc; +use crate::plans::walk_expr_mut; -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct TableSignature { - pub tables: Vec, + pub scans: Vec, + pub aggregate: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ScanSignature { + pub table_id: IndexType, + pub columns: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ColumnSignature { + Base { + column_id: ColumnId, + path_indices: Option>, + virtual_expr: Option, + }, + Virtual { + source_column_id: ColumnId, + column_id: ColumnId, + key_paths: String, + is_try: bool, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct AggregateSignature { + pub aggregate: Aggregate, + pub input_items: Vec, } pub fn collect_table_signatures( @@ -43,7 +90,7 @@ fn collect_table_signatures_rec( path: &mut Vec, metadata: &Metadata, signature_to_exprs: &mut HashMap, SExpr)>>, -) -> Option> { +) -> Option> { let mut child_tables = Vec::with_capacity(expr.arity()); for (child_index, child) in expr.children().enumerate() { path.push(child_index); @@ -58,11 +105,12 @@ fn collect_table_signatures_rec( match expr.plan.as_ref() { RelOperator::Scan(scan) => { - let table_id = scan_signature(scan, metadata)?; - let tables = vec![table_id]; + let scan = scan_signature(scan, metadata)?; + let tables = vec![scan]; signature_to_exprs .entry(TableSignature { - tables: tables.clone(), + scans: tables.clone(), + aggregate: None, }) .or_default() .push((path.clone(), expr.clone())); @@ -80,17 +128,271 @@ fn collect_table_signatures_rec( // signature and get remapped positionally later. signature_to_exprs .entry(TableSignature { - tables: tables.clone(), + scans: tables.clone(), + aggregate: None, }) .or_default() .push((path.clone(), expr.clone())); Some(tables) } + RelOperator::Aggregate(aggregate) if child_tables.len() == 1 => { + let tables = child_tables[0] + .clone() + .or_else(|| aggregate_input_tables(expr.child(0).ok()?, metadata))?; + if let Some(aggregate_signature) = aggregate_signature(aggregate, expr.child(0).ok()?) { + signature_to_exprs + .entry(TableSignature { + scans: tables.clone(), + aggregate: Some(aggregate_signature), + }) + .or_default() + .push((path.clone(), expr.clone())); + } + None + } _ => None, } } -fn scan_signature(scan: &Scan, metadata: &Metadata) -> Option { +fn aggregate_input_tables(expr: &SExpr, metadata: &Metadata) -> Option> { + match expr.plan() { + RelOperator::EvalScalar(_) if expr.arity() == 1 => { + aggregate_input_tables_without_eval_scalar(expr.child(0).ok()?, metadata) + } + _ => aggregate_input_tables_without_eval_scalar(expr, metadata), + } +} + +fn aggregate_input_tables_without_eval_scalar( + expr: &SExpr, + metadata: &Metadata, +) -> Option> { + match expr.plan() { + RelOperator::Scan(scan) => Some(vec![scan_signature(scan, metadata)?]), + RelOperator::Join(join) if is_supported_cross_join(join) && expr.arity() == 2 => { + let mut tables = + aggregate_input_tables_without_eval_scalar(expr.child(0).ok()?, metadata)?; + tables.extend(aggregate_input_tables_without_eval_scalar( + expr.child(1).ok()?, + metadata, + )?); + Some(tables) + } + _ => None, + } +} + +fn aggregate_signature(aggregate: &Aggregate, input: &SExpr) -> Option { + if aggregate.rank_limit.is_some() || aggregate.grouping_sets.is_some() { + return None; + } + + let input_columns = input + .derive_relational_prop() + .ok()? + .output_columns + .iter() + .copied() + .enumerate() + .map(|(position, column)| (column, Symbol::new(position))) + .collect::>(); + + let mut aggregate = aggregate.clone(); + aggregate.group_items = normalize_scalar_items(&aggregate.group_items, &input_columns)?; + aggregate.aggregate_functions = + normalize_scalar_items(&aggregate.aggregate_functions, &input_columns)?; + let input_items = aggregate_input_items(input)?; + + if !scalar_items_are_deterministic(&aggregate.group_items) + || !scalar_items_are_deterministic(&aggregate.aggregate_functions) + || !scalar_items_are_deterministic(&input_items) + { + return None; + } + + Some(AggregateSignature { + aggregate, + input_items, + }) +} + +fn aggregate_input_items(input: &SExpr) -> Option> { + let RelOperator::EvalScalar(eval_scalar) = input.plan() else { + return Some(vec![]); + }; + + let input_columns = input + .derive_relational_prop() + .ok()? + .output_columns + .iter() + .copied() + .enumerate() + .map(|(position, column)| (column, Symbol::new(position))) + .collect::>(); + let child_columns = input + .child(0) + .ok()? + .derive_relational_prop() + .ok()? + .output_columns + .iter() + .copied() + .enumerate() + .map(|(position, column)| (column, Symbol::new(position))) + .collect::>(); + + eval_scalar + .items + .iter() + .map(|item| { + Some(ScalarItem { + scalar: normalize_scalar_expr(&item.scalar, &child_columns)?, + index: *input_columns.get(&item.index)?, + }) + }) + .collect() +} + +fn scalar_items_are_deterministic(items: &[ScalarItem]) -> bool { + items + .iter() + .all(|item| scalar_expr_is_deterministic(&item.scalar)) +} + +fn scalar_expr_is_deterministic(scalar: &ScalarExpr) -> bool { + let mut visitor = DeterministicVisitor { + deterministic: true, + }; + visitor.visit(scalar).is_ok() && visitor.deterministic +} + +struct DeterministicVisitor { + deterministic: bool, +} + +impl<'a> Visitor<'a> for DeterministicVisitor { + fn visit_function_call( + &mut self, + func: &'a FunctionCall, + ) -> databend_common_exception::Result<()> { + if BUILTIN_FUNCTIONS + .get_property(&func.func_name) + .map(|property| property.non_deterministic || property.kind == FunctionKind::SRF) + .unwrap_or(true) + { + self.deterministic = false; + return Ok(()); + } + + for expr in &func.arguments { + self.visit(expr)?; + } + Ok(()) + } + + fn visit_window_function( + &mut self, + _window: &'a WindowFunc, + ) -> databend_common_exception::Result<()> { + self.deterministic = false; + Ok(()) + } + + fn visit_udf_call(&mut self, _udf: &'a UDFCall) -> databend_common_exception::Result<()> { + self.deterministic = false; + Ok(()) + } + + fn visit_udf_lambda_call( + &mut self, + _udf: &'a UDFLambdaCall, + ) -> databend_common_exception::Result<()> { + self.deterministic = false; + Ok(()) + } + + fn visit_udaf_call(&mut self, _udaf: &'a UDAFCall) -> databend_common_exception::Result<()> { + self.deterministic = false; + Ok(()) + } + + fn visit_async_function_call( + &mut self, + _async_func: &'a AsyncFunctionCall, + ) -> databend_common_exception::Result<()> { + self.deterministic = false; + Ok(()) + } +} + +fn normalize_scalar_items( + items: &[ScalarItem], + input_columns: &HashMap, +) -> Option> { + items + .iter() + .enumerate() + .map(|(position, item)| { + Some(ScalarItem { + scalar: normalize_scalar_expr(&item.scalar, input_columns)?, + index: Symbol::new(position), + }) + }) + .collect() +} + +fn normalize_scalar_expr( + scalar: &ScalarExpr, + input_columns: &HashMap, +) -> Option { + let mut scalar = scalar.clone(); + let mut visitor = NormalizeColumnVisitor { input_columns }; + visitor.visit(&mut scalar).ok()?; + Some(scalar) +} + +struct NormalizeColumnVisitor<'a> { + input_columns: &'a HashMap, +} + +impl VisitorMut<'_> for NormalizeColumnVisitor<'_> { + fn visit(&mut self, expr: &mut ScalarExpr) -> databend_common_exception::Result<()> { + walk_expr_mut(self, expr) + } + + fn visit_bound_column_ref( + &mut self, + col: &mut crate::plans::BoundColumnRef, + ) -> databend_common_exception::Result<()> { + let Some(normalized) = self.input_columns.get(&col.column.index) else { + return Err(databend_common_exception::ErrorCode::Internal( + "aggregate CSE column is not produced by input", + )); + }; + col.column = ColumnBindingBuilder::new( + normalized.to_string(), + *normalized, + col.column.data_type.clone(), + Visibility::Visible, + ) + .build(); + Ok(()) + } + + fn visit_aggregate_function( + &mut self, + aggregate: &mut crate::plans::AggregateFunction, + ) -> databend_common_exception::Result<()> { + aggregate.display_name.clear(); + for expr in aggregate.exprs_mut() { + self.visit(expr)?; + } + Ok(()) + } +} + +fn scan_signature(scan: &Scan, metadata: &Metadata) -> Option { let has_internal_column = scan.columns.iter().any(|column_index| { let column = metadata.column(*column_index); matches!(column, ColumnEntry::InternalColumn(_)) @@ -115,7 +417,34 @@ fn scan_signature(scan: &Scan, metadata: &Metadata) -> Option { return None; } - Some(table.get_id() as IndexType) + let mut columns = scan + .columns + .iter() + .map(|column_index| column_signature(metadata.column(*column_index))) + .collect::>>()?; + columns.sort(); + + Some(ScanSignature { + table_id: table.get_id() as IndexType, + columns, + }) +} + +fn column_signature(column: &ColumnEntry) -> Option { + match column { + ColumnEntry::BaseTableColumn(base) => Some(ColumnSignature::Base { + column_id: base.column_id, + path_indices: base.path_indices.clone(), + virtual_expr: base.virtual_expr.clone(), + }), + ColumnEntry::VirtualColumn(virtual_column) => Some(ColumnSignature::Virtual { + source_column_id: virtual_column.source_column_id, + column_id: virtual_column.column_id, + key_paths: format!("{:?}", virtual_column.key_paths), + is_try: virtual_column.is_try, + }), + ColumnEntry::InternalColumn(_) | ColumnEntry::DerivedColumn(_) => None, + } } fn is_supported_cross_join(join: &Join) -> bool { diff --git a/tests/sqllogictests/suites/mode/standalone/explain/common_subexpression_optimizer.test b/tests/sqllogictests/suites/mode/standalone/explain/common_subexpression_optimizer.test index 9a5f3d0994a98..b314e74426c30 100644 --- a/tests/sqllogictests/suites/mode/standalone/explain/common_subexpression_optimizer.test +++ b/tests/sqllogictests/suites/mode/standalone/explain/common_subexpression_optimizer.test @@ -677,3 +677,96 @@ drop table time_dim; statement ok drop table store; + +statement ok +create or replace table cse_agg_t as +select number as a, number % 3 as b +from numbers(6); + +query T nosort +explain select l.b, l.max_a, r.max_a +from (select b, max(a) as max_a from cse_agg_t group by b) l +join (select b, max(a) as max_a from cse_agg_t group by b) r +on l.b = r.b; +---- +Sequence +├── MaterializedCTE: cte_cse_0 +│ └── AggregateFinal +│ ├── output columns: [max(a) (#2), cse_agg_t.b (#1)] +│ ├── group by: [b] +│ ├── aggregate functions: [max(a)] +│ ├── estimated rows: 3.00 +│ └── AggregatePartial +│ ├── group by: [b] +│ ├── aggregate functions: [max(a)] +│ ├── estimated rows: 3.00 +│ └── TableScan +│ ├── table: default.default.cse_agg_t +│ ├── scan id: 0 +│ ├── output columns: [a (#0), b (#1)] +│ ├── read rows: 6 +│ ├── read size: < 1 KiB +│ ├── partitions total: 1 +│ ├── partitions scanned: 1 +│ ├── pruning stats: [segments: >, blocks: >] +│ ├── push downs: [filters: [], limit: NONE] +│ └── estimated rows: 6.00 +└── HashJoin + ├── output columns: [cse_agg_t.b (#1), max(a) (#2), max(a) (#5)] + ├── join type: INNER + ├── build keys: [r.b (#4)] + ├── probe keys: [l.b (#1)] + ├── keys is null equal: [false] + ├── filters: [] + ├── build join filters: + │ └── filter id:0, build key:r.b (#4), probe targets:[l.b (#1)@scan0], filter type:bloom,inlist,min_max + ├── estimated rows: 9.00 + ├── MaterializeCTERef(Build) + │ ├── cte_name: cte_cse_0 + │ ├── cte_schema: [b (#4), max(a) (#5)] + │ └── estimated rows: 3.00 + └── MaterializeCTERef(Probe) + ├── cte_name: cte_cse_0 + ├── cte_schema: [b (#1), max(a) (#2)] + └── estimated rows: 3.00 + +query III rowsort +select l.b, l.max_a, r.max_a +from (select b, max(a) as max_a from cse_agg_t group by b) l +join (select b, max(a) as max_a from cse_agg_t group by b) r +on l.b = r.b; +---- +0 3 3 +1 4 4 +2 5 5 + +query III rowsort +select l.b, l.max_a_plus_1, r.max_a_plus_2 +from (select b, max(a + 1) as max_a_plus_1 from cse_agg_t group by b) l +join (select b, max(a + 2) as max_a_plus_2 from cse_agg_t group by b) r +on l.b = r.b; +---- +0 4 5 +1 5 6 +2 6 7 + +query III rowsort +select l.b, l.max_x_plus_0, r.max_x_plus_0 +from (select b, max(x + 0) as max_x_plus_0 from (select b, a + 1 as x from cse_agg_t) group by b) l +join (select b, max(x + 0) as max_x_plus_0 from (select b, a + 2 as x from cse_agg_t) group by b) r +on l.b = r.b; +---- +0 4 5 +1 5 6 +2 6 7 + +query II rowsort +select l.max_a, r.max_b +from (select max(a) as max_a from cse_agg_t) l +cross join (select max(b) as max_b from cse_agg_t) r; +---- +5 2 + + +statement ok +drop table cse_agg_t;