Skip to content

Commit 6669e8b

Browse files
committed
fix: apply the left side schema on the right side in set expressions
1 parent d138c36 commit 6669e8b

File tree

7 files changed

+209
-10
lines changed

7 files changed

+209
-10
lines changed

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,23 @@ impl LogicalPlanBuilder {
593593
self,
594594
expr: Vec<(impl Into<SelectExpr>, bool)>,
595595
) -> Result<Self> {
596-
project_with_validation(Arc::unwrap_or_clone(self.plan), expr).map(Self::new)
596+
project_with_validation(Arc::unwrap_or_clone(self.plan), expr, None)
597+
.map(Self::new)
598+
}
599+
600+
/// Apply a projection, aliasing non-Column/non-Alias expressions to
601+
/// match the field names from the provided schema.
602+
pub fn project_with_validation_and_schema(
603+
self,
604+
expr: impl IntoIterator<Item = impl Into<SelectExpr>>,
605+
schema: &DFSchemaRef,
606+
) -> Result<Self> {
607+
project_with_validation(
608+
Arc::unwrap_or_clone(self.plan),
609+
expr.into_iter().map(|e| (e, true)),
610+
Some(schema),
611+
)
612+
.map(Self::new)
597613
}
598614

599615
/// Select the given column indices
@@ -1916,7 +1932,7 @@ pub fn project(
19161932
plan: LogicalPlan,
19171933
expr: impl IntoIterator<Item = impl Into<SelectExpr>>,
19181934
) -> Result<LogicalPlan> {
1919-
project_with_validation(plan, expr.into_iter().map(|e| (e, true)))
1935+
project_with_validation(plan, expr.into_iter().map(|e| (e, true)), None)
19201936
}
19211937

19221938
/// Create Projection. Similar to project except that the expressions
@@ -1929,6 +1945,7 @@ pub fn project(
19291945
fn project_with_validation(
19301946
plan: LogicalPlan,
19311947
expr: impl IntoIterator<Item = (impl Into<SelectExpr>, bool)>,
1948+
schema: Option<&DFSchemaRef>,
19321949
) -> Result<LogicalPlan> {
19331950
let mut projected_expr = vec![];
19341951
for (e, validate) in expr {
@@ -1984,6 +2001,17 @@ fn project_with_validation(
19842001
}
19852002
}
19862003
}
2004+
2005+
// When inside a set expression, alias non-Column/non-Alias expressions
2006+
// to match the left side's field names, avoiding duplicate name errors.
2007+
if let Some(schema) = &schema {
2008+
for (expr, field) in projected_expr.iter_mut().zip(schema.fields()) {
2009+
if !matches!(expr, Expr::Column(_) | Expr::Alias(_)) {
2010+
*expr = expr.clone().alias(field.name());
2011+
}
2012+
}
2013+
}
2014+
19872015
validate_unique_names("Projections", projected_expr.iter())?;
19882016

19892017
Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection)

datafusion/sql/src/planner.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ pub struct PlannerContext {
270270
outer_from_schema: Option<DFSchemaRef>,
271271
/// The query schema defined by the table
272272
create_table_schema: Option<DFSchemaRef>,
273+
/// When planning non-first queries in a set expression
274+
/// (UNION/INTERSECT/EXCEPT), holds the schema of the left-most query.
275+
/// Used to alias duplicate expressions to match the left side's field names.
276+
set_expr_left_schema: Option<DFSchemaRef>,
273277
}
274278

275279
impl Default for PlannerContext {
@@ -287,6 +291,7 @@ impl PlannerContext {
287291
outer_queries_schemas_stack: vec![],
288292
outer_from_schema: None,
289293
create_table_schema: None,
294+
set_expr_left_schema: None,
290295
}
291296
}
292297

@@ -400,6 +405,14 @@ impl PlannerContext {
400405
pub(super) fn remove_cte(&mut self, cte_name: &str) {
401406
self.ctes.remove(cte_name);
402407
}
408+
409+
/// Sets the left-most set expression schema, returning the previous value
410+
pub(super) fn set_set_expr_left_schema(
411+
&mut self,
412+
schema: Option<DFSchemaRef>,
413+
) -> Option<DFSchemaRef> {
414+
std::mem::replace(&mut self.set_expr_left_schema, schema)
415+
}
403416
}
404417

405418
/// SQL query planner and binder

