Skip to content

Commit 2dbc722

Browse files
committed
feat: Join Using will follow the direction of join type and select the displayed using column
1 parent 1ac3632 commit 2dbc722

6 files changed

Lines changed: 181 additions & 154 deletions

File tree

src/binder/create_view.rs

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,36 @@ impl<T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'_, '_, T, A>
2424
let view_name = Arc::new(lower_case_name(name)?);
2525
let mut plan = self.bind_query(query)?;
2626

27-
if !columns.is_empty() {
28-
let mapping_schema = plan.output_schema();
29-
let exprs = columns
30-
.iter()
31-
.enumerate()
32-
.map(|(i, ident)| {
33-
let mapping_column = &mapping_schema[i];
34-
let mut column = ColumnCatalog::new(
35-
lower_ident(ident),
36-
mapping_column.nullable(),
37-
mapping_column.desc().clone(),
38-
);
39-
column.set_ref_table(view_name.clone(), Ulid::new(), true);
27+
let mapping_schema = plan.output_schema();
4028

41-
ScalarExpression::Alias {
42-
expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())),
43-
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(
44-
ColumnRef::from(column),
45-
))),
46-
}
47-
})
48-
.collect_vec();
49-
plan = self.bind_project(plan, exprs)?;
29+
let exprs = if columns.is_empty() {
30+
Box::new(
31+
mapping_schema
32+
.iter()
33+
.map(|column| column.name().to_string()),
34+
) as Box<dyn Iterator<Item = String>>
35+
} else {
36+
Box::new(columns.iter().map(lower_ident)) as Box<dyn Iterator<Item = String>>
5037
}
38+
.enumerate()
39+
.map(|(i, column_name)| {
40+
let mapping_column = &mapping_schema[i];
41+
let mut column = ColumnCatalog::new(
42+
column_name,
43+
mapping_column.nullable(),
44+
mapping_column.desc().clone(),
45+
);
46+
column.set_ref_table(view_name.clone(), Ulid::new(), true);
47+
48+
ScalarExpression::Alias {
49+
expr: Box::new(ScalarExpression::ColumnRef(mapping_column.clone())),
50+
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
51+
column,
52+
)))),
53+
}
54+
})
55+
.collect_vec();
56+
plan = self.bind_project(plan, exprs)?;
5157

