Skip to content

Commit 1409ec7

Browse files
committed
Using the RBO method, the related subquery is implemented
1 parent b710db4 commit 1409ec7

9 files changed

Lines changed: 395 additions & 9 deletions

File tree

src/binder/expr.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use sqlparser::ast::{
77
BinaryOperator, CharLengthUnits, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident,
88
Query, UnaryOperator, Value,
99
};
10-
use std::collections::HashMap;
10+
use std::collections::{HashMap, HashSet};
1111
use std::slice;
1212
use std::sync::Arc;
1313

@@ -293,6 +293,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
293293
self.args,
294294
Some(self),
295295
);
296+
296297
let mut sub_query = binder.bind_query(subquery)?;
297298
let sub_query_schema = sub_query.output_schema();
298299

@@ -368,7 +369,15 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
368369
try_default!(&full_name.0, full_name.1);
369370
}
370371
if let Some(table) = full_name.0.or(bind_table_name) {
371-
let source = self.context.bind_source(&table)?;
372+
let (source, is_parent) = self.context.bind_source::<A>(self.parent, &table, false)?;
373+
374+
if is_parent {
375+
self.parent_table_col
376+
.entry(Arc::new(table.clone()))
377+
.or_default()
378+
.insert(full_name.1.clone());
379+
}
380+
372381
let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default();
373382

374383
Ok(ScalarExpression::ColumnRef(

src/binder/mod.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
2626
use std::sync::Arc;
2727

2828
use crate::catalog::view::View;
29-
use crate::catalog::{ColumnRef, TableCatalog, TableName};
29+
use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName};
3030
use crate::db::{ScalaFunctions, TableFunctions};
3131
use crate::errors::DatabaseError;
3232
use crate::expression::ScalarExpression;
@@ -276,12 +276,19 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
276276
Ok(source)
277277
}
278278

279-
pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source, DatabaseError> {
279+
pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]>>(
280+
&self,
281+
parent: Option<&'a Binder<'a, 'b, T, A>>,
282+
table_name: &str,
283+
is_parent: bool,
284+
) -> Result<(&'b Source, bool), DatabaseError> {
280285
if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| {
281286
t.as_str() == table_name
282287
|| matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true))
283288
}) {
284-
Ok(source.1)
289+
Ok((source.1, is_parent))
290+
} else if let Some(binder) = parent {
291+
binder.context.bind_source(binder.parent, table_name, true)
285292
} else {
286293
Err(DatabaseError::InvalidTable(table_name.into()))
287294
}
@@ -323,6 +330,7 @@ pub struct Binder<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>>
323330
args: &'a A,
324331
with_pk: Option<TableName>,
325332
pub(crate) parent: Option<&'b Binder<'a, 'b, T, A>>,
333+
pub(crate) parent_table_col: HashMap<TableName, HashSet<String>>,
326334
}
327335

328336
impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> {
@@ -337,6 +345,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '
337345
args,
338346
with_pk: None,
339347
parent,
348+
parent_table_col: Default::default(),
340349
}
341350
}
342351

