diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 59502da987904..8c09b90191717 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -41,11 +41,37 @@ const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; const BENCHMARKS_PATH_2: &str = "./benchmarks/"; const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; -/// Create a logical plan from the specified sql +/// Create a logical plan from the specified sql (parse + analyze only, NO optimization) fn logical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { black_box(rt.block_on(ctx.sql(sql)).unwrap()); } +/// Parse SQL and run the analyzer to get an analyzed (but unoptimized) LogicalPlan. +/// This is the input to the optimizer. +fn analyzed_plan( + ctx: &SessionContext, + rt: &Runtime, + sql: &str, +) -> datafusion_expr::LogicalPlan { + let state = ctx.state(); + let plan = rt.block_on(state.create_logical_plan(sql)).unwrap(); + state + .analyzer() + .execute_and_check(plan, &state.config().options(), |_, _| {}) + .unwrap() +} + +/// Run ONLY the optimizer on a pre-analyzed plan. Measures optimizer cost in isolation. +fn optimize_plan(ctx: &SessionContext, plan: &datafusion_expr::LogicalPlan) { + let state = ctx.state(); + black_box( + state + .optimizer() + .optimize(plan.clone(), &state, |_, _| {}) + .unwrap(), + ); +} + /// Create a physical ExecutionPlan (by way of logical plan) fn physical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { black_box(rt.block_on(async { @@ -634,6 +660,455 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("with_param_values_many_columns", |b| { benchmark_with_param_values_many_columns(&ctx, &rt, b); }); + + // ========================================================================== + // Optimizer-focused benchmarks + // These benchmarks are designed to stress the logical optimizer with + // varying plan sizes, expression counts, and node type distributions. + // ========================================================================== + + // --- Deep join trees (many plan nodes, few expressions) --- + // Tests optimizer traversal cost as plan node count grows. + // Each join adds ~3 nodes (Join, TableScan, CrossJoin/Filter). + + // Register additional tables for join benchmarks + for i in 3..=16 { + ctx.register_table( + &format!("j{i}"), + create_table_provider("x", 10), + ) + .unwrap(); + } + + c.bench_function("logical_join_chain_4", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0", + ) + }) + }); + + c.bench_function("logical_join_chain_8", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0", + ) + }) + }); + + c.bench_function("logical_join_chain_16", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0 \ + JOIN j11 ON j10.x0 = j11.x0 \ + JOIN j12 ON j11.x0 = j12.x0 \ + JOIN j13 ON j12.x0 = j13.x0 \ + JOIN j14 ON j13.x0 = j14.x0 \ + JOIN j15 ON j14.x0 = j15.x0 \ + JOIN j16 ON j15.x0 = j16.x0 \ + JOIN j3 AS j3b ON j16.x0 = j3b.x0 \ + JOIN j4 AS j4b ON j3b.x0 = j4b.x0", + ) + }) + }); + + // --- Wide expressions (few plan nodes, many expressions) --- + // Tests expression processing overhead in optimizer rules like + // SimplifyExpressions, CommonSubexprEliminate, OptimizeProjections. + + // Many WHERE clauses (filter expressions) + { + let predicates: Vec = (0..50) + .map(|i| format!("a{i} > 0")) + .collect(); + let query = format!( + "SELECT a0 FROM t1 WHERE {}", + predicates.join(" AND ") + ); + c.bench_function("logical_wide_filter_50_predicates", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + { + let predicates: Vec = (0..200) + .map(|i| format!("a{i} > 0")) + .collect(); + let query = format!( + "SELECT a0 FROM t1 WHERE {}", + predicates.join(" AND ") + ); + c.bench_function("logical_wide_filter_200_predicates", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + // Many aggregate expressions + { + let aggs: Vec = (0..50) + .map(|i| format!("SUM(a{i}), AVG(a{i})")) + .collect(); + let query = format!("SELECT {} FROM t1", aggs.join(", ")); + c.bench_function("logical_wide_aggregate_100_exprs", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + // Many CASE WHEN expressions (complex expressions) + { + let cases: Vec = (0..50) + .map(|i| { + format!( + "CASE WHEN a{i} > 0 THEN a{i} * 2 ELSE a{i} + 1 END AS r{i}" + ) + }) + .collect(); + let query = format!("SELECT {} FROM t1", cases.join(", ")); + c.bench_function("logical_wide_case_50_exprs", |b| { + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + } + + // --- Mixed: deep plan + wide expressions --- + // This is the worst case for optimizer: many nodes AND many expressions. + + c.bench_function("logical_join_4_with_agg_and_filter", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0, SUM(j4.x1), AVG(j5.x2), COUNT(j6.x3), \ + MIN(j3.x4), MAX(j4.x5) \ + FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + WHERE j3.x1 > 0 AND j4.x2 < 100 AND j5.x3 != j6.x4 \ + GROUP BY j3.x0 \ + HAVING SUM(j4.x1) > 10 \ + ORDER BY j3.x0", + ) + }) + }); + + c.bench_function("logical_join_8_with_agg_sort_limit", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT j3.x0, j4.x1, j5.x2, \ + SUM(j6.x3), AVG(j7.x4), COUNT(j8.x5), \ + MIN(j9.x6), MAX(j10.x7) \ + FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0 \ + WHERE j3.x1 > 0 AND j5.x2 < 100 \ + GROUP BY j3.x0, j4.x1, j5.x2 \ + ORDER BY j3.x0 DESC \ + LIMIT 100", + ) + }) + }); + + // --- Subqueries (trigger decorrelation rules) --- + // Tests rules like DecorrelatePredicateSubquery, ScalarSubqueryToJoin. + + c.bench_function("logical_correlated_subquery_exists", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b0 = t1.a0)", + ) + }) + }); + + c.bench_function("logical_correlated_subquery_in", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE a0 IN (SELECT b0 FROM t2 WHERE t2.b1 = t1.a1)", + ) + }) + }); + + c.bench_function("logical_scalar_subquery", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, (SELECT MAX(b1) FROM t2 WHERE t2.b0 = t1.a0) AS max_b \ + FROM t1", + ) + }) + }); + + c.bench_function("logical_multiple_subqueries", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE a0 IN (SELECT b0 FROM t2 WHERE b1 > 0) \ + AND EXISTS (SELECT 1 FROM t2 WHERE t2.b0 = t1.a0 AND t2.b1 < 100) \ + AND a1 > (SELECT AVG(b1) FROM t2)", + ) + }) + }); + + // --- UNION queries (test OptimizeUnions, PropagateEmptyRelation) --- + + c.bench_function("logical_union_4_branches", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 WHERE a0 > 0 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 10 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 20 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 30", + ) + }) + }); + + c.bench_function("logical_union_8_branches", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 WHERE a0 > 0 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 10 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 20 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 30 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 40 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 50 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 60 \ + UNION ALL SELECT a0, a1 FROM t1 WHERE a0 > 70", + ) + }) + }); + + // --- DISTINCT (test ReplaceDistinctWithAggregate) --- + + c.bench_function("logical_distinct_many_columns", |b| { + let cols: Vec = (0..50).map(|i| format!("a{i}")).collect(); + let query = format!("SELECT DISTINCT {} FROM t1", cols.join(", ")); + b.iter(|| logical_plan(&ctx, &rt, &query)) + }); + + // --- Nested views / CTEs (deeper plan trees) --- + + c.bench_function("logical_nested_cte_4_levels", |b| { + b.iter(|| { + logical_plan( + &ctx, + &rt, + "WITH \ + cte1 AS (SELECT a0, a1, a2 FROM t1 WHERE a0 > 0), \ + cte2 AS (SELECT a0, a1 FROM cte1 WHERE a1 > 0), \ + cte3 AS (SELECT a0 FROM cte2 WHERE a0 < 100), \ + cte4 AS (SELECT a0, COUNT(*) AS cnt FROM cte3 GROUP BY a0) \ + SELECT * FROM cte4 ORDER BY a0 LIMIT 10", + ) + }) + }); + + // --- TPC-H logical plans (uncommented from existing code) --- + // These test real-world query patterns with moderate plan complexity. + + c.bench_function("logical_plan_tpch_all", |b| { + b.iter(|| { + for sql in &all_tpch_sql_queries { + logical_plan(&tpch_ctx, &rt, sql) + } + }) + }); + + c.bench_function("logical_plan_tpcds_all", |b| { + b.iter(|| { + for sql in &all_tpcds_sql_queries { + logical_plan(&tpcds_ctx, &rt, sql) + } + }) + }); + + // ========================================================================== + // Optimizer-only benchmarks + // These measure ONLY the optimizer, not SQL parsing or analysis. + // Plans are pre-parsed and pre-analyzed in setup, then only optimization + // is measured in the benchmark loop. + // ========================================================================== + + // Simple select (baseline: few nodes, few expressions) + { + let plan = analyzed_plan(&ctx, &rt, "SELECT c1 FROM t700"); + c.bench_function("optimizer_select_one_from_700", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Wide select (many expressions, few nodes) + { + let plan = analyzed_plan(&ctx, &rt, "SELECT * FROM t1000"); + c.bench_function("optimizer_select_all_from_1000", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Deep join chains (many nodes, few expressions) + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0", + ); + c.bench_function("optimizer_join_chain_4", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT j3.x0 FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + JOIN j7 ON j6.x0 = j7.x0 \ + JOIN j8 ON j7.x0 = j8.x0 \ + JOIN j9 ON j8.x0 = j9.x0 \ + JOIN j10 ON j9.x0 = j10.x0", + ); + c.bench_function("optimizer_join_chain_8", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Wide filter (many expressions) + { + let predicates: Vec = (0..200) + .map(|i| format!("a{i} > 0")) + .collect(); + let query = format!( + "SELECT a0 FROM t1 WHERE {}", + predicates.join(" AND ") + ); + let plan = analyzed_plan(&ctx, &rt, &query); + c.bench_function("optimizer_wide_filter_200", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Wide aggregate (many expressions) + { + let aggs: Vec = (0..50) + .map(|i| format!("SUM(a{i}), AVG(a{i})")) + .collect(); + let query = format!("SELECT {} FROM t1", aggs.join(", ")); + let plan = analyzed_plan(&ctx, &rt, &query); + c.bench_function("optimizer_wide_aggregate_100", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Subquery (tests decorrelation rules) + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT a0, a1 FROM t1 \ + WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b0 = t1.a0)", + ); + c.bench_function("optimizer_correlated_exists", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // Mixed: joins + aggregates + filter + { + let plan = analyzed_plan( + &ctx, + &rt, + "SELECT j3.x0, SUM(j4.x1), AVG(j5.x2), COUNT(j6.x3), \ + MIN(j3.x4), MAX(j4.x5) \ + FROM j3 \ + JOIN j4 ON j3.x0 = j4.x0 \ + JOIN j5 ON j4.x0 = j5.x0 \ + JOIN j6 ON j5.x0 = j6.x0 \ + WHERE j3.x1 > 0 AND j4.x2 < 100 AND j5.x3 != j6.x4 \ + GROUP BY j3.x0 \ + HAVING SUM(j4.x1) > 10 \ + ORDER BY j3.x0", + ); + c.bench_function("optimizer_join_4_with_agg_filter", |b| { + b.iter(|| optimize_plan(&ctx, &plan)) + }); + } + + // TPC-H all queries (optimizer only) + { + let plans: Vec<_> = all_tpch_sql_queries + .iter() + .map(|sql| analyzed_plan(&tpch_ctx, &rt, sql)) + .collect(); + c.bench_function("optimizer_tpch_all", |b| { + b.iter(|| { + for plan in &plans { + optimize_plan(&tpch_ctx, plan) + } + }) + }); + } + + // TPC-DS all queries (optimizer only) + { + let plans: Vec<_> = all_tpcds_sql_queries + .iter() + .map(|sql| analyzed_plan(&tpcds_ctx, &rt, sql)) + .collect(); + c.bench_function("optimizer_tpcds_all", |b| { + b.iter(|| { + for plan in &plans { + optimize_plan(&tpcds_ctx, plan) + } + }) + }); + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index a1285510da569..1cf162b0d6477 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -52,6 +52,7 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{Result, internal_err}; +use std::sync::Arc; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -393,6 +394,119 @@ macro_rules! handle_transform_recursion { } impl LogicalPlan { + /// Applies `f` to each child (input) of this plan node in place, + /// using [`Arc::make_mut`] for copy-on-write semantics. + /// + /// When the `Arc` refcount is 1 (the common case in the optimizer), + /// `Arc::make_mut` returns a `&mut` reference without cloning. + /// When the refcount is >1, it clones the inner value first. + /// + /// Returns `Ok(true)` if any child was modified by `f`. + pub fn map_children_mut Result>( + &mut self, + mut f: F, + ) -> Result { + Ok(match self { + LogicalPlan::Projection(Projection { input, .. }) + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Repartition(Repartition { input, .. }) + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Aggregate(Aggregate { input, .. }) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) + | LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) + | LogicalPlan::Analyze(Analyze { input, .. }) + | LogicalPlan::Dml(DmlStatement { input, .. }) + | LogicalPlan::Copy(CopyTo { input, .. }) + | LogicalPlan::Unnest(Unnest { input, .. }) => { + f(Arc::make_mut(input))? + } + LogicalPlan::Subquery(Subquery { subquery, .. }) => { + f(Arc::make_mut(subquery))? + } + LogicalPlan::Join(Join { left, right, .. }) => { + let l = f(Arc::make_mut(left))?; + let r = f(Arc::make_mut(right))?; + l || r + } + LogicalPlan::Union(Union { inputs, .. }) => { + let mut changed = false; + for input in inputs { + changed |= f(Arc::make_mut(input))?; + } + changed + } + LogicalPlan::Distinct(Distinct::All(input)) => { + f(Arc::make_mut(input))? + } + LogicalPlan::Distinct(Distinct::On(DistinctOn { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::Explain(Explain { plan, .. }) => { + f(Arc::make_mut(plan))? + } + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { + input, .. + })) + | LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + input, .. + })) => f(Arc::make_mut(input))?, + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => { + let s = f(Arc::make_mut(static_term))?; + let r = f(Arc::make_mut(recursive_term))?; + s || r + } + LogicalPlan::Statement(Statement::Prepare(p)) => { + f(Arc::make_mut(&mut p.input))? + } + LogicalPlan::Extension(Extension { node }) => { + let inputs = node.inputs(); + if inputs.is_empty() { + false + } else { + // Extension nodes don't expose mutable children, + // fall back to the ownership-based API + let mut changed = false; + let exprs = node.expressions(); + let new_inputs: Vec = inputs + .into_iter() + .map(|input| { + let mut plan = input.clone(); + if f(&mut plan)? { + changed = true; + } + Ok(plan) + }) + .collect::>>()?; + if changed { + *node = node + .with_exprs_and_inputs(exprs, new_inputs)?; + } + changed + } + } + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalog(_)) + | LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) + | LogicalPlan::Ddl(DdlStatement::DropTable(_)) + | LogicalPlan::Ddl(DdlStatement::DropView(_)) + | LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) + | LogicalPlan::Ddl(DdlStatement::DropFunction(_)) + | LogicalPlan::Statement(_) => false, + }) + } + /// Calls `f` on all expressions in the current `LogicalPlan` node. /// /// # Notes @@ -831,6 +945,32 @@ impl LogicalPlan { }) } + /// Returns true if any expression in this node contains a subquery + /// (Exists, InSubquery, SetComparison, or ScalarSubquery). + fn has_subquery_expressions(&self) -> bool { + let mut found = false; + let _ = self.apply_expressions(|expr| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + expr.apply(|e| { + if matches!( + e, + Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + }); + found + } + /// Similarly to [`Self::map_children`], rewrites all subqueries that may /// appear in expressions such as `IN (SELECT ...)` using `f`. /// @@ -839,6 +979,14 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + // Fast path: skip the expensive ownership-based expression traversal + // when this node has no subquery expressions. This avoids + // map_expressions → transform_down walking every expression node + // via consume+recreate just to find no subqueries. + if !self.has_subquery_expressions() { + return Ok(Transformed::no(self)); + } + self.map_expressions(|expr| { expr.transform_down(|expr| match expr { Expr::Exists(Exists { subquery, negated }) => { diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index bdea6a83072cd..e9267bdc9a2bd 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -28,8 +28,11 @@ use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{DFSchema, DataFusionError, HashSet, Result, internal_err}; +use datafusion_expr::Expr; use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; @@ -357,6 +360,104 @@ impl TreeNodeRewriter for Rewriter<'_> { } } +/// A cheap placeholder [`LogicalPlan`] for use with [`std::mem::replace`]. +/// +/// Used to temporarily extract an owned `LogicalPlan` from a `&mut` reference +/// so it can be passed to the ownership-based rule API (`rule.rewrite(plan)`). +fn placeholder_plan() -> LogicalPlan { + LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }) +} + +/// Rewrites a plan tree in place using `Arc::make_mut` for +/// copy-on-write semantics on `Arc` children. +/// +/// This avoids the `Arc::unwrap_or_clone` + `Arc::new` cycle that the +/// ownership-based `TreeNode::rewrite` performs at every child node. +/// When the `Arc` refcount is 1 (always true in the optimizer), +/// `Arc::make_mut` is essentially free. +/// +/// The `rule.rewrite()` API still takes ownership, so we bridge via +/// `mem::replace` with a cheap placeholder. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] +fn rewrite_plan_in_place( + plan: &mut LogicalPlan, + apply_order: ApplyOrder, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result { + // f_down phase + let mut changed = false; + if apply_order == ApplyOrder::TopDown { + let owned = std::mem::replace(plan, placeholder_plan()); + let result = rule.rewrite(owned, config)?; + *plan = result.data; + changed |= result.transformed; + // Respect TreeNodeRecursion::Stop/Jump from the rule + if result.tnr == TreeNodeRecursion::Stop { + return Ok(changed); + } + } + + // Recurse into children using Arc::make_mut (zero-cost when refcount == 1) + changed |= plan.map_children_mut(|child| { + rewrite_plan_in_place(child, apply_order, rule, config) + })?; + + // f_up phase + if apply_order == ApplyOrder::BottomUp { + let owned = std::mem::replace(plan, placeholder_plan()); + let result = rule.rewrite(owned, config)?; + *plan = result.data; + changed |= result.transformed; + } + + Ok(changed) +} + +/// Returns true if the plan contains any subquery expressions +/// (EXISTS, IN subquery, scalar subquery, set comparison). +/// +/// Used to determine whether the more expensive `rewrite_with_subqueries` +/// traversal is needed. When the plan has no subqueries, the cheaper +/// `rewrite` traversal is sufficient since all plan nodes are reachable +/// via direct children. +fn plan_has_subqueries(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply(|node| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + node.apply_expressions(|expr| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + expr.apply(|e| { + if matches!( + e, + Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + })?; + Ok(if found { + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }); + found +} + impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call @@ -386,6 +487,14 @@ impl Optimizer { while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); + // Check once per pass whether the plan contains subquery + // expressions. When there are no subqueries, we use the + // cheaper `rewrite` traversal instead of + // `rewrite_with_subqueries`, avoiding the per-node + // map_subqueries call that walks all expression trees + // via ownership-based transform_down. + let has_subqueries = plan_has_subqueries(&new_plan); + for rule in &self.rules { // If skipping failed rules, copy plan before attempting to rewrite // as rewriting is destructive @@ -398,9 +507,39 @@ impl Optimizer { let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite_with_subqueries( - &mut Rewriter::new(apply_order, rule.as_ref(), config), - ), + Some(apply_order) => { + if has_subqueries { + // Plans with subqueries need the full + // rewrite_with_subqueries traversal to + // recurse into subquery plans. + new_plan.rewrite_with_subqueries( + &mut Rewriter::new( + apply_order, + rule.as_ref(), + config, + ), + ) + } else { + // No subqueries: use in-place rewriting + // with Arc::make_mut for zero-cost CoW on + // children, avoiding Arc unwrap/rewrap. + rewrite_plan_in_place( + &mut new_plan, + apply_order, + rule.as_ref(), + config, + ) + .map(|transformed| { + Transformed::new_transformed( + std::mem::replace( + &mut new_plan, + placeholder_plan(), + ), + transformed, + ) + }) + } + } // rule handles recursion itself None => { rule.rewrite(new_plan, config)