Skip to content

Commit 6562ea1

Browse files
sundy-liKKouldforsaken628
authored
fix(sql): handle correlated subqueries over union (#19607)
* fix(sql): preserve correlated union keys (#19574) * fix(sql): add missing optimizer physical goldens (#19607) * refine --------- Co-authored-by: kould <kould2333@gmail.com> Co-authored-by: coldWater <forsaken628@gmail.com>
1 parent b72171a commit 6562ea1

File tree

12 files changed

+506
-39
lines changed

12 files changed

+506
-39
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2021 Datafuse Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use databend_common_exception::Result;
16+
use databend_common_expression::DataBlock;
17+
use databend_query::interpreters::InterpreterFactory;
18+
use databend_query::sql::Planner;
19+
use databend_query::test_kits::TestFixture;
20+
use futures_util::TryStreamExt;
21+
22+
async fn execute_query_rows(sql: &str) -> Result<usize> {
23+
let fixture = TestFixture::setup().await?;
24+
let ctx = fixture.new_query_ctx().await?;
25+
26+
let mut planner = Planner::new(ctx.clone());
27+
let (plan, _) = planner.plan_sql(sql).await?;
28+
let interpreter = InterpreterFactory::get(ctx.clone(), &plan).await?;
29+
let stream = interpreter.execute(ctx).await?;
30+
let blocks: Vec<DataBlock> = stream.try_collect().await?;
31+
Ok(DataBlock::concat(&blocks)?.num_rows())
32+
}
33+
34+
#[tokio::test(flavor = "multi_thread")]
35+
async fn correlated_exists_subquery_over_union_regression() -> anyhow::Result<()> {
36+
let sql = r"
37+
SELECT *
38+
FROM (VALUES (1)) t(f1)
39+
WHERE EXISTS (
40+
SELECT 1
41+
UNION
42+
SELECT 2 WHERE f1 = 1
43+
);
44+
";
45+
46+
let rows = execute_query_rows(sql).await?;
47+
assert_eq!(rows, 1);
48+
49+
Ok(())
50+
}
51+
52+
#[tokio::test(flavor = "multi_thread")]
53+
async fn correlated_exists_subquery_over_union_all_regression() -> anyhow::Result<()> {
54+
let sql = r"
55+
SELECT *
56+
FROM (VALUES (1)) t(f1)
57+
WHERE EXISTS (
58+
SELECT 1
59+
UNION ALL
60+
SELECT 2 WHERE f1 = 1
61+
);
62+
";
63+
64+
let rows = execute_query_rows(sql).await?;
65+
assert_eq!(rows, 1);
66+
67+
Ok(())
68+
}

src/query/service/tests/it/sql/exec/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ pub async fn test_snapshot_consistency() -> anyhow::Result<()> {
167167
Ok(())
168168
}
169169

170+
mod correlated_subquery_regression;
170171
mod get_table_bind_test;
171172
mod range_join;
172173
mod window;

src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_plan.rs

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -855,59 +855,42 @@ impl SubqueryDecorrelatorOptimizer {
855855
need_cross_join = true;
856856
}
857857

858+
let rel_expr = RelExpr::with_s_expr(subquery);
859+
let left_prop = rel_expr.derive_relational_prop_child(0)?;
860+
let right_prop = rel_expr.derive_relational_prop_child(1)?;
861+
let left_need_cross_join =
862+
need_cross_join || !correlated_columns.is_subset(&left_prop.outer_columns);
863+
let right_need_cross_join =
864+
need_cross_join || !correlated_columns.is_subset(&right_prop.outer_columns);
865+
858866
let mut union_all = union_all.clone();
859867
let left_flatten_plan = self.flatten_plan(
860868
outer,
861869
subquery.left_child(),
862870
correlated_columns,
863871
flatten_info,
864-
need_cross_join,
872+
left_need_cross_join,
873+
)?;
874+
let left_derived = std::mem::take(&mut self.derived_columns);
875+
Self::rewrite_union_branch_outputs(
876+
&mut union_all.left_outputs,
877+
correlated_columns,
878+
&left_derived,
865879
)?;
866-
867-
union_all.left_outputs = union_all
868-
.left_outputs
869-
.drain(..)
870-
.map(|(old, mut expr)| {
871-
let Some(&new) = self.derived_columns.get(&old) else {
872-
return Ok((old, expr));
873-
};
874-
if let Some(expr) = &mut expr {
875-
expr.replace_column(old, new)?;
876-
};
877-
Ok((new, expr))
878-
})
879-
.chain(correlated_columns.iter().copied().map(|old| {
880-
let new = *self.derived_columns.get(&old).unwrap();
881-
Ok((new, None))
882-
}))
883-
.collect::<Result<_>>()?;
884-
self.derived_columns.clear();
885880

886881
let right_flatten_plan = self.flatten_plan(
887882
outer,
888883
subquery.right_child(),
889884
correlated_columns,
890885
flatten_info,
891-
need_cross_join,
886+
right_need_cross_join,
887+
)?;
888+
let right_derived = std::mem::take(&mut self.derived_columns);
889+
Self::rewrite_union_branch_outputs(
890+
&mut union_all.right_outputs,
891+
correlated_columns,
892+
&right_derived,
892893
)?;
893-
union_all.right_outputs = union_all
894-
.right_outputs
895-
.drain(..)
896-
.map(|(old, mut expr)| {
897-
let Some(&new) = self.derived_columns.get(&old) else {
898-
return Ok((old, expr));
899-
};
900-
if let Some(expr) = &mut expr {
901-
expr.replace_column(old, new)?;
902-
};
903-
Ok((new, expr))
904-
})
905-
.chain(correlated_columns.iter().map(|old| {
906-
let new = *self.derived_columns.get(old).unwrap();
907-
Ok((new, None))
908-
}))
909-
.collect::<Result<_>>()?;
910-
self.derived_columns.clear();
911894

912895
let mut metadata = self.metadata.write();
913896
union_all
@@ -928,6 +911,33 @@ impl SubqueryDecorrelatorOptimizer {
928911
))
929912
}
930913

