Skip to content

Commit c17c87c

Browse files
Dandandanclaude
andauthored
Eliminate redundant ProjectionExecs (#21333)
## Which issue does this PR close? ## Rationale for this change ## What changes are included in this PR? ## How are these changes tested? - All existing sqllogictests pass (419 files) including TPC-H - Updated test expectations to reflect eliminated `ProjectionExec` nodes - Net reduction of 72 lines across test files (fewer plan operators) ## Are there any user-facing changes? No, only more efficient plans 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 422129c commit c17c87c

22 files changed

+470
-542
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3429,31 +3429,30 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
34293429
assert_snapshot!(
34303430
pretty_format_batches(&sql_results).unwrap(),
34313431
@r"
3432-
+---------------+----------------------------------------------------------------------------------------------------------------------------+
3433-
| plan_type | plan |
3434-
+---------------+----------------------------------------------------------------------------------------------------------------------------+
3435-
| logical_plan | Projection: t1.a, t1.b |
3436-
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
3437-
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
3438-
| | Left Join: t1.a = __scalar_sq_1.a |
3439-
| | TableScan: t1 projection=[a, b] |
3440-
| | SubqueryAlias: __scalar_sq_1 |
3441-
| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |
3442-
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |
3443-
| | TableScan: t2 projection=[a] |
3444-
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
3445-
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
3446-
| | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] |
3447-
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] |
3448-
| | CoalescePartitionsExec |
3449-
| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |
3450-
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |
3451-
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
3452-
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |
3453-
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3454-
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3455-
| | |
3456-
+---------------+----------------------------------------------------------------------------------------------------------------------------+
3432+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
3433+
| plan_type | plan |
3434+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
3435+
| logical_plan | Projection: t1.a, t1.b |
3436+
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
3437+
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
3438+
| | Left Join: t1.a = __scalar_sq_1.a |
3439+
| | TableScan: t1 projection=[a, b] |
3440+
| | SubqueryAlias: __scalar_sq_1 |
3441+
| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |
3442+
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |
3443+
| | TableScan: t2 projection=[a] |
3444+
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
3445+
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
3446+
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] |
3447+
| | CoalescePartitionsExec |
3448+
| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |
3449+
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |
3450+
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
3451+
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |
3452+
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3453+
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3454+
| | |
3455+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
34573456
"
34583457
);
34593458

@@ -3485,31 +3484,30 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
34853484
assert_snapshot!(
34863485
pretty_format_batches(&df_results).unwrap(),
34873486
@r"
3488-
+---------------+----------------------------------------------------------------------------------------------------------------------------+
3489-
| plan_type | plan |
3490-
+---------------+----------------------------------------------------------------------------------------------------------------------------+
3491-
| logical_plan | Projection: t1.a, t1.b |
3492-
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
3493-
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
3494-
| | Left Join: t1.a = __scalar_sq_1.a |
3495-
| | TableScan: t1 projection=[a, b] |
3496-
| | SubqueryAlias: __scalar_sq_1 |
3497-
| | Projection: count(*), t2.a, Boolean(true) AS __always_true |
3498-
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |
3499-
| | TableScan: t2 projection=[a] |
3500-
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
3501-
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
3502-
| | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] |
3503-
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] |
3504-
| | CoalescePartitionsExec |
3505-
| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |
3506-
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |
3507-
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
3508-
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |
3509-
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3510-
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3511-
| | |
3512-
+---------------+----------------------------------------------------------------------------------------------------------------------------+
3487+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
3488+
| plan_type | plan |
3489+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
3490+
| logical_plan | Projection: t1.a, t1.b |
3491+
| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |
3492+
| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |
3493+
| | Left Join: t1.a = __scalar_sq_1.a |
3494+
| | TableScan: t1 projection=[a, b] |
3495+
| | SubqueryAlias: __scalar_sq_1 |
3496+
| | Projection: count(*), t2.a, Boolean(true) AS __always_true |
3497+
| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |
3498+
| | TableScan: t2 projection=[a] |
3499+
| physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |
3500+
| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |
3501+
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] |
3502+
| | CoalescePartitionsExec |
3503+
| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |
3504+
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |
3505+
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
3506+
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |
3507+
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3508+
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3509+
| | |
3510+
+---------------+--------------------------------------------------------------------------------------------------------------------------+
35133511
"
35143512
);
35153513