src/db.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ impl<S: Storage> State<S> {
173173

174174
pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer {
175175
HepOptimizer::new(source_plan)
176+
.batch(
177+
"Correlated Subquery".to_string(),
178+
HepBatchStrategy::once_topdown(),
179+
vec![NormalizationRuleImpl::CorrelateSubquery],
180+
)
176181
.batch(
177182
"Column Pruning".to_string(),
178183
HepBatchStrategy::once_topdown(),

src/optimizer/heuristic/graph.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl HepGraph {
7979
source_id: HepNodeId,
8080
children_option: Option<HepNodeId>,
8181
new_node: Operator,
82-
) {
82+
) -> HepNodeId {
8383
let new_index = self.graph.add_node(new_node);
8484
let mut order = self.graph.edges(source_id).count();
8585

@@ -95,6 +95,7 @@ impl HepGraph {
9595

9696
self.graph.add_edge(source_id, new_index, order);
9797
self.version += 1;
98+
new_index
9899
}
99100

100101
pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) {
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
use crate::catalog::{ColumnRef, TableName};
2+
use crate::errors::DatabaseError;
3+
use crate::expression::visitor::Visitor;
4+
use crate::expression::HasCountStar;
5+
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
6+
use crate::optimizer::core::rule::{MatchPattern, NormalizationRule};
7+
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
8+
use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType};
9+
use crate::planner::operator::table_scan::TableScanOperator;
10+
use crate::planner::operator::Operator;
11+
use crate::planner::operator::Operator::{Join, TableScan};
12+
use crate::types::index::IndexInfo;
13+
use crate::types::ColumnId;
14+
use itertools::Itertools;
15+
use std::collections::BTreeMap;
16+
use std::collections::{HashMap, HashSet};
17+
use std::sync::{Arc, LazyLock};
18+
19+
static CORRELATED_SUBQUERY_RULE: LazyLock<Pattern> = LazyLock::new(|| Pattern {
20+
predicate: |op| matches!(op, Join(_)),
21+
children: PatternChildrenPredicate::None,
22+
});
23+
24+
#[derive(Clone)]
25+
pub struct CorrelatedSubquery;
26+
27+
macro_rules! trans_references {
28+
($columns:expr) => {{
29+
let mut column_references = HashSet::with_capacity($columns.len());
30+
for column in $columns {
31+
column_references.insert(column);
32+
}
33+
column_references
34+
}};
35+
}
36+
37+
impl CorrelatedSubquery {
38+
fn _apply(
39+
column_references: HashSet<&ColumnRef>,
40+
scan_columns: HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>,
41+
node_id: HepNodeId,
42+
graph: &mut HepGraph,
43+
) -> Result<
44+
HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>,
45+
DatabaseError,
46+
> {
47+
let operator = &graph.operator(node_id).clone();
48+
49+
match operator {
50+
Operator::Aggregate(op) => {
51+
let is_distinct = op.is_distinct;
52+
let referenced_columns = operator.referenced_columns(false);
53+
let mut new_column_references = trans_references!(&referenced_columns);
54+
// on distinct
55+
if is_distinct {
56+
for summary in column_references {
57+
new_column_references.insert(summary);
58+
}
59+
}
60+
61+
Self::recollect_apply(new_column_references, scan_columns, node_id, graph)
62+
}
63+
Operator::Project(op) => {
64+
let mut has_count_star = HasCountStar::default();
65+
for expr in &op.exprs {
66+
has_count_star.visit(expr)?;
67+
}
68+
let referenced_columns = operator.referenced_columns(false);
69+
let new_column_references = trans_references!(&referenced_columns);
70+
71+
Self::recollect_apply(new_column_references, scan_columns, node_id, graph)
72+
}
73+
Operator::TableScan(op) => {
74+
let table_column: HashSet<&ColumnRef> = op.columns.values().collect();
75+
let mut new_scan_columns = scan_columns.clone();
76+
new_scan_columns.insert(
77+
op.table_name.clone(),
78+
(
79+
op.primary_keys.clone(),
80+
op.columns
81+
.iter()
82+
.map(|(num, col)| (col.id().unwrap(), *num))
83+
.collect(),
84+
op.index_infos.clone(),
85+
),
86+
);
87+
let mut parent_col = HashMap::new();
88+
for col in column_references {
89+
match (
90+
table_column.contains(col),
91+
scan_columns.get(col.table_name().unwrap_or(&Arc::new("".to_string()))),
92+
) {
93+
(false, Some(..)) => {
94+
parent_col
95+
.entry(col.table_name().unwrap())
96+
.or_insert(HashSet::new())
97+
.insert(col);
98+
}
99+
_ => continue,
100+
}
101+
}
102+
for (table_name, table_columns) in parent_col {
103+
let table_columns = table_columns.into_iter().collect_vec();
104+
let (primary_keys, columns, index_infos) =
105+
scan_columns.get(table_name).unwrap();
106+
let map: BTreeMap<usize, ColumnRef> = table_columns
107+
.into_iter()
108+
.map(|col| (*columns.get(&col.id().unwrap()).unwrap(), col.clone()))
109+
.collect();
110+
let left_operator = graph.operator(node_id).clone();
111+
let right_operator = TableScan(TableScanOperator {
112+
table_name: table_name.clone(),
113+
primary_keys: primary_keys.clone(),
114+
columns: map,
115+
limit: (None, None),
116+
index_infos: index_infos.clone(),
117+
with_pk: false,
118+
});
119+
let join_operator = Join(JoinOperator {
120+
on: JoinCondition::None,
121+
join_type: JoinType::Cross,
122+
});
123+
124+
match &left_operator {
125+
TableScan(_) => {
126+
graph.replace_node(node_id, join_operator);
127+
graph.add_node(node_id, None, left_operator);
128+
graph.add_node(node_id, None, right_operator);
129+
}
130+
Join(_) => {
131+
let left_id = graph.eldest_child_at(node_id).unwrap();
132+
let left_id = graph.add_node(node_id, Some(left_id), join_operator);
133+
graph.add_node(left_id, None, right_operator);
134+
}
135+
_ => unreachable!(),
136+
}
137+
}
138+
Ok(new_scan_columns)
139+
}
140+
Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => {
141+
let mut new_scan_columns = scan_columns.clone();
142+
let temp_columns = operator.referenced_columns(false);
143+
// why?
144+
let mut column_references = column_references;
145+
for column in temp_columns.iter() {
146+
column_references.insert(column);
147+
}
148+
for child_id in graph.children_at(node_id).collect_vec() {
149+
let copy_references = column_references.clone();
150+
let copy_scan = scan_columns.clone();
151+
if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) {
152+
new_scan_columns.extend(scan);
153+
};
154+
}
155+
Ok(new_scan_columns)
156+
}
157+
Operator::Join(_) => {
158+
let mut new_scan_columns = scan_columns.clone();
159+
for child_id in graph.children_at(node_id).collect_vec() {
160+
let copy_references = column_references.clone();
161+
let copy_scan = new_scan_columns.clone();
162+
if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) {
163+
new_scan_columns.extend(scan);
164+
};
165+
}
166+
Ok(new_scan_columns)
167+
}
168+
// Last Operator
169+
Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(scan_columns),
170+
Operator::Explain => {
171+
if let Some(child_id) = graph.eldest_child_at(node_id) {
172+
Self::_apply(column_references, scan_columns, child_id, graph)
173+
} else {
174+
unreachable!()
175+
}
176+
}
177+
// DDL Based on Other Plan
178+
Operator::Insert(_)
179+
| Operator::Update(_)
180+
| Operator::Delete(_)
181+
| Operator::Analyze(_) => {
182+
let referenced_columns = operator.referenced_columns(false);
183+
let new_column_references = trans_references!(&referenced_columns);
184+
185+
if let Some(child_id) = graph.eldest_child_at(node_id) {
186+
Self::recollect_apply(new_column_references, scan_columns, child_id, graph)
187+
} else {
188+
unreachable!();
189+
}
190+
}
191+
// DDL Single Plan
192+
Operator::CreateTable(_)
193+
| Operator::CreateIndex(_)
194+
| Operator::CreateView(_)
195+
| Operator::DropTable(_)
196+
| Operator::DropView(_)
197+
| Operator::DropIndex(_)
198+
| Operator::Truncate(_)
199+
| Operator::ShowTable
200+
| Operator::ShowView
201+
| Operator::CopyFromFile(_)
202+
| Operator::CopyToFile(_)
203+
| Operator::AddColumn(_)
204+
| Operator::DropColumn(_)
205+
| Operator::Describe(_) => Ok(scan_columns),
206+
}
207+
}
208+
209+
fn recollect_apply(
210+
referenced_columns: HashSet<&ColumnRef>,
211+
scan_columns: HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>,
212+
node_id: HepNodeId,
213+
graph: &mut HepGraph,
214+
) -> Result<
215+
HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>,
216+
DatabaseError,
217+
> {
218+
let mut new_scan_columns = scan_columns.clone();
219+
for child_id in graph.children_at(node_id).collect_vec() {
220+
let copy_references = referenced_columns.clone();
221+
let copy_scan = scan_columns.clone();
222+
223+
if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) {
224+
new_scan_columns.extend(scan);
225+
};
226+
}
227+
Ok(new_scan_columns)
228+
}
229+
}
230+
231+
impl MatchPattern for CorrelatedSubquery {
232+
fn pattern(&self) -> &Pattern {
233+
&CORRELATED_SUBQUERY_RULE
234+
}
235+
}
236+
237+
impl NormalizationRule for CorrelatedSubquery {
238+
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> {
239+
Self::_apply(HashSet::new(), HashMap::new(), node_id, graph)?;
240+
// mark changed to skip this rule batch
241+
graph.version += 1;
242+
243+
Ok(())
244+
}
245+
}