5258
Ok(LogicalPlan::new(
5359
Operator::CreateView(CreateViewOperator {

src/binder/mod.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ pub struct BinderContext<'a, T: Transaction> {
112112
group_by_exprs: Vec<ScalarExpression>,
113113
pub(crate) agg_calls: Vec<ScalarExpression>,
114114
// join
115-
using: HashSet<String>,
115+
using: HashSet<ColumnRef>,
116116

117117
bind_step: QueryBindStep,
118118
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,
@@ -295,8 +295,17 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
295295
}
296296
}
297297

298-
pub fn add_using(&mut self, name: String) {
299-
self.using.insert(name);
298+
pub fn add_using(
299+
&mut self,
300+
join_type: JoinType,
301+
left_expr: &ColumnRef,
302+
right_expr: &ColumnRef,
303+
) {
304+
self.using.insert(if join_type.is_right() {
305+
left_expr.clone()
306+
} else {
307+
right_expr.clone()
308+
});
300309
}
301310

302311
pub fn add_alias(

src/binder/select.rs

Lines changed: 59 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ use crate::types::tuple::{Schema, SchemaRef};
3636
use crate::types::value::Utf8Type;
3737
use crate::types::{ColumnId, LogicalType};
3838
use itertools::Itertools;
39-
use sqlparser::ast::CharLengthUnits::Characters;
4039
use sqlparser::ast::{
4140
CharLengthUnits, Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset,
4241
OrderByExpr, Query, Select, SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier,
@@ -189,14 +188,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
189188
}
190189
}
191190

192-
if left_cast.len() > 0 {
191+
if !left_cast.is_empty() {
193192
left_plan = LogicalPlan::new(
194193
Operator::Project(ProjectOperator { exprs: left_cast }),
195194
Childrens::Only(left_plan),
196195
);
197196
}
198197

199-
if right_cast.len() > 0 {
198+
if !right_cast.is_empty() {
200199
right_plan = LogicalPlan::new(
201200
Operator::Project(ProjectOperator { exprs: right_cast }),
202201
Childrens::Only(right_plan),
@@ -393,7 +392,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
393392
unreachable!()
394393
}
395394
}
396-
_ => unimplemented!(),
395+
table => return Err(DatabaseError::UnsupportedStmt(format!("{:#?}", table))),
397396
};
398397

399398
Ok(plan)
@@ -517,8 +516,6 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
517516
}
518517
continue;
519518
}
520-
let mut join_used = HashSet::with_capacity(self.context.using.len());
521-
522519
for (table_name, alias, _) in self.context.bind_table.keys() {
523520
let schema_buf =
524521
self.table_schema_buf.entry(table_name.clone()).or_default();
@@ -527,7 +524,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
527524
schema_buf,
528525
&mut select_items,
529526
alias.as_ref().unwrap_or(table_name).clone(),
530-
Some(&mut join_used),
527+
false,
531528
)?;
532529
}
533530
}
@@ -540,7 +537,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
540537
schema_buf,
541538
&mut select_items,
542539
table_name,
543-
None,
540+
true,
544541
)?;
545542
}
546543
};
@@ -555,58 +552,48 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
555552
schema_buf: &mut Option<SchemaOutput>,
556553
exprs: &mut Vec<ScalarExpression>,
557554
table_name: TableName,
558-
mut join_used: Option<&mut HashSet<String>>,
555+
is_qualified_wildcard: bool,
559556
) -> Result<(), DatabaseError> {
560-
let mut is_bound_alias = false;
561-
562-
let fn_used =
563-
|column_name: &str, context: &BinderContext<T>, join_used: Option<&HashSet<_>>| {
564-
context.using.contains(column_name)
565-
&& matches!(join_used.map(|used| used.contains(column_name)), Some(true))
566-
};
567-
for (_, alias_expr) in context.expr_aliases.iter().filter(|(_, expr)| {
568-
if let ScalarExpression::ColumnRef(col) = expr.unpack_alias_ref() {
569-
let column_name = col.name();
557+
let fn_not_on_using = |column: &ColumnRef| {
558+
if context.using.is_empty() {
559+
return Some(&table_name) == column.table_name();
560+
}
561+
is_qualified_wildcard
562+
|| Some(&table_name) == column.table_name() && !context.using.contains(column)
563+
};
570564

571-
if Some(&table_name) == col.table_name()
572-
&& !fn_used(column_name, context, join_used.as_deref())
573-
{
574-
if let Some(used) = join_used.as_mut() {
575-
used.insert(column_name.to_string());
565+
let bound_alias = context
566+
.expr_aliases
567+
.iter()
568+
.filter(|(_, expr)| {
569+
if let ScalarExpression::ColumnRef(col) = expr.unpack_alias_ref() {
570+
if fn_not_on_using(col) {
571+
exprs.push(ScalarExpression::clone(expr));
572+
return true;
576573
}
577-
return true;
578574
}
579-
}
580-
false
581-
}) {
582-
is_bound_alias = true;
583-
exprs.push(alias_expr.clone());
584-
}
585-
if is_bound_alias {
575+
false
576+
})
577+
.count()
578+
> 0;
579+
580+
if bound_alias {
586581
return Ok(());
587582
}
588-
589583
let mut source = None;
590584

591585
source = context.table(table_name.clone())?.map(Source::Table);
592586
if source.is_none() {
593-
source = context.view(table_name)?.map(Source::View);
587+
source = context.view(table_name.clone())?.map(Source::View);
594588
}
595589
for column in source
596590
.ok_or(DatabaseError::SourceNotFound)?
597591
.columns(schema_buf)
598592
{
599-
let column_name = column.name();
600-
601-
if fn_used(column_name, context, join_used.as_deref()) {
593+
if !fn_not_on_using(column) {
602594
continue;
603595
}
604-
let expr = ScalarExpression::ColumnRef(column.clone());
605-
606-
if let Some(used) = join_used.as_mut() {
607-
used.insert(column_name.to_string());
608-
}
609-
exprs.push(expr);
596+
exprs.push(ScalarExpression::ColumnRef(column.clone()));
610597
}
611598
Ok(())
612599
}
@@ -654,9 +641,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
654641
self.extend(binder.context);
655642

656643
let on = match joint_condition {
657-
Some(constraint) => {
658-
self.bind_join_constraint(left.output_schema(), right.output_schema(), constraint)?
659-
}
644+
Some(constraint) => self.bind_join_constraint(
645+
join_type,
646+
left.output_schema(),
647+
right.output_schema(),
648+
constraint,
649+
)?,
660650
None => JoinCondition::None,
661651
};
662652

@@ -902,6 +892,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
902892

903893
fn bind_join_constraint<'c>(
904894
&mut self,
895+
join_type: JoinType,
905896
left_schema: &'c SchemaRef,
906897
right_schema: &'c SchemaRef,
907898
constraint: &JoinConstraint,
@@ -938,24 +929,26 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
938929
})
939930
}
940931
JoinConstraint::Using(idents) => {
932+
fn find_column<'a>(schema: &'a Schema, name: &'a str) -> Option<&'a ColumnRef> {
933+
schema.iter().find(|column| column.name() == name)
934+
}
935+
941936
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new();
942-
let fn_column = |schema: &Schema, name: &str| {
943-
schema
944-
.iter()
945-
.find(|column| column.name() == name)
946-
.map(|column| ScalarExpression::ColumnRef(column.clone()))
947-
};
937+
948938
for ident in idents {
949939
let name = lower_ident(ident);
950-
if let (Some(left_column), Some(right_column)) = (
951-
fn_column(left_schema, &name),
952-
fn_column(right_schema, &name),
953-
) {
954-
on_keys.push((left_column, right_column));
955-
} else {
956-
return Err(DatabaseError::InvalidColumn("not found column".to_string()))?;
957-
}
958-
self.context.add_using(name);
940+
let (Some(left_column), Some(right_column)) = (
941+
find_column(left_schema, &name),
942+
find_column(right_schema, &name),
943+
) else {
944+
return Err(DatabaseError::InvalidColumn("not found column".to_string()));
945+
};
946+
self.context
947+
.add_using(join_type, left_column, right_column);
948+
on_keys.push((
949+
ScalarExpression::ColumnRef(left_column.clone()),
950+
ScalarExpression::ColumnRef(right_column.clone()),
951+
));
959952
}
960953
Ok(JoinCondition::On {
961954
on: on_keys,
@@ -970,15 +963,15 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
970963
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new();
971964

972965
for name in fn_names(left_schema).intersection(&fn_names(right_schema)) {
973-
self.context.add_using(name.to_string());
974966
if let (Some(left_column), Some(right_column)) = (
975967
left_schema.iter().find(|column| column.name() == *name),
976968
right_schema.iter().find(|column| column.name() == *name),
977969
) {
978-
on_keys.push((
979-
ScalarExpression::ColumnRef(left_column.clone()),
980-
ScalarExpression::ColumnRef(right_column.clone()),
981-
));
970+
let left_expr = ScalarExpression::ColumnRef(left_column.clone());
971+
let right_expr = ScalarExpression::ColumnRef(right_column.clone());
972+
973+
self.context.add_using(join_type, left_column, right_column);
974+
on_keys.push((left_expr, right_expr));
982975
}
983976
}
984977
Ok(JoinCondition::On {

src/planner/operator/join.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ pub enum JoinType {
1616
Full,
1717
Cross,
1818
}
19+
20+
impl JoinType {
21+
pub fn is_right(&self) -> bool {
22+
matches!(self, JoinType::RightOuter)
23+
}
24+
}
25+
1926
#[derive(Debug, Clone, PartialEq, Eq, Hash, ReferenceSerialization)]
2027
pub enum JoinCondition {
2128
On {

0 commit comments

Comments
 (0)