datafusion/core/tests/physical_optimizer/projection_pushdown.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,8 +1306,8 @@ fn test_hash_join_after_projection() -> Result<()> {
13061306
assert_snapshot!(
13071307
actual,
13081308
@r"
1309-
ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]
1310-
HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]
1309+
ProjectionExec: expr=[c@0 as c_from_left, b@1 as b_from_left, a@2 as a_from_left, c@3 as c_from_right]
1310+
HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[c@2, b@1, a@0, c@7]
13111311
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false
13121312
DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false
13131313
"

datafusion/physical-plan/src/projection.rs

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ use datafusion_execution::TaskContext;
5050
use datafusion_expr::ExpressionPlacement;
5151
use datafusion_physical_expr::equivalence::ProjectionMapping;
5252
use datafusion_physical_expr::projection::Projector;
53-
use datafusion_physical_expr::utils::collect_columns;
5453
use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
5554
use datafusion_physical_expr_common::sort_expr::{
5655
LexOrdering, LexRequirement, PhysicalSortExpr,
@@ -602,13 +601,7 @@ pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
602601
return Ok(None);
603602
};
604603

605-
// If the projection indices is the same as the input columns, we don't need to embed the projection to hash join.
606-
// Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields.
607-
if projection_index.len() == projection_index.last().unwrap() + 1
608-
&& projection_index.len() == execution_plan.schema().fields().len()
609-
{
610-
return Ok(None);
611-
}
604+
let columns_reduced = projection_index.len() < execution_plan.schema().fields().len();
612605

613606
let new_execution_plan =
614607
Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?);
@@ -643,9 +636,16 @@ pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
643636
Arc::clone(&new_execution_plan) as _,
644637
)?);
645638
if is_projection_removable(&new_projection) {
639+
// Residual is identity — embedding fully absorbed the projection.
646640
Ok(Some(new_execution_plan))
647-
} else {
641+
} else if columns_reduced {
642+
// Embedding reduced columns even though a residual is still needed
643+
// for renames or expressions — worth keeping.
648644
Ok(Some(new_projection))
645+
} else {
646+
// No columns eliminated and residual still needed — embedding just
647+
// adds an unnecessary column reorder inside the operator.
648+
Ok(None)
649649
}
650650
}
651651

@@ -1074,15 +1074,37 @@ fn try_unifying_projections(
10741074

10751075
/// Collect all column indices from the given projection expressions.
10761076
fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec<usize> {
1077-
// Collect indices and remove duplicates.
1078-
let mut indices = exprs
1079-
.iter()
1080-
.flat_map(|proj_expr| collect_columns(&proj_expr.expr))
1081-
.map(|x| x.index())
1082-
.collect::<std::collections::HashSet<_>>()
1083-
.into_iter()
1084-
.collect::<Vec<_>>();
1085-
indices.sort();
1077+
// Collect column indices in a deterministic order that preserves the
1078+
// projection's column ordering. For simple Column expressions, we use
1079+
// the column index directly. For complex expressions, we walk the
1080+
// expression tree to collect column references in traversal order.
1081+
// This allows the embedded projection to match the desired output
1082+
// column order, avoiding a residual ProjectionExec.
1083+
let mut seen = std::collections::HashSet::new();
1084+
let mut indices = Vec::new();
1085+
for proj_expr in exprs {
1086+
if let Some(col) = proj_expr.expr.as_any().downcast_ref::<Column>() {
1087+
// Simple column reference: preserve projection order.
1088+
if seen.insert(col.index()) {
1089+
indices.push(col.index());
1090+
}
1091+
} else {
1092+
// Complex expression: collect all referenced columns in
1093+
// expression tree traversal order (deterministic) to preserve
1094+
// the natural ordering of column references.
1095+
proj_expr
1096+
.expr
1097+
.apply(|expr| {
1098+
if let Some(col) = expr.as_any().downcast_ref::<Column>()
1099+
&& seen.insert(col.index())
1100+
{
1101+
indices.push(col.index());
1102+
}
1103+
Ok(TreeNodeRecursion::Continue)
1104+
})
1105+
.expect("closure always returns OK");
1106+
}
1107+
}
10861108
indices
10871109
}
10881110

@@ -1196,7 +1218,8 @@ mod tests {
11961218
expr,
11971219
alias: "b-(1+a)".to_string(),
11981220
}]);
1199-
assert_eq!(column_indices, vec![1, 7]);
1221+
// Tree traversal order: b@7 is visited before a@1
1222+
assert_eq!(column_indices, vec![7, 1]);
12001223
Ok(())
12011224
}
12021225

0 commit comments

Comments
 (0)