@@ -156,21 +156,36 @@ mod tests {
156156 use std:: any:: Any ;
157157
158158 use databend_common_catalog:: table:: Table ;
159+ use databend_common_expression:: Scalar ;
159160 use databend_common_expression:: TableDataType ;
160161 use databend_common_expression:: TableField ;
161162 use databend_common_expression:: TableSchema ;
163+ use databend_common_expression:: types:: DataType ;
162164 use databend_common_expression:: types:: NumberDataType ;
165+ use databend_common_expression:: types:: NumberScalar ;
163166 use databend_common_meta_app:: schema:: CatalogInfo ;
164167 use databend_common_meta_app:: schema:: DatabaseType ;
165168 use databend_common_meta_app:: schema:: TableIdent ;
166169 use databend_common_meta_app:: schema:: TableInfo ;
167170 use databend_common_meta_app:: schema:: TableMeta ;
168171
169172 use super :: * ;
173+ use crate :: ColumnBindingBuilder ;
174+ use crate :: Symbol ;
175+ use crate :: Visibility ;
170176 use crate :: planner:: metadata:: Metadata ;
177+ use crate :: plans:: Aggregate ;
178+ use crate :: plans:: AggregateFunction ;
179+ use crate :: plans:: AggregateMode ;
180+ use crate :: plans:: BoundColumnRef ;
181+ use crate :: plans:: ConstantExpr ;
182+ use crate :: plans:: EvalScalar ;
183+ use crate :: plans:: FunctionCall ;
171184 use crate :: plans:: Join ;
172185 use crate :: plans:: JoinType ;
173186 use crate :: plans:: RelOperator ;
187+ use crate :: plans:: ScalarExpr ;
188+ use crate :: plans:: ScalarItem ;
174189 use crate :: plans:: Scan ;
175190
176191 #[ derive( Debug ) ]
@@ -240,6 +255,62 @@ mod tests {
240255 } ) ) )
241256 }
242257
258+ fn column_expr ( metadata : & Metadata , table_index : usize ) -> ScalarExpr {
259+ let column = metadata. columns_by_table_index ( table_index) [ 0 ] . clone ( ) ;
260+ BoundColumnRef {
261+ span : None ,
262+ column : ColumnBindingBuilder :: new (
263+ column. name ( ) ,
264+ column. index ( ) ,
265+ Box :: new ( column. data_type ( ) ) ,
266+ Visibility :: Visible ,
267+ )
268+ . table_index ( Some ( table_index) )
269+ . build ( ) ,
270+ }
271+ . into ( )
272+ }
273+
274+ fn max_aggregate_expr (
275+ metadata : & Metadata ,
276+ table_index : usize ,
277+ output_index : Symbol ,
278+ with_group_by : bool ,
279+ ) -> SExpr {
280+ let group_items = if with_group_by {
281+ vec ! [ ScalarItem {
282+ scalar: column_expr( metadata, table_index) ,
283+ index: Symbol :: new( output_index. as_usize( ) + 1 ) ,
284+ } ]
285+ } else {
286+ vec ! [ ]
287+ } ;
288+
289+ SExpr :: create_unary (
290+ Arc :: new ( RelOperator :: Aggregate ( Aggregate {
291+ mode : AggregateMode :: Initial ,
292+ group_items,
293+ aggregate_functions : vec ! [ ScalarItem {
294+ scalar: ScalarExpr :: AggregateFunction ( AggregateFunction {
295+ span: None ,
296+ func_name: "max" . to_string( ) ,
297+ distinct: false ,
298+ params: vec![ ] ,
299+ args: vec![ column_expr( metadata, table_index) ] ,
300+ return_type: Box :: new( DataType :: Number ( NumberDataType :: UInt64 ) ) ,
301+ sort_descs: vec![ ] ,
302+ display_name: "max(a)" . to_string( ) ,
303+ } ) ,
304+ index: output_index,
305+ } ] ,
306+ from_distinct : false ,
307+ rank_limit : None ,
308+ grouping_sets : None ,
309+ } ) ) ,
310+ Arc :: new ( scan_expr ( metadata, table_index) ) ,
311+ )
312+ }
313+
243314 fn cross_join_expr ( left : SExpr , right : SExpr ) -> SExpr {
244315 SExpr :: create_binary (
245316 Arc :: new ( RelOperator :: Join ( Join {
@@ -251,6 +322,35 @@ mod tests {
251322 )
252323 }
253324
325+ fn eval_scalar_expr (
326+ metadata : & Metadata ,
327+ input : SExpr ,
328+ table_index : usize ,
329+ output_index : Symbol ,
330+ value : u64 ,
331+ ) -> SExpr {
332+ SExpr :: create_unary (
333+ Arc :: new ( RelOperator :: EvalScalar ( EvalScalar {
334+ items : vec ! [ ScalarItem {
335+ scalar: ScalarExpr :: FunctionCall ( FunctionCall {
336+ span: None ,
337+ func_name: "plus" . to_string( ) ,
338+ params: vec![ ] ,
339+ arguments: vec![
340+ column_expr( metadata, table_index) ,
341+ ScalarExpr :: ConstantExpr ( ConstantExpr {
342+ span: None ,
343+ value: Scalar :: Number ( NumberScalar :: UInt64 ( value) ) ,
344+ } ) ,
345+ ] ,
346+ } ) ,
347+ index: output_index,
348+ } ] ,
349+ } ) ) ,
350+ Arc :: new ( input) ,
351+ )
352+ }
353+
254354 #[ test]
255355 fn test_analyze_common_subexpression_prefers_cross_join_subtree ( ) {
256356 let mut metadata = Metadata :: default ( ) ;
@@ -313,4 +413,86 @@ mod tests {
313413 . all( |cte| matches!( cte. child( 0 ) . unwrap( ) . plan( ) , RelOperator :: Scan ( _) ) )
314414 ) ;
315415 }
416+
417+ #[ test]
418+ fn test_analyze_common_subexpression_matches_identical_aggregates ( ) {
419+ let mut metadata = Metadata :: default ( ) ;
420+ let t1 = fake_fuse_table ( 1 , "t1" ) ;
421+
422+ let t1_left = add_table ( & mut metadata, t1. clone ( ) ) ;
423+ let t1_right = add_table ( & mut metadata, t1) ;
424+
425+ let left = max_aggregate_expr ( & metadata, t1_left, Symbol :: new ( 10 ) , false ) ;
426+ let right = max_aggregate_expr ( & metadata, t1_right, Symbol :: new ( 11 ) , false ) ;
427+ let root = cross_join_expr ( left, right) ;
428+
429+ let ( replacements, materialized_ctes) =
430+ analyze_common_subexpression ( & root, & mut metadata) . unwrap ( ) ;
431+
432+ assert_eq ! ( replacements. len( ) , 2 ) ;
433+ assert_eq ! ( materialized_ctes. len( ) , 1 ) ;
434+ assert ! ( matches!(
435+ materialized_ctes[ 0 ] . child( 0 ) . unwrap( ) . plan( ) ,
436+ RelOperator :: Aggregate ( _)
437+ ) ) ;
438+ }
439+
440+ #[ test]
441+ fn test_analyze_common_subexpression_matches_identical_group_aggregates ( ) {
442+ let mut metadata = Metadata :: default ( ) ;
443+ let t1 = fake_fuse_table ( 1 , "t1" ) ;
444+
445+ let t1_left = add_table ( & mut metadata, t1. clone ( ) ) ;
446+ let t1_right = add_table ( & mut metadata, t1) ;
447+
448+ let left = max_aggregate_expr ( & metadata, t1_left, Symbol :: new ( 10 ) , true ) ;
449+ let right = max_aggregate_expr ( & metadata, t1_right, Symbol :: new ( 12 ) , true ) ;
450+ let root = cross_join_expr ( left, right) ;
451+
452+ let ( replacements, materialized_ctes) =
453+ analyze_common_subexpression ( & root, & mut metadata) . unwrap ( ) ;
454+
455+ assert_eq ! ( replacements. len( ) , 2 ) ;
456+ assert_eq ! ( materialized_ctes. len( ) , 1 ) ;
457+ assert ! ( matches!(
458+ materialized_ctes[ 0 ] . child( 0 ) . unwrap( ) . plan( ) ,
459+ RelOperator :: Aggregate ( _)
460+ ) ) ;
461+ }
462+
463+ #[ test]
464+ fn test_analyze_common_subexpression_does_not_materialize_eval_scalar_subtree ( ) {
465+ let mut metadata = Metadata :: default ( ) ;
466+ let t1 = fake_fuse_table ( 1 , "t1" ) ;
467+ let t2 = fake_fuse_table ( 2 , "t2" ) ;
468+
469+ let t1_left = add_table ( & mut metadata, t1. clone ( ) ) ;
470+ let t2_left = add_table ( & mut metadata, t2. clone ( ) ) ;
471+ let t1_right = add_table ( & mut metadata, t1) ;
472+ let t2_right = add_table ( & mut metadata, t2) ;
473+
474+ let left_input =
475+ cross_join_expr ( scan_expr ( & metadata, t1_left) , scan_expr ( & metadata, t2_left) ) ;
476+ let right_input = cross_join_expr (
477+ scan_expr ( & metadata, t1_right) ,
478+ scan_expr ( & metadata, t2_right) ,
479+ ) ;
480+ let left = eval_scalar_expr ( & metadata, left_input, t1_left, Symbol :: new ( 20 ) , 1 ) ;
481+ let right = eval_scalar_expr ( & metadata, right_input, t1_right, Symbol :: new ( 21 ) , 2 ) ;
482+ let root = cross_join_expr ( left, right) ;
483+
484+ let ( _replacements, materialized_ctes) =
485+ analyze_common_subexpression ( & root, & mut metadata) . unwrap ( ) ;
486+
487+ assert ! (
488+ materialized_ctes
489+ . iter( )
490+ . all( |cte| !contains_eval_scalar( cte. child( 0 ) . unwrap( ) ) )
491+ ) ;
492+ }
493+
494+ fn contains_eval_scalar ( expr : & SExpr ) -> bool {
495+ matches ! ( expr. plan( ) , RelOperator :: EvalScalar ( _) )
496+ || expr. children ( ) . any ( contains_eval_scalar)
497+ }
316498}
0 commit comments