|
| 1 | +use crate::catalog::{ColumnRef, TableName}; |
| 2 | +use crate::errors::DatabaseError; |
| 3 | +use crate::expression::visitor::Visitor; |
| 4 | +use crate::expression::HasCountStar; |
| 5 | +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; |
| 6 | +use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; |
| 7 | +use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; |
| 8 | +use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; |
| 9 | +use crate::planner::operator::table_scan::TableScanOperator; |
| 10 | +use crate::planner::operator::Operator; |
| 11 | +use crate::planner::operator::Operator::{Join, TableScan}; |
| 12 | +use crate::types::index::IndexInfo; |
| 13 | +use crate::types::ColumnId; |
| 14 | +use itertools::Itertools; |
| 15 | +use std::collections::BTreeMap; |
| 16 | +use std::collections::{HashMap, HashSet}; |
| 17 | +use std::sync::{Arc, LazyLock}; |
| 18 | + |
| 19 | +static CORRELATED_SUBQUERY_RULE: LazyLock<Pattern> = LazyLock::new(|| Pattern { |
| 20 | + predicate: |op| matches!(op, Join(_)), |
| 21 | + children: PatternChildrenPredicate::None, |
| 22 | +}); |
| 23 | + |
| 24 | +#[derive(Clone)] |
| 25 | +pub struct CorrelatedSubquery; |
| 26 | + |
| 27 | +macro_rules! trans_references { |
| 28 | + ($columns:expr) => {{ |
| 29 | + let mut column_references = HashSet::with_capacity($columns.len()); |
| 30 | + for column in $columns { |
| 31 | + column_references.insert(column); |
| 32 | + } |
| 33 | + column_references |
| 34 | + }}; |
| 35 | +} |
| 36 | + |
| 37 | +impl CorrelatedSubquery { |
| 38 | + fn _apply( |
| 39 | + column_references: HashSet<&ColumnRef>, |
| 40 | + scan_columns: HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>, |
| 41 | + node_id: HepNodeId, |
| 42 | + graph: &mut HepGraph, |
| 43 | + ) -> Result< |
| 44 | + HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>, |
| 45 | + DatabaseError, |
| 46 | + > { |
| 47 | + let operator = &graph.operator(node_id).clone(); |
| 48 | + |
| 49 | + match operator { |
| 50 | + Operator::Aggregate(op) => { |
| 51 | + let is_distinct = op.is_distinct; |
| 52 | + let referenced_columns = operator.referenced_columns(false); |
| 53 | + let mut new_column_references = trans_references!(&referenced_columns); |
| 54 | + // on distinct |
| 55 | + if is_distinct { |
| 56 | + for summary in column_references { |
| 57 | + new_column_references.insert(summary); |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + Self::recollect_apply(new_column_references, scan_columns, node_id, graph) |
| 62 | + } |
| 63 | + Operator::Project(op) => { |
| 64 | + let mut has_count_star = HasCountStar::default(); |
| 65 | + for expr in &op.exprs { |
| 66 | + has_count_star.visit(expr)?; |
| 67 | + } |
| 68 | + let referenced_columns = operator.referenced_columns(false); |
| 69 | + let new_column_references = trans_references!(&referenced_columns); |
| 70 | + |
| 71 | + Self::recollect_apply(new_column_references, scan_columns, node_id, graph) |
| 72 | + } |
| 73 | + Operator::TableScan(op) => { |
| 74 | + let table_column: HashSet<&ColumnRef> = op.columns.values().collect(); |
| 75 | + let mut new_scan_columns = scan_columns.clone(); |
| 76 | + new_scan_columns.insert( |
| 77 | + op.table_name.clone(), |
| 78 | + ( |
| 79 | + op.primary_keys.clone(), |
| 80 | + op.columns |
| 81 | + .iter() |
| 82 | + .map(|(num, col)| (col.id().unwrap(), *num)) |
| 83 | + .collect(), |
| 84 | + op.index_infos.clone(), |
| 85 | + ), |
| 86 | + ); |
| 87 | + let mut parent_col = HashMap::new(); |
| 88 | + for col in column_references { |
| 89 | + match ( |
| 90 | + table_column.contains(col), |
| 91 | + scan_columns.get(col.table_name().unwrap_or(&Arc::new("".to_string()))), |
| 92 | + ) { |
| 93 | + (false, Some(..)) => { |
| 94 | + parent_col |
| 95 | + .entry(col.table_name().unwrap()) |
| 96 | + .or_insert(HashSet::new()) |
| 97 | + .insert(col); |
| 98 | + } |
| 99 | + _ => continue, |
| 100 | + } |
| 101 | + } |
| 102 | + for (table_name, table_columns) in parent_col { |
| 103 | + let table_columns = table_columns.into_iter().collect_vec(); |
| 104 | + let (primary_keys, columns, index_infos) = |
| 105 | + scan_columns.get(table_name).unwrap(); |
| 106 | + let map: BTreeMap<usize, ColumnRef> = table_columns |
| 107 | + .into_iter() |
| 108 | + .map(|col| (*columns.get(&col.id().unwrap()).unwrap(), col.clone())) |
| 109 | + .collect(); |
| 110 | + let left_operator = graph.operator(node_id).clone(); |
| 111 | + let right_operator = TableScan(TableScanOperator { |
| 112 | + table_name: table_name.clone(), |
| 113 | + primary_keys: primary_keys.clone(), |
| 114 | + columns: map, |
| 115 | + limit: (None, None), |
| 116 | + index_infos: index_infos.clone(), |
| 117 | + with_pk: false, |
| 118 | + }); |
| 119 | + let join_operator = Join(JoinOperator { |
| 120 | + on: JoinCondition::None, |
| 121 | + join_type: JoinType::Cross, |
| 122 | + }); |
| 123 | + |
| 124 | + match &left_operator { |
| 125 | + TableScan(_) => { |
| 126 | + graph.replace_node(node_id, join_operator); |
| 127 | + graph.add_node(node_id, None, left_operator); |
| 128 | + graph.add_node(node_id, None, right_operator); |
| 129 | + } |
| 130 | + Join(_) => { |
| 131 | + let left_id = graph.eldest_child_at(node_id).unwrap(); |
| 132 | + let left_id = graph.add_node(node_id, Some(left_id), join_operator); |
| 133 | + graph.add_node(left_id, None, right_operator); |
| 134 | + } |
| 135 | + _ => unreachable!(), |
| 136 | + } |
| 137 | + } |
| 138 | + Ok(new_scan_columns) |
| 139 | + } |
| 140 | + Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => { |
| 141 | + let mut new_scan_columns = scan_columns.clone(); |
| 142 | + let temp_columns = operator.referenced_columns(false); |
| 143 | + // why? |
| 144 | + let mut column_references = column_references; |
| 145 | + for column in temp_columns.iter() { |
| 146 | + column_references.insert(column); |
| 147 | + } |
| 148 | + for child_id in graph.children_at(node_id).collect_vec() { |
| 149 | + let copy_references = column_references.clone(); |
| 150 | + let copy_scan = scan_columns.clone(); |
| 151 | + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { |
| 152 | + new_scan_columns.extend(scan); |
| 153 | + }; |
| 154 | + } |
| 155 | + Ok(new_scan_columns) |
| 156 | + } |
| 157 | + Operator::Join(_) => { |
| 158 | + let mut new_scan_columns = scan_columns.clone(); |
| 159 | + for child_id in graph.children_at(node_id).collect_vec() { |
| 160 | + let copy_references = column_references.clone(); |
| 161 | + let copy_scan = new_scan_columns.clone(); |
| 162 | + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { |
| 163 | + new_scan_columns.extend(scan); |
| 164 | + }; |
| 165 | + } |
| 166 | + Ok(new_scan_columns) |
| 167 | + } |
| 168 | + // Last Operator |
| 169 | + Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(scan_columns), |
| 170 | + Operator::Explain => { |
| 171 | + if let Some(child_id) = graph.eldest_child_at(node_id) { |
| 172 | + Self::_apply(column_references, scan_columns, child_id, graph) |
| 173 | + } else { |
| 174 | + unreachable!() |
| 175 | + } |
| 176 | + } |
| 177 | + // DDL Based on Other Plan |
| 178 | + Operator::Insert(_) |
| 179 | + | Operator::Update(_) |
| 180 | + | Operator::Delete(_) |
| 181 | + | Operator::Analyze(_) => { |
| 182 | + let referenced_columns = operator.referenced_columns(false); |
| 183 | + let new_column_references = trans_references!(&referenced_columns); |
| 184 | + |
| 185 | + if let Some(child_id) = graph.eldest_child_at(node_id) { |
| 186 | + Self::recollect_apply(new_column_references, scan_columns, child_id, graph) |
| 187 | + } else { |
| 188 | + unreachable!(); |
| 189 | + } |
| 190 | + } |
| 191 | + // DDL Single Plan |
| 192 | + Operator::CreateTable(_) |
| 193 | + | Operator::CreateIndex(_) |
| 194 | + | Operator::CreateView(_) |
| 195 | + | Operator::DropTable(_) |
| 196 | + | Operator::DropView(_) |
| 197 | + | Operator::DropIndex(_) |
| 198 | + | Operator::Truncate(_) |
| 199 | + | Operator::ShowTable |
| 200 | + | Operator::ShowView |
| 201 | + | Operator::CopyFromFile(_) |
| 202 | + | Operator::CopyToFile(_) |
| 203 | + | Operator::AddColumn(_) |
| 204 | + | Operator::DropColumn(_) |
| 205 | + | Operator::Describe(_) => Ok(scan_columns), |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + fn recollect_apply( |
| 210 | + referenced_columns: HashSet<&ColumnRef>, |
| 211 | + scan_columns: HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>, |
| 212 | + node_id: HepNodeId, |
| 213 | + graph: &mut HepGraph, |
| 214 | + ) -> Result< |
| 215 | + HashMap<TableName, (Vec<ColumnId>, HashMap<ColumnId, usize>, Vec<IndexInfo>)>, |
| 216 | + DatabaseError, |
| 217 | + > { |
| 218 | + let mut new_scan_columns = scan_columns.clone(); |
| 219 | + for child_id in graph.children_at(node_id).collect_vec() { |
| 220 | + let copy_references = referenced_columns.clone(); |
| 221 | + let copy_scan = scan_columns.clone(); |
| 222 | + |
| 223 | + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { |
| 224 | + new_scan_columns.extend(scan); |
| 225 | + }; |
| 226 | + } |
| 227 | + Ok(new_scan_columns) |
| 228 | + } |
| 229 | +} |
| 230 | + |
| 231 | +impl MatchPattern for CorrelatedSubquery { |
| 232 | + fn pattern(&self) -> &Pattern { |
| 233 | + &CORRELATED_SUBQUERY_RULE |
| 234 | + } |
| 235 | +} |
| 236 | + |
| 237 | +impl NormalizationRule for CorrelatedSubquery { |
| 238 | + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { |
| 239 | + Self::_apply(HashSet::new(), HashMap::new(), node_id, graph)?; |
| 240 | + // mark changed to skip this rule batch |
| 241 | + graph.version += 1; |
| 242 | + |
| 243 | + Ok(()) |
| 244 | + } |
| 245 | +} |
0 commit comments