Skip to content

Commit cc6db9c

Browse files
authored
Support projection in subquery with IN clause (#35)
1 parent 181a132 commit cc6db9c

2 files changed

Lines changed: 122 additions & 40 deletions

File tree

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 93 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3030
use datafusion_common::{internal_err, plan_err, Column, Result};
3131
use datafusion_expr::expr::{Exists, InSubquery};
3232
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
33-
use datafusion_expr::logical_plan::{JoinType, Subquery};
33+
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
3434
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
3535
use datafusion_expr::{
3636
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
@@ -66,54 +66,82 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
6666
})?
6767
.data;
6868

69-
let LogicalPlan::Filter(filter) = plan else {
70-
return Ok(Transformed::no(plan));
71-
};
72-
73-
if !has_subquery(&filter.predicate) {
74-
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
75-
}
69+
match plan {
70+
LogicalPlan::Filter(filter) => {
71+
if !has_subquery(&filter.predicate) {
72+
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
73+
}
7674

77-
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
78-
split_conjunction_owned(filter.predicate)
79-
.into_iter()
80-
.partition(has_subquery);
75+
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
76+
split_conjunction_owned(filter.predicate)
77+
.into_iter()
78+
.partition(has_subquery);
8179

82-
if with_subqueries.is_empty() {
83-
return internal_err!(
84-
"can not find expected subqueries in DecorrelatePredicateSubquery"
85-
);
86-
}
80+
if with_subqueries.is_empty() {
81+
return internal_err!(
82+
"can not find expected subqueries in DecorrelatePredicateSubquery"
83+
);
84+
}
8785

88-
// iterate through all exists clauses in predicate, turning each into a join
89-
let mut cur_input = Arc::unwrap_or_clone(filter.input);
90-
for subquery_expr in with_subqueries {
91-
match extract_subquery_info(subquery_expr) {
92-
// The subquery expression is at the top level of the filter
93-
SubqueryPredicate::Top(subquery) => {
94-
match build_join_top(&subquery, &cur_input, config.alias_generator())?
95-
{
96-
Some(plan) => cur_input = plan,
97-
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
98-
None => other_exprs.push(subquery.expr()),
86+
// iterate through all exists clauses in predicate, turning each into a join
87+
let mut cur_input = Arc::unwrap_or_clone(filter.input);
88+
for subquery_expr in with_subqueries {
89+
match extract_subquery_info(subquery_expr) {
90+
// The subquery expression is at the top level of the filter
91+
SubqueryPredicate::Top(subquery) => {
92+
match build_join_top(
93+
&subquery,
94+
&cur_input,
95+
config.alias_generator(),
96+
)? {
97+
Some(plan) => cur_input = plan,
98+
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
99+
None => other_exprs.push(subquery.expr()),
100+
}
101+
}
102+
// The subquery expression is embedded within another expression
103+
SubqueryPredicate::Embedded(expr) => {
104+
let (plan, expr_without_subqueries) =
105+
rewrite_inner_subqueries(cur_input, expr, config)?;
106+
cur_input = plan;
107+
other_exprs.push(expr_without_subqueries);
108+
}
99109
}
100110
}
101-
// The subquery expression is embedded within another expression
102-
SubqueryPredicate::Embedded(expr) => {
103-
let (plan, expr_without_subqueries) =
104-
rewrite_inner_subqueries(cur_input, expr, config)?;
105-
cur_input = plan;
106-
other_exprs.push(expr_without_subqueries);
111+
112+
let expr = conjunction(other_exprs);
113+
let mut new_plan = cur_input;
114+
if let Some(expr) = expr {
115+
let new_filter = Filter::try_new(expr, Arc::new(new_plan))?;
116+
new_plan = LogicalPlan::Filter(new_filter);
107117
}
118+
Ok(Transformed::yes(new_plan))
108119
}
109-
}
120+
LogicalPlan::Projection(proj) => {
121+
// Only proceed if any projection expression contains a subquery
122+
if !proj.expr.iter().any(has_subquery) {
123+
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
124+
}
110125

111-
let expr = conjunction(other_exprs);
112-
if let Some(expr) = expr {
113-
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
114-
cur_input = LogicalPlan::Filter(new_filter);
126+
let mut cur_input = Arc::unwrap_or_clone(proj.input);
127+
let mut new_exprs = Vec::with_capacity(proj.expr.len());
128+
for e in proj.expr {
129+
let old_name = e.schema_name().to_string();
130+
let (plan_after, rewritten) =
131+
rewrite_inner_subqueries(cur_input, e, config)?;
132+
cur_input = plan_after;
133+
let new_name = rewritten.schema_name().to_string();
134+
if new_name != old_name {
135+
new_exprs.push(rewritten.alias(old_name));
136+
} else {
137+
new_exprs.push(rewritten);
138+
}
139+
}
140+
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
141+
Ok(Transformed::yes(LogicalPlan::Projection(new_proj)))
142+
}
143+
other => Ok(Transformed::no(other)),
115144
}
116-
Ok(Transformed::yes(cur_input))
117145
}
118146

119147
fn name(&self) -> &str {
@@ -529,6 +557,31 @@ mod tests {
529557
assert_optimized_plan_equal(plan, expected)
530558
}
531559

560+
/// Projection IN (subquery) should be decorrelated via LeftMark join in Projection
561+
#[test]
562+
fn projection_in_subquery_simple() -> Result<()> {
563+
// Build outer values t(a) = (1),(2)
564+
let outer = LogicalPlanBuilder::values(vec![vec![lit(1_i32)], vec![lit(2_i32)]])?
565+
.project(vec![col("column1").alias("a")])?
566+
.build()?;
567+
568+
// Build subquery u(a) = (2)
569+
let sub = Arc::new(
570+
LogicalPlanBuilder::values(vec![vec![lit(2_i32)]])?
571+
.project(vec![col("column1").alias("ua")])?
572+
.build()?,
573+
);
574+
575+
let plan = LogicalPlanBuilder::from(outer)
576+
.project(vec![col("a"), in_subquery(col("a"), sub).alias("flag")])?
577+
.build()?;
578+
579+
// We expect a LeftMark join inserted and the projection keeps columns
580+
let expected = "Projection: a, __correlated_sq_1.mark AS flag [a:Int32;N, flag:Boolean]\n LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean]\n Projection: column1 AS a [a:Int32;N]\n Values: (Int32(1)), (Int32(2)) [column1:Int32;N]\n SubqueryAlias: __correlated_sq_1 [ua:Int32;N]\n Projection: column1 AS ua [ua:Int32;N]\n Values: (Int32(2)) [column1:Int32;N]";
581+
582+
assert_optimized_plan_equal(plan, expected)
583+
}
584+
532585
/// Test multiple correlated subqueries
533586
/// See subqueries.rs where_in_multiple()
534587
#[test]

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,3 +1482,32 @@ logical_plan
14821482

14831483
statement count 0
14841484
drop table person;
1485+
1486+
1487+
# Projection IN (subquery) decorrelation
1488+
query IB rowsort
1489+
WITH t(a) AS (VALUES (1),(2),(3),(4),(5)),
1490+
u(a) AS (VALUES (2),(4),(6))
1491+
SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t;
1492+
----
1493+
1 false
1494+
2 true
1495+
3 false
1496+
4 true
1497+
5 false
1498+
1499+
query TT
1500+
EXPLAIN WITH t(a) AS (VALUES (1),(2),(3),(4),(5)),
1501+
u(a) AS (VALUES (2),(4),(6))
1502+
SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t;
1503+
----
1504+
logical_plan
1505+
01)Projection: t.a, __correlated_sq_1.mark AS flag
1506+
02)--LeftMark Join: t.a = __correlated_sq_1.a
1507+
03)----SubqueryAlias: t
1508+
04)------Projection: column1 AS a
1509+
05)--------Values: (Int64(1)), (Int64(2)), (Int64(3)), (Int64(4)), (Int64(5))
1510+
06)----SubqueryAlias: __correlated_sq_1
1511+
07)------SubqueryAlias: u
1512+
08)--------Projection: column1 AS a
1513+
09)----------Values: (Int64(2)), (Int64(4)), (Int64(6))

0 commit comments

Comments
 (0)