@@ -353,18 +353,15 @@ impl ScalarExpression {
353353 if !self . only_column_ref {
354354 self . columns . push ( expr. output_column ( ) ) ;
355355 }
356- walk_expr ( self , expr)
356+ walk_expr ( self , expr)
357357 }
358-
359358 fn visit_column_ref ( & mut self , col : & ColumnRef ) -> Result < ( ) , DatabaseError > {
360359 if self . only_column_ref {
361360 self . columns . push ( col. clone ( ) ) ;
362-
363361 }
364362 Ok ( ( ) )
365363 }
366364 }
367-
368365 let mut collector = ColumnCollector {
369366 columns : Vec :: new ( ) ,
370367 only_column_ref,
@@ -374,213 +371,51 @@ impl ScalarExpression {
374371 }
375372
376373 pub fn has_table_ref_column ( & self ) -> bool {
377- match self {
378- ScalarExpression :: Constant ( _) => false ,
379- ScalarExpression :: ColumnRef ( column) => {
380- column. table_name ( ) . is_some ( ) && column. id ( ) . is_some ( )
381- }
382- ScalarExpression :: Alias { expr, .. } => expr. has_table_ref_column ( ) ,
383- ScalarExpression :: TypeCast { expr, .. } | ScalarExpression :: IsNull { expr, .. } => {
384- expr. has_table_ref_column ( )
385- }
386- ScalarExpression :: Unary { expr, .. } => expr. has_table_ref_column ( ) ,
387- ScalarExpression :: Binary {
388- left_expr,
389- right_expr,
390- ..
391- } => left_expr. has_table_ref_column ( ) || right_expr. has_table_ref_column ( ) ,
392- ScalarExpression :: AggCall { args, .. } => {
393- args. iter ( ) . any ( ScalarExpression :: has_table_ref_column)
394- }
395- ScalarExpression :: In { expr, args, .. } => {
396- expr. has_table_ref_column ( )
397- || args. iter ( ) . any ( ScalarExpression :: has_table_ref_column)
398- }
399- ScalarExpression :: Between {
400- expr,
401- left_expr,
402- right_expr,
403- ..
404- } => {
405- expr. has_table_ref_column ( )
406- || left_expr. has_table_ref_column ( )
407- || right_expr. has_table_ref_column ( )
408- }
409- ScalarExpression :: SubString {
410- expr,
411- for_expr,
412- from_expr,
413- } => {
414- expr. has_table_ref_column ( )
415- || for_expr
416- . as_deref ( )
417- . map ( ScalarExpression :: has_table_ref_column)
418- . unwrap_or ( false )
419- || from_expr
420- . as_deref ( )
421- . map ( ScalarExpression :: has_table_ref_column)
422- . unwrap_or ( false )
423- }
424- ScalarExpression :: Position { expr, in_expr } => {
425- expr. has_table_ref_column ( ) || in_expr. has_table_ref_column ( )
426- }
427- ScalarExpression :: Trim {
428- expr,
429- trim_what_expr,
430- ..
431- } => {
432- expr. has_table_ref_column ( )
433- || trim_what_expr
434- . as_deref ( )
435- . map ( ScalarExpression :: has_table_ref_column)
436- . unwrap_or ( false )
437- }
438- ScalarExpression :: Empty => false ,
439- ScalarExpression :: Reference { expr, .. } => expr. has_table_ref_column ( ) ,
440- ScalarExpression :: Tuple ( exprs) => {
441- exprs. iter ( ) . any ( ScalarExpression :: has_table_ref_column)
442- }
443- ScalarExpression :: ScalaFunction ( function) => function
444- . args
445- . iter ( )
446- . any ( ScalarExpression :: has_table_ref_column) ,
447- ScalarExpression :: TableFunction ( function) => function
448- . args
449- . iter ( )
450- . any ( ScalarExpression :: has_table_ref_column) ,
451- ScalarExpression :: If {
452- condition,
453- left_expr,
454- right_expr,
455- ..
456- } => {
457- condition. has_table_ref_column ( )
458- || left_expr. has_table_ref_column ( )
459- || right_expr. has_table_ref_column ( )
460- }
461- ScalarExpression :: IfNull {
462- left_expr,
463- right_expr,
464- ..
465- } => left_expr. has_table_ref_column ( ) || right_expr. has_table_ref_column ( ) ,
466- ScalarExpression :: NullIf {
467- left_expr,
468- right_expr,
469- ..
470- } => left_expr. has_table_ref_column ( ) || right_expr. has_table_ref_column ( ) ,
471- ScalarExpression :: Coalesce { exprs, .. } => {
472- exprs. iter ( ) . any ( ScalarExpression :: has_table_ref_column)
473- }
474- ScalarExpression :: CaseWhen {
475- operand_expr,
476- expr_pairs,
477- else_expr,
478- ..
479- } => {
480- operand_expr
481- . as_deref ( )
482- . map ( ScalarExpression :: has_table_ref_column)
483- . unwrap_or ( false )
484- || else_expr
485- . as_deref ( )
486- . map ( ScalarExpression :: has_table_ref_column)
487- . unwrap_or ( false )
488- || expr_pairs. iter ( ) . any ( |( left_expr, right_expr) | {
489- left_expr. has_table_ref_column ( ) || right_expr. has_table_ref_column ( )
490- } )
374+ struct TableRefChecker {
375+ found : bool ,
376+ }
377+ impl < ' a > Visitor < ' a > for TableRefChecker {
378+ fn visit_column_ref ( & mut self , col : & ColumnRef ) -> Result < ( ) , DatabaseError > {
379+ if col. table_name ( ) . is_some ( ) && col. id ( ) . is_some ( ) {
380+ self . found = true ;
381+ }
382+ Ok ( ( ) )
491383 }
492384 }
385+ let mut checker = TableRefChecker { found : false } ;
386+ checker. visit ( self ) . unwrap ( ) ;
387+ checker. found
493388 }
494389
390+
495391 pub fn has_agg_call ( & self ) -> bool {
496- match self {
497- ScalarExpression :: AggCall { .. } => true ,
498- ScalarExpression :: Constant ( _) => false ,
499- ScalarExpression :: ColumnRef ( _) => false ,
500- ScalarExpression :: Alias { expr, .. } => expr. has_agg_call ( ) ,
501- ScalarExpression :: TypeCast { expr, .. } => expr. has_agg_call ( ) ,
502- ScalarExpression :: IsNull { expr, .. } => expr. has_agg_call ( ) ,
503- ScalarExpression :: Unary { expr, .. } => expr. has_agg_call ( ) ,
504- ScalarExpression :: Binary {
505- left_expr,
506- right_expr,
507- ..
508- } => left_expr. has_agg_call ( ) || right_expr. has_agg_call ( ) ,
509- ScalarExpression :: In { expr, args, .. } => {
510- expr. has_agg_call ( ) || args. iter ( ) . any ( |arg| arg. has_agg_call ( ) )
511- }
512- ScalarExpression :: Between {
513- expr,
514- left_expr,
515- right_expr,
516- ..
517- } => expr. has_agg_call ( ) || left_expr. has_agg_call ( ) || right_expr. has_agg_call ( ) ,
518- ScalarExpression :: SubString {
519- expr,
520- for_expr,
521- from_expr,
522- } => {
523- expr. has_agg_call ( )
524- || matches ! (
525- for_expr. as_ref( ) . map( |expr| expr. has_agg_call( ) ) ,
526- Some ( true )
527- )
528- || matches ! (
529- from_expr. as_ref( ) . map( |expr| expr. has_agg_call( ) ) ,
530- Some ( true )
531- )
532- }
533- ScalarExpression :: Position { expr, in_expr } => {
534- expr. has_agg_call ( ) || in_expr. has_agg_call ( )
535- }
536- ScalarExpression :: Trim {
537- expr,
538- trim_what_expr,
539- ..
540- } => {
541- expr. has_agg_call ( )
542- || trim_what_expr. as_ref ( ) . map ( |expr| expr. has_agg_call ( ) ) == Some ( true )
543- }
544- ScalarExpression :: Reference { .. }
545- | ScalarExpression :: Empty
546- | ScalarExpression :: TableFunction ( _) => unreachable ! ( ) ,
547- ScalarExpression :: Tuple ( args)
548- | ScalarExpression :: ScalaFunction ( ScalarFunction { args, .. } )
549- | ScalarExpression :: Coalesce { exprs : args, .. } => args. iter ( ) . any ( Self :: has_agg_call) ,
550- ScalarExpression :: If {
551- condition,
552- left_expr,
553- right_expr,
554- ..
555- } => condition. has_agg_call ( ) || left_expr. has_agg_call ( ) || right_expr. has_agg_call ( ) ,
556- ScalarExpression :: IfNull {
557- left_expr,
558- right_expr,
559- ..
392+ struct AggCallChecker {
393+ has_agg : bool ,
394+ }
395+ impl < ' a > Visitor < ' a > for AggCallChecker {
396+ fn visit ( & mut self , expr : & ' a ScalarExpression ) -> Result < ( ) , DatabaseError > {
397+ if self . has_agg {
398+ return Ok ( ( ) ) ;
399+ }
400+ walk_expr ( self , expr)
560401 }
561- | ScalarExpression :: NullIf {
562- left_expr,
563- right_expr,
564- ..
565- } => left_expr. has_agg_call ( ) || right_expr. has_agg_call ( ) ,
566- ScalarExpression :: CaseWhen {
567- operand_expr,
568- expr_pairs,
569- else_expr,
570- ..
571- } => {
572- matches ! (
573- operand_expr. as_ref( ) . map( |expr| expr. has_agg_call( ) ) ,
574- Some ( true )
575- ) || expr_pairs
576- . iter ( )
577- . any ( |( expr_1, expr_2) | expr_1. has_agg_call ( ) || expr_2. has_agg_call ( ) )
578- || matches ! (
579- else_expr. as_ref( ) . map( |expr| expr. has_agg_call( ) ) ,
580- Some ( true )
581- )
402+ fn visit_agg ( & mut self ,
403+ _distinct : bool ,
404+ _kind : & ' a AggKind ,
405+ args : & ' a [ ScalarExpression ] ,
406+ _ty : & ' a LogicalType ) -> Result < ( ) , DatabaseError > {
407+ for arg in args {
408+ self . visit ( arg) ?;
409+ }
410+ self . has_agg = true ;
411+ Ok ( ( ) )
582412 }
413+
414+
583415 }
416+ let mut checker = AggCallChecker { has_agg : false } ;
417+ checker. visit ( self ) . unwrap ( ) ;
418+ checker. has_agg
584419 }
585420
586421 pub fn output_name ( & self ) -> String {
0 commit comments