datafusion/sql/src/query.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
152152
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
153153
let select_exprs =
154154
self.prepare_select_exprs(&plan, exprs, empty_from, planner_context)?;
155-
self.project(plan, select_exprs)
155+
self.project(plan, select_exprs, None)
156156
}
157157
PipeOperator::Extend { exprs } => {
158158
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
@@ -162,7 +162,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
162162
std::iter::once(SelectExpr::Wildcard(WildcardOptions::default()))
163163
.chain(extend_exprs)
164164
.collect();
165-
self.project(plan, all_exprs)
165+
self.project(plan, all_exprs, None)
166166
}
167167
PipeOperator::As { alias } => self.apply_table_alias(
168168
plan,

datafusion/sql/src/select.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::utils::{
2929

3030
use datafusion_common::error::DataFusionErrorBuilder;
3131
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
32-
use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err};
32+
use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, not_impl_err, plan_err};
3333
use datafusion_common::{RecursionUnnestOption, UnnestOptions};
3434
use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions};
3535
use datafusion_expr::expr_rewriter::{
@@ -90,6 +90,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
9090
return not_impl_err!("SORT BY");
9191
}
9292

93+
// Capture and clear set expression schema so it doesn't leak
94+
// into subqueries planned during FROM clause handling.
95+
let set_expr_left_schema = planner_context.set_set_expr_left_schema(None);
96+
9397
// Process `from` clause
9498
let plan = self.plan_from_tables(select.from, planner_context)?;
9599
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
@@ -110,7 +114,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
110114
)?;
111115

112116
// Having and group by clause may reference aliases defined in select projection
113-
let projected_plan = self.project(base_plan.clone(), select_exprs)?;
117+
let projected_plan =
118+
self.project(base_plan.clone(), select_exprs, set_expr_left_schema)?;
114119
let select_exprs = projected_plan.expressions();
115120

116121
let order_by =
@@ -879,18 +884,29 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
879884
&self,
880885
input: LogicalPlan,
881886
expr: Vec<SelectExpr>,
887+
set_expr_left_schema: Option<DFSchemaRef>,
882888
) -> Result<LogicalPlan> {
883889
// convert to Expr for validate_schema_satisfies_exprs
884-
let exprs = expr
890+
let plain_exprs = expr
885891
.iter()
886892
.filter_map(|e| match e {
887893
SelectExpr::Expression(expr) => Some(expr.to_owned()),
888894
_ => None,
889895
})
890896
.collect::<Vec<_>>();
891-
self.validate_schema_satisfies_exprs(input.schema(), &exprs)?;
892-
893-
LogicalPlanBuilder::from(input).project(expr)?.build()
897+
self.validate_schema_satisfies_exprs(input.schema(), &plain_exprs)?;
898+
899+
// When inside a set expression, pass the left-most schema so
900+
// that expressions get aliased to match, avoiding duplicate
901+
// name errors from expressions like `count(*), count(*)`.
902+
let builder = LogicalPlanBuilder::from(input);
903+
if let Some(left_schema) = set_expr_left_schema {
904+
builder
905+
.project_with_validation_and_schema(expr, &left_schema)?
906+
.build()
907+
} else {
908+
builder.project(expr)?.build()
909+
}
894910
}
895911

896912
/// Create an aggregate plan.