914+
fn rewrite_union_branch_outputs(
915+
branch_outputs: &mut Vec<(Symbol, Option<ScalarExpr>)>,
916+
correlated_columns: &ColumnSet,
917+
derived: &HashMap<Symbol, Symbol>,
918+
) -> Result<()> {
919+
*branch_outputs = branch_outputs
920+
.drain(..)
921+
.map(|(old, mut expr)| {
922+
let Some(&new) = derived.get(&old) else {
923+
return Ok((old, expr));
924+
};
925+
if let Some(expr) = &mut expr {
926+
expr.replace_column(old, new)?;
927+
};
928+
Ok((new, expr))
929+
})
930+
.chain(correlated_columns.iter().copied().map(|old| {
931+
let new = derived
932+
.get(&old)
933+
.copied()
934+
.ok_or_else(|| ErrorCode::Internal(format!("Missing derived column {old}")))?;
935+
Ok((new, None))
936+
}))
937+
.collect::<Result<_>>()?;
938+
Ok(())
939+
}
940+
931941
fn flatten_sub_expression_scan(
932942
&mut self,
933943
subquery: &SExpr,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: "19574_correlated_exists_union"
2+
description: "Correlated EXISTS over UNION should decorrelate without panicking"
3+
4+
sql: |
5+
SELECT *
6+
FROM (VALUES (1)) t(f1)
7+
WHERE EXISTS (
8+
SELECT 1
9+
UNION
10+
SELECT 2 WHERE f1 = 1
11+
);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: "19574_correlated_exists_union_all"
2+
description: "Correlated EXISTS over UNION ALL should decorrelate without panicking"
3+
4+
sql: |
5+
SELECT *
6+
FROM (VALUES (1)) t(f1)
7+
WHERE EXISTS (
8+
SELECT 1
9+
UNION ALL
10+
SELECT 2 WHERE f1 = 1
11+
);
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
EvalScalar
2+
├── scalars: [f1 (#0) AS (#0), marker (#7) AS (#8)]
3+
└── Filter
4+
├── filters: [is_true(marker (#7))]
5+
└── Join(LeftMark)
6+
├── build keys: [f1 (#0)]
7+
├── probe keys: [f1 (#6)]
8+
├── other filters: []
9+
├── Exchange(Merge)
10+
│ └── ConstantTableScan
11+
│ ├── columns: [f1 (#0)]
12+
│ └── num_rows: [1]
13+
└── UnionAll
14+
├── output: [1 (#3), f1 (#6)]
15+
├── left: [1 (#1), f1 (#4)]
16+
├── right: [2 (#2), f1 (#5)]
17+
├── cte_scan_names: []
18+
├── logical_recursive_cte_id: None
19+
├── Join(Cross)
20+
│ ├── build keys: []
21+
│ ├── probe keys: []
22+
│ ├── other filters: []
23+
│ ├── EvalScalar
24+
│ │ ├── scalars: [1 AS (#1)]
25+
│ │ └── DummyTableScan(DummyTableScan { source_table_indexes: [] })
26+
│ └── Exchange(Merge)
27+
│ └── Aggregate(Final)
28+
│ ├── group items: [f1 (#4) AS (#4)]
29+
│ ├── aggregate functions: []
30+
│ └── Aggregate(Partial)
31+
│ ├── group items: [f1 (#4) AS (#4)]
32+
│ ├── aggregate functions: []
33+
│ └── Exchange(Hash)
34+
│ ├── Exchange(Hash): keys: [f1 (#4)]
35+
│ └── ConstantTableScan
36+
│ ├── columns: [f1 (#4)]
37+
│ └── num_rows: [1]
38+
└── EvalScalar
39+
├── scalars: [2 AS (#2), f1 (#5) AS (#5)]
40+
└── Join(Cross)
41+
├── build keys: []
42+
├── probe keys: []
43+
├── other filters: []
44+
├── DummyTableScan(DummyTableScan { source_table_indexes: [] })
45+
└── Exchange(Merge)
46+
└── Aggregate(Final)
47+
├── group items: [f1 (#5) AS (#5)]
48+
├── aggregate functions: []
49+
└── Aggregate(Partial)
50+
├── group items: [f1 (#5) AS (#5)]
51+
├── aggregate functions: []
52+
└── Exchange(Hash)
53+
├── Exchange(Hash): keys: [f1 (#5)]
54+
└── Filter
55+
├── filters: [eq(f1 (#5), 1)]
56+
└── ConstantTableScan
57+
├── columns: [f1 (#5)]
58+
└── num_rows: [1]
59+
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
Filter
2+
├── output columns: [f1 (#0)]
3+
├── filters: [is_true(7 (#7))]
4+
├── estimated rows: 0.20
5+
└── HashJoin
6+
├── output columns: [f1 (#0), marker (#7)]
7+
├── join type: LEFT MARK
8+
├── build keys: [f1 (#0)]
9+
├── probe keys: [f1 (#6)]
10+
├── keys is null equal: [false]
11+
├── filters: []
12+
├── estimated rows: 1.00
13+
├── Exchange(Build)
14+
│ ├── output columns: [f1 (#0)]
15+
│ ├── exchange type: Merge
16+
│ └── ConstantTableScan
17+
│ ├── output columns: [f1 (#0)]
18+
│ └── column 0: [1]
19+
└── UnionAll(Probe)
20+
├── output columns: [f1 (#6)]
21+
├── estimated rows: 2.00
22+
├── HashJoin
23+
│ ├── output columns: [f1 (#4)]
24+
│ ├── join type: CROSS
25+
│ ├── build keys: []
26+
│ ├── probe keys: []
27+
│ ├── keys is null equal: []
28+
│ ├── filters: []
29+
│ ├── estimated rows: 1.00
30+
│ ├── DummyTableScan(Build)
31+
│ └── Exchange(Probe)
32+
│ ├── output columns: [f1 (#4)]
33+
│ ├── exchange type: Merge
34+
│ └── AggregateFinal
35+
│ ├── output columns: [f1 (#4)]
36+
│ ├── group by: [f1]
37+
│ ├── aggregate functions: []
38+
│ ├── estimated rows: 1.00
39+
│ └── Exchange
40+
│ ├── output columns: [f1 (#4)]
41+
│ ├── exchange type: Hash(0)
42+
│ └── AggregatePartial
43+
│ ├── group by: [f1]
44+
│ ├── aggregate functions: []
45+
│ ├── estimated rows: 1.00
46+
│ └── ConstantTableScan
47+
│ ├── output columns: [f1 (#4)]
48+
│ └── column 0: [1]
49+
└── HashJoin
50+
├── output columns: [f1 (#5)]
51+
├── join type: CROSS
52+
├── build keys: []
53+
├── probe keys: []
54+
├── keys is null equal: []
55+
├── filters: []
56+
├── estimated rows: 1.00
57+
├── DummyTableScan(Build)
58+
└── Exchange(Probe)
59+
├── output columns: [f1 (#5)]
60+
├── exchange type: Merge
61+
└── AggregateFinal
62+
├── output columns: [f1 (#5)]
63+
├── group by: [f1]
64+
├── aggregate functions: []
65+
├── estimated rows: 1.00
66+
└── Exchange
67+
├── output columns: [f1 (#5)]
68+
├── exchange type: Hash(0)
69+
└── AggregatePartial
70+
├── group by: [f1]
71+
├── aggregate functions: []
72+
├── estimated rows: 1.00
73+
└── Filter
74+
├── output columns: [f1 (#5)]
75+
├── filters: [outer.f1 (#5) = 1]
76+
├── estimated rows: 1.00
77+
└── ConstantTableScan
78+
├── output columns: [f1 (#5)]
79+
└── column 0: [1]
80+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
EvalScalar
2+
├── scalars: [f1 (#0) AS (#0)]
3+
└── Filter
4+
├── filters: [SUBQUERY AS (#3)]
5+
├── subquerys
6+
│ └── Subquery (Exists)
7+
│ ├── output_column: 1 (#3)
8+
│ └── UnionAll
9+
│ ├── output: [1 (#3)]
10+
│ ├── left: [1 (#1)]
11+
│ ├── right: [2 (#2)]
12+
│ ├── cte_scan_names: []
13+
│ ├── logical_recursive_cte_id: None
14+
│ ├── EvalScalar
15+
│ │ ├── scalars: [1 AS (#1)]
16+
│ │ └── DummyTableScan(DummyTableScan { source_table_indexes: [] })
17+
│ └── EvalScalar
18+
│ ├── scalars: [2 AS (#2)]
19+
│ └── Filter
20+
│ ├── filters: [eq(f1 (#0), 1)]
21+
│ └── DummyTableScan(DummyTableScan { source_table_indexes: [] })
22+
└── ConstantTableScan
23+
├── columns: [f1 (#0)]
24+
└── num_rows: [1]
25+

0 commit comments

Comments
 (0)