src/optimizer/rule/normalization/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::optimizer::rule::normalization::combine_operators::{
1010
use crate::optimizer::rule::normalization::compilation_in_advance::{
1111
EvaluatorBind, ExpressionRemapper,
1212
};
13+
use crate::optimizer::rule::normalization::correlated_subquery::CorrelatedSubquery;
1314
use crate::optimizer::rule::normalization::pushdown_limit::{
1415
LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin,
1516
};
@@ -21,6 +22,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter;
2122
mod column_pruning;
2223
mod combine_operators;
2324
mod compilation_in_advance;
25+
mod correlated_subquery;
2426
mod pushdown_limit;
2527
mod pushdown_predicates;
2628
mod simplification;
@@ -32,6 +34,7 @@ pub enum NormalizationRuleImpl {
3234
CollapseProject,
3335
CollapseGroupByAgg,
3436
CombineFilter,
37+
CorrelateSubquery,
3538
// PushDown limit
3639
LimitProjectTranspose,
3740
PushLimitThroughJoin,
@@ -55,6 +58,7 @@ impl MatchPattern for NormalizationRuleImpl {
5558
NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(),
5659
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(),
5760
NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(),
61+
NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.pattern(),
5862
NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(),
5963
NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(),
6064
NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(),
@@ -75,6 +79,7 @@ impl NormalizationRule for NormalizationRuleImpl {
7579
NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph),
7680
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph),
7781
NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph),
82+
NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.apply(node_id, graph),
7883
NormalizationRuleImpl::LimitProjectTranspose => {
7984
LimitProjectTranspose.apply(node_id, graph)
8085
}

0 commit comments

Comments
 (0)