datafusion/sql/src/set_expr.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1921
use datafusion_common::{
2022
DataFusionError, Diagnostic, Result, Span, not_impl_err, plan_err,
@@ -42,7 +44,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
4244
let left_span = Span::try_from_sqlparser_span(left.span());
4345
let right_span = Span::try_from_sqlparser_span(right.span());
4446
let left_plan = self.set_expr_to_plan(*left, planner_context);
47+
// Store the left plan's schema so that the right side can
48+
// alias duplicate expressions to match. Skip for BY NAME
49+
// operations since those match columns by name, not position.
50+
if let Ok(plan) = &left_plan
51+
&& plan.schema().fields().len() > 1
52+
&& !matches!(
53+
set_quantifier,
54+
SetQuantifier::ByName
55+
| SetQuantifier::AllByName
56+
| SetQuantifier::DistinctByName
57+
)
58+
{
59+
planner_context
60+
.set_set_expr_left_schema(Some(Arc::clone(plan.schema())));
61+
}
4562
let right_plan = self.set_expr_to_plan(*right, planner_context);
63+
planner_context.set_set_expr_left_schema(None);
4664
let (left_plan, right_plan) = match (left_plan, right_plan) {
4765
(Ok(left_plan), Ok(right_plan)) => (left_plan, right_plan),
4866
(Err(left_err), Err(right_err)) => {

datafusion/sql/tests/sql_integration.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,6 +2655,106 @@ fn union_all_by_name_same_column_names() {
26552655
);
26562656
}
26572657

2658+
#[test]
2659+
fn union_all_with_duplicate_expressions() {
2660+
let sql = "\
2661+
SELECT 0 a, 0 b \
2662+
UNION ALL SELECT 1, 1 \
2663+
UNION ALL SELECT count(*), count(*) FROM orders";
2664+
let plan = logical_plan(sql).unwrap();
2665+
assert_snapshot!(
2666+
plan,
2667+
@r"
2668+
Union
2669+
Union
2670+
Projection: Int64(0) AS a, Int64(0) AS b
2671+
EmptyRelation: rows=1
2672+
Projection: Int64(1) AS a, Int64(1) AS b
2673+
EmptyRelation: rows=1
2674+
Projection: count(*) AS a, count(*) AS b
2675+
Aggregate: groupBy=[[]], aggr=[[count(*)]]
2676+
TableScan: orders
2677+
"
2678+
);
2679+
}
2680+
2681+
#[test]
2682+
fn union_with_qualified_and_duplicate_expressions() {
2683+
let sql = "\
2684+
SELECT 0 a, id b, price c, 0 d FROM test_decimal \
2685+
UNION SELECT 1, *, 1 FROM test_decimal";
2686+
let plan = logical_plan(sql).unwrap();
2687+
assert_snapshot!(
2688+
plan,
2689+
@"
2690+
Distinct:
2691+
Union
2692+
Projection: Int64(0) AS a, test_decimal.id AS b, test_decimal.price AS c, Int64(0) AS d
2693+
TableScan: test_decimal
2694+
Projection: Int64(1) AS a, test_decimal.id, test_decimal.price, Int64(1) AS d
2695+
TableScan: test_decimal
2696+
"
2697+
);
2698+
}
2699+
2700+
#[test]
2701+
fn intersect_with_duplicate_expressions() {
2702+
let sql = "\
2703+
SELECT 0 a, 0 b \
2704+
INTERSECT SELECT 1, 1 \
2705+
INTERSECT SELECT count(*), count(*) FROM orders";
2706+
let plan = logical_plan(sql).unwrap();
2707+
assert_snapshot!(
2708+
plan,
2709+
@r"
2710+
LeftSemi Join: left.a = right.a, left.b = right.b
2711+
Distinct:
2712+
SubqueryAlias: left
2713+
LeftSemi Join: left.a = right.a, left.b = right.b
2714+
Distinct:
2715+
SubqueryAlias: left
2716+
Projection: Int64(0) AS a, Int64(0) AS b
2717+
EmptyRelation: rows=1
2718+
SubqueryAlias: right
2719+
Projection: Int64(1) AS a, Int64(1) AS b
2720+
EmptyRelation: rows=1
2721+
SubqueryAlias: right
2722+
Projection: count(*) AS a, count(*) AS b
2723+
Aggregate: groupBy=[[]], aggr=[[count(*)]]
2724+
TableScan: orders
2725+
"
2726+
);
2727+
}
2728+
2729+
#[test]
2730+
fn except_with_duplicate_expressions() {
2731+
let sql = "\
2732+
SELECT 0 a, 0 b \
2733+
EXCEPT SELECT 1, 1 \
2734+
EXCEPT SELECT count(*), count(*) FROM orders";
2735+
let plan = logical_plan(sql).unwrap();
2736+
assert_snapshot!(
2737+
plan,
2738+
@r"
2739+
LeftAnti Join: left.a = right.a, left.b = right.b
2740+
Distinct:
2741+
SubqueryAlias: left
2742+
LeftAnti Join: left.a = right.a, left.b = right.b
2743+
Distinct:
2744+
SubqueryAlias: left
2745+
Projection: Int64(0) AS a, Int64(0) AS b
2746+
EmptyRelation: rows=1
2747+
SubqueryAlias: right
2748+
Projection: Int64(1) AS a, Int64(1) AS b
2749+
EmptyRelation: rows=1
2750+
SubqueryAlias: right
2751+
Projection: count(*) AS a, count(*) AS b
2752+
Aggregate: groupBy=[[]], aggr=[[count(*)]]
2753+
TableScan: orders
2754+
"
2755+
);
2756+
}
2757+
26582758
#[test]
26592759
fn empty_over() {
26602760
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";

datafusion/sqllogictest/test_files/union.slt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,30 @@ Bob_new
256256
John
257257
John_new
258258

259+
# Test UNION ALL with unaliased duplicate literal values on the right side.
260+
# The second projection will inherit field names from the first one, and so
261+
# pass the unique projection expression name check.
262+
query TII rowsort
263+
SELECT name, 1 as table, 1 as row FROM t1 WHERE id = 1
264+
UNION ALL
265+
SELECT name, 2, 2 FROM t2 WHERE id = 2
266+
----
267+
Alex 1 1
268+
Bob 2 2
269+
270+
# Test nested UNION, EXCEPT, INTERSECT with duplicate unaliased literals.
271+
# Only the first SELECT has column aliases, which should propagate to all projections.
272+
query III rowsort
273+
SELECT 1 as a, 0 as b, 0 as c
274+
UNION ALL
275+
((SELECT 2, 0, 0 UNION ALL SELECT 3, 0, 0) EXCEPT SELECT 3, 0, 0)
276+
UNION ALL
277+
(SELECT 4, 0, 0 INTERSECT SELECT 4, 0, 0)
278+
----
279+
1 0 0
280+
2 0 0
281+
4 0 0
282+
259283
# Plan is unnested
260284
query TT
261285
EXPLAIN SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2)

0 commit comments

Comments
 (0)