Skip to content

Commit 86cb815

Browse files
Dandandanclaude
andauthored
[Minor] Remove redundant ProjectionExec nodes in sort-based plans (#20780)
## Which issue does this PR close? - Closes #. ## Rationale for this change ClickBench quueries (Q7, Q15, Q16, Q18) have some redundant projections for sorting based on count. Probably not a (measurable) improvement, but the plan looks better (in case of non-TopK it could probably be measurable). ## What changes are included in this PR? ## Are these changes tested? Existing tests. ## Are there any user-facing changes? --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 48199b9 commit 86cb815

File tree

4 files changed

+80
-84
lines changed

4 files changed

+80
-84
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,24 +3001,22 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
30013001
assert_snapshot!(
30023002
pretty_format_batches(&sql_results).unwrap(),
30033003
@r"
3004-
+---------------+------------------------------------------------------------------------------------------------------------+
3005-
| plan_type | plan |
3006-
+---------------+------------------------------------------------------------------------------------------------------------+
3007-
| logical_plan | Projection: t1.b, count(*) |
3008-
| | Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST |
3009-
| | Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1)) |
3010-
| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] |
3011-
| | TableScan: t1 projection=[b] |
3012-
| physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] |
3013-
| | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] |
3014-
| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |
3015-
| | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |
3016-
| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] |
3017-
| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 |
3018-
| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] |
3019-
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3020-
| | |
3021-
+---------------+------------------------------------------------------------------------------------------------------------+
3004+
+---------------+------------------------------------------------------------------------------------+
3005+
| plan_type | plan |
3006+
+---------------+------------------------------------------------------------------------------------+
3007+
| logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST |
3008+
| | Projection: t1.b, count(Int64(1)) AS count(*) |
3009+
| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] |
3010+
| | TableScan: t1 projection=[b] |
3011+
| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |
3012+
| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |
3013+
| | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*)] |
3014+
| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] |
3015+
| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 |
3016+
| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] |
3017+
| | DataSourceExec: partitions=1, partition_sizes=[1] |
3018+
| | |
3019+
+---------------+------------------------------------------------------------------------------------+
30223020
"
30233021
);
30243022

@@ -3028,7 +3026,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
30283026
+---------------+----------------------------------------------------------------------------+
30293027
| plan_type | plan |
30303028
+---------------+----------------------------------------------------------------------------+
3031-
| logical_plan | Sort: count(*) ASC NULLS LAST |
3029+
| logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST |
30323030
| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |
30333031
| | TableScan: t1 projection=[b] |
30343032
| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |

datafusion/expr/src/expr_rewriter/order_by.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ fn rewrite_in_terms_of_projection(
7777
// assumption is that each item in exprs, such as "b + c" is
7878
// available as an output column named "b + c"
7979
expr.transform(|expr| {
80-
// search for unnormalized names first such as "c1" (such as aliases)
81-
if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
80+
// search for unnormalized names first such as "c1" (such as aliases).
81+
// Also look inside aliases so e.g. `count(Int64(1))` matches
82+
// `count(Int64(1)) AS count(*)`.
83+
if let Some(found) = proj_exprs.iter().find(|a| expr_match(&expr, a)) {
8284
let (qualifier, field_name) = found.qualified_name();
8385
let col = Expr::Column(Column::new(qualifier, field_name));
8486
return Ok(Transformed::yes(col));
@@ -235,18 +237,22 @@ mod test {
235237
TestCase {
236238
desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
237239
input: sort(min(col("c2"))),
238-
expected: sort(col("min(t.c2)")),
240+
expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
239241
},
240242
TestCase {
241243
desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
242244
input: sort(col("c1") + min(col("c2"))),
243245
// should be "c1" not t.c1
244-
expected: sort(col("c1") + col("min(t.c2)")),
246+
expected: sort(
247+
col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")),
248+
),
245249
},
246250
TestCase {
247251
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
248252
input: sort(avg(col("c3"))),
249-
expected: sort(col("avg(t.c3)").alias("average")),
253+
expected: sort(
254+
Expr::Column(Column::new_unqualified("avg(t.c3)")).alias("average"),
255+
),
250256
},
251257
];
252258

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1984,7 +1984,7 @@ fn test_complex_order_by_with_grouping() -> Result<()> {
19841984
}, {
19851985
assert_snapshot!(
19861986
sql,
1987-
@r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string)) ORDER BY lochierarchy DESC NULLS FIRST, CASE WHEN (("grouping(j1.j1_id)" + "grouping(j1.j1_string)") = 0) THEN j1.j1_id END ASC NULLS LAST LIMIT 100"#
1987+
@"SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY lochierarchy DESC NULLS FIRST, CASE WHEN (lochierarchy = 0) THEN j1.j1_id END ASC NULLS LAST LIMIT 100"
19881988
);
19891989
});
19901990

0 commit comments

Comments
 (0)