Skip to content

Commit 04e44f0

Browse files
committed
reimplement referenced_columns & has_table_ref_column & has_agg_call using Visitor, solving issue #273
1 parent 18acb70 commit 04e44f0

1 file changed

Lines changed: 38 additions & 203 deletions

File tree

src/expression/mod.rs

Lines changed: 38 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -353,18 +353,15 @@ impl ScalarExpression {
353353
if !self.only_column_ref {
354354
self.columns.push(expr.output_column());
355355
}
356-
walk_expr(self, expr)
356+
walk_expr(self, expr)
357357
}
358-
359358
fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> {
360359
if self.only_column_ref {
361360
self.columns.push(col.clone());
362-
363361
}
364362
Ok(())
365363
}
366364
}
367-
368365
let mut collector = ColumnCollector {
369366
columns: Vec::new(),
370367
only_column_ref,
@@ -374,213 +371,51 @@ impl ScalarExpression {
374371
}
375372

376373
pub fn has_table_ref_column(&self) -> bool {
377-
match self {
378-
ScalarExpression::Constant(_) => false,
379-
ScalarExpression::ColumnRef(column) => {
380-
column.table_name().is_some() && column.id().is_some()
381-
}
382-
ScalarExpression::Alias { expr, .. } => expr.has_table_ref_column(),
383-
ScalarExpression::TypeCast { expr, .. } | ScalarExpression::IsNull { expr, .. } => {
384-
expr.has_table_ref_column()
385-
}
386-
ScalarExpression::Unary { expr, .. } => expr.has_table_ref_column(),
387-
ScalarExpression::Binary {
388-
left_expr,
389-
right_expr,
390-
..
391-
} => left_expr.has_table_ref_column() || right_expr.has_table_ref_column(),
392-
ScalarExpression::AggCall { args, .. } => {
393-
args.iter().any(ScalarExpression::has_table_ref_column)
394-
}
395-
ScalarExpression::In { expr, args, .. } => {
396-
expr.has_table_ref_column()
397-
|| args.iter().any(ScalarExpression::has_table_ref_column)
398-
}
399-
ScalarExpression::Between {
400-
expr,
401-
left_expr,
402-
right_expr,
403-
..
404-
} => {
405-
expr.has_table_ref_column()
406-
|| left_expr.has_table_ref_column()
407-
|| right_expr.has_table_ref_column()
408-
}
409-
ScalarExpression::SubString {
410-
expr,
411-
for_expr,
412-
from_expr,
413-
} => {
414-
expr.has_table_ref_column()
415-
|| for_expr
416-
.as_deref()
417-
.map(ScalarExpression::has_table_ref_column)
418-
.unwrap_or(false)
419-
|| from_expr
420-
.as_deref()
421-
.map(ScalarExpression::has_table_ref_column)
422-
.unwrap_or(false)
423-
}
424-
ScalarExpression::Position { expr, in_expr } => {
425-
expr.has_table_ref_column() || in_expr.has_table_ref_column()
426-
}
427-
ScalarExpression::Trim {
428-
expr,
429-
trim_what_expr,
430-
..
431-
} => {
432-
expr.has_table_ref_column()
433-
|| trim_what_expr
434-
.as_deref()
435-
.map(ScalarExpression::has_table_ref_column)
436-
.unwrap_or(false)
437-
}
438-
ScalarExpression::Empty => false,
439-
ScalarExpression::Reference { expr, .. } => expr.has_table_ref_column(),
440-
ScalarExpression::Tuple(exprs) => {
441-
exprs.iter().any(ScalarExpression::has_table_ref_column)
442-
}
443-
ScalarExpression::ScalaFunction(function) => function
444-
.args
445-
.iter()
446-
.any(ScalarExpression::has_table_ref_column),
447-
ScalarExpression::TableFunction(function) => function
448-
.args
449-
.iter()
450-
.any(ScalarExpression::has_table_ref_column),
451-
ScalarExpression::If {
452-
condition,
453-
left_expr,
454-
right_expr,
455-
..
456-
} => {
457-
condition.has_table_ref_column()
458-
|| left_expr.has_table_ref_column()
459-
|| right_expr.has_table_ref_column()
460-
}
461-
ScalarExpression::IfNull {
462-
left_expr,
463-
right_expr,
464-
..
465-
} => left_expr.has_table_ref_column() || right_expr.has_table_ref_column(),
466-
ScalarExpression::NullIf {
467-
left_expr,
468-
right_expr,
469-
..
470-
} => left_expr.has_table_ref_column() || right_expr.has_table_ref_column(),
471-
ScalarExpression::Coalesce { exprs, .. } => {
472-
exprs.iter().any(ScalarExpression::has_table_ref_column)
473-
}
474-
ScalarExpression::CaseWhen {
475-
operand_expr,
476-
expr_pairs,
477-
else_expr,
478-
..
479-
} => {
480-
operand_expr
481-
.as_deref()
482-
.map(ScalarExpression::has_table_ref_column)
483-
.unwrap_or(false)
484-
|| else_expr
485-
.as_deref()
486-
.map(ScalarExpression::has_table_ref_column)
487-
.unwrap_or(false)
488-
|| expr_pairs.iter().any(|(left_expr, right_expr)| {
489-
left_expr.has_table_ref_column() || right_expr.has_table_ref_column()
490-
})
374+
struct TableRefChecker {
375+
found: bool,
376+
}
377+
impl<'a> Visitor<'a> for TableRefChecker {
378+
fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> {
379+
if col.table_name().is_some() && col.id().is_some() {
380+
self.found = true;
381+
}
382+
Ok(())
491383
}
492384
}
385+
let mut checker = TableRefChecker { found: false };
386+
checker.visit(self).unwrap();
387+
checker.found
493388
}
494389

390+
495391
pub fn has_agg_call(&self) -> bool {
496-
match self {
497-
ScalarExpression::AggCall { .. } => true,
498-
ScalarExpression::Constant(_) => false,
499-
ScalarExpression::ColumnRef(_) => false,
500-
ScalarExpression::Alias { expr, .. } => expr.has_agg_call(),
501-
ScalarExpression::TypeCast { expr, .. } => expr.has_agg_call(),
502-
ScalarExpression::IsNull { expr, .. } => expr.has_agg_call(),
503-
ScalarExpression::Unary { expr, .. } => expr.has_agg_call(),
504-
ScalarExpression::Binary {
505-
left_expr,
506-
right_expr,
507-
..
508-
} => left_expr.has_agg_call() || right_expr.has_agg_call(),
509-
ScalarExpression::In { expr, args, .. } => {
510-
expr.has_agg_call() || args.iter().any(|arg| arg.has_agg_call())
511-
}
512-
ScalarExpression::Between {
513-
expr,
514-
left_expr,
515-
right_expr,
516-
..
517-
} => expr.has_agg_call() || left_expr.has_agg_call() || right_expr.has_agg_call(),
518-
ScalarExpression::SubString {
519-
expr,
520-
for_expr,
521-
from_expr,
522-
} => {
523-
expr.has_agg_call()
524-
|| matches!(
525-
for_expr.as_ref().map(|expr| expr.has_agg_call()),
526-
Some(true)
527-
)
528-
|| matches!(
529-
from_expr.as_ref().map(|expr| expr.has_agg_call()),
530-
Some(true)
531-
)
532-
}
533-
ScalarExpression::Position { expr, in_expr } => {
534-
expr.has_agg_call() || in_expr.has_agg_call()
535-
}
536-
ScalarExpression::Trim {
537-
expr,
538-
trim_what_expr,
539-
..
540-
} => {
541-
expr.has_agg_call()
542-
|| trim_what_expr.as_ref().map(|expr| expr.has_agg_call()) == Some(true)
543-
}
544-
ScalarExpression::Reference { .. }
545-
| ScalarExpression::Empty
546-
| ScalarExpression::TableFunction(_) => unreachable!(),
547-
ScalarExpression::Tuple(args)
548-
| ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
549-
| ScalarExpression::Coalesce { exprs: args, .. } => args.iter().any(Self::has_agg_call),
550-
ScalarExpression::If {
551-
condition,
552-
left_expr,
553-
right_expr,
554-
..
555-
} => condition.has_agg_call() || left_expr.has_agg_call() || right_expr.has_agg_call(),
556-
ScalarExpression::IfNull {
557-
left_expr,
558-
right_expr,
559-
..
392+
struct AggCallChecker {
393+
has_agg: bool,
394+
}
395+
impl<'a> Visitor<'a> for AggCallChecker {
396+
fn visit(&mut self, expr: &'a ScalarExpression) -> Result<(), DatabaseError> {
397+
if self.has_agg {
398+
return Ok(());
399+
}
400+
walk_expr(self, expr)
560401
}
561-
| ScalarExpression::NullIf {
562-
left_expr,
563-
right_expr,
564-
..
565-
} => left_expr.has_agg_call() || right_expr.has_agg_call(),
566-
ScalarExpression::CaseWhen {
567-
operand_expr,
568-
expr_pairs,
569-
else_expr,
570-
..
571-
} => {
572-
matches!(
573-
operand_expr.as_ref().map(|expr| expr.has_agg_call()),
574-
Some(true)
575-
) || expr_pairs
576-
.iter()
577-
.any(|(expr_1, expr_2)| expr_1.has_agg_call() || expr_2.has_agg_call())
578-
|| matches!(
579-
else_expr.as_ref().map(|expr| expr.has_agg_call()),
580-
Some(true)
581-
)
402+
fn visit_agg(&mut self,
403+
_distinct: bool,
404+
_kind: &'a AggKind,
405+
args: &'a [ScalarExpression],
406+
_ty: &'a LogicalType) -> Result<(), DatabaseError> {
407+
for arg in args {
408+
self.visit(arg)?;
409+
}
410+
self.has_agg = true;
411+
Ok(())
582412
}
413+
414+
583415
}
416+
let mut checker = AggCallChecker { has_agg: false };
417+
checker.visit(self).unwrap();
418+
checker.has_agg
584419
}
585420

586421
pub fn output_name(&self) -> String {

0 commit comments

Comments
 (0)