Skip to content

Commit c34fefc

Browse files
committed
planning render combining queries
1 parent 1aea939 commit c34fefc

5 files changed

Lines changed: 55 additions & 51 deletions

File tree

mindsdb_sql/parser/dialects/mindsdb/parser.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class MindsDBParser(Parser):
7070
'drop_dataset',
7171
'select',
7272
'insert',
73+
'union',
7374
'update',
7475
'delete',
7576
'evaluate',
@@ -614,10 +615,13 @@ def update(self, p):
614615

615616
# INSERT
616617
@_('INSERT INTO identifier LPAREN column_list RPAREN select',
617-
'INSERT INTO identifier select')
618+
'INSERT INTO identifier LPAREN column_list RPAREN union',
619+
'INSERT INTO identifier select',
620+
'INSERT INTO identifier union')
618621
def insert(self, p):
619622
columns = getattr(p, 'column_list', None)
620-
return Insert(table=p.identifier, columns=columns, from_select=p.select)
623+
query = p.select if hasattr(p, 'select') else p.union
624+
return Insert(table=p.identifier, columns=columns, from_select=query)
621625

622626
@_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set',
623627
'INSERT INTO identifier VALUES expr_list_set')
@@ -999,20 +1003,28 @@ def database_engine(self, p):
9991003
return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty}
10001004

10011005
# Combining
1002-
@_('select UNION select')
1003-
@_('select UNION ALL select')
1004-
def select(self, p):
1005-
return Union(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL'))
1006-
1007-
@_('select INTERSECT select')
1008-
@_('select INTERSECT ALL select')
1009-
def select(self, p):
1010-
return Intersect(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL'))
1011-
1012-
@_('select EXCEPT select')
1013-
@_('select EXCEPT ALL select')
1014-
def select(self, p):
1015-
return Except(left=p.select0, right=p.select1, unique=not hasattr(p, 'ALL'))
1006+
@_('select UNION select',
1007+
'union UNION select',
1008+
'select UNION ALL select',
1009+
'union UNION ALL select')
1010+
def union(self, p):
1011+
unique = not hasattr(p, 'ALL')
1012+
return Union(left=p[0], right=p[2] if unique else p[3], unique=unique)
1013+
1014+
@_('select INTERSECT select',
1015+
'union INTERSECT select',
1016+
'select INTERSECT ALL select',
1017+
'union INTERSECT ALL select')
1018+
def union(self, p):
1019+
unique = not hasattr(p, 'ALL')
1020+
return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique)
1021+
@_('select EXCEPT select',
1022+
'union EXCEPT select',
1023+
'select EXCEPT ALL select',
1024+
'union EXCEPT ALL select')
1025+
def union(self, p):
1026+
unique = not hasattr(p, 'ALL')
1027+
return Except(left=p[0], right=p[2] if unique else p[3], unique=unique)
10161028

10171029
# tableau
10181030
@_('LPAREN select RPAREN')

mindsdb_sql/planner/query_planner.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mindsdb_sql.exceptions import PlanningException
44
from mindsdb_sql.parser import ast
55
from mindsdb_sql.parser.ast import (Select, Identifier, Join, Star, BinaryOperation, Constant, Union, CreateTable,
6-
Function, Insert,
6+
Function, Insert, Except, Intersect,
77
Update, NativeQuery, Parameter, Delete)
88
from mindsdb_sql.planner import utils
99
from mindsdb_sql.planner.query_plan import QueryPlan
@@ -678,7 +678,7 @@ def plan_cte(self, query):
678678
self.cte_results[name] = step.result
679679

680680
def plan_select(self, query, integration=None):
681-
if isinstance(query, Union):
681+
if isinstance(query, (Union, Except, Intersect)):
682682
return self.plan_union(query, integration=integration)
683683

684684
if query.cte is not None:
@@ -734,14 +734,15 @@ def plan_sub_select(self, query, prev_step, add_absent_cols=False):
734734
return prev_step
735735

736736
def plan_union(self, query, integration=None):
737-
if isinstance(query.left, Union):
738-
step1 = self.plan_union(query.left, integration=integration)
739-
else:
740-
# it is select
741-
step1 = self.plan_select(query.left, integration=integration)
737+
step1 = self.plan_select(query.left, integration=integration)
742738
step2 = self.plan_select(query.right, integration=integration)
739+
operation = 'union'
740+
if isinstance(query, Except):
741+
operation = 'except'
742+
elif isinstance(query, Intersect):
743+
operation = 'intersect'
743744

744-
return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique))
745+
return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique, operation=operation))
745746

746747
# method for compatibility
747748
def from_query(self, query=None):
@@ -750,7 +751,7 @@ def from_query(self, query=None):
750751
if query is None:
751752
query = self.query
752753

753-
if isinstance(query, (Select, Union)):
754+
if isinstance(query, (Select, Union, Except, Intersect)):
754755
self.plan_select(query)
755756
elif isinstance(query, CreateTable):
756757
self.plan_create_table(query)

mindsdb_sql/planner/steps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,12 @@ def __init__(self, left, right, query, *args, **kwargs):
7575

7676
class UnionStep(PlanStep):
7777
"""Union of two dataframes, producing a new dataframe"""
78-
def __init__(self, left, right, unique, *args, **kwargs):
78+
def __init__(self, left, right, unique, operation='union', *args, **kwargs):
7979
super().__init__(*args, **kwargs)
8080
self.left = left
8181
self.right = right
8282
self.unique = unique
83+
self.operation = operation
8384

8485

8586
# TODO remove

mindsdb_sql/render/sqlalchemy_render.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def to_table(self, node):
396396
return table
397397

398398
def prepare_select(self, node):
399-
if isinstance(node, ast.Union):
399+
if isinstance(node, (ast.Union, ast.Except, ast.Intersect)):
400400
return self.prepare_union(node)
401401

402402
cols = []
@@ -525,26 +525,17 @@ def prepare_select(self, node):
525525
return query
526526

527527
def prepare_union(self, from_table):
528-
tables = self.extract_union_list(from_table)
528+
step1 = self.prepare_select(from_table.left)
529+
step2 = self.prepare_select(from_table.right)
529530

530-
table1 = tables[0]
531-
tables_x = tables[1:]
532-
533-
return table1.union(*tables_x)
534-
535-
def extract_union_list(self, node):
536-
if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)):
537-
raise NotImplementedError(
538-
f'Unknown UNION {node.left.__class__.__name__}, {node.right.__class__.__name__}')
539-
540-
tables = []
541-
if isinstance(node.left, ast.Union):
542-
tables.extend(self.extract_union_list(node.left))
531+
if isinstance(from_table, ast.Except):
532+
func = sa.except_ if from_table.unique else sa.except_all
533+
elif isinstance(from_table, ast.Intersect):
534+
func = sa.intersect if from_table.unique else sa.intersect_all
543535
else:
544-
tables.append(self.prepare_select(node.left))
545-
tables.append(self.prepare_select(node.right))
546-
return tables
536+
func = sa.union if from_table.unique else sa.union_all
547537

538+
return func(step1, step2)
548539

549540
def prepare_create_table(self, ast_query):
550541
columns = []
@@ -693,7 +684,7 @@ def prepare_delete(self, ast_query: ast.Delete):
693684

694685
def get_query(self, ast_query, with_params=False):
695686
params = None
696-
if isinstance(ast_query, ast.Select):
687+
if isinstance(ast_query, (ast.Select, ast.Union, ast.Except, ast.Intersect)):
697688
stmt = self.prepare_select(ast_query)
698689
elif isinstance(ast_query, ast.Insert):
699690
stmt, params = self.prepare_insert(ast_query, with_params=with_params)

tests/test_parser/test_base_sql/test_union.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,16 @@ def test_union_alias(self):
6060
from_table=Union(
6161
unique=True,
6262
alias=Identifier('alias'),
63-
left=Select(targets=[Identifier('col1')],
64-
from_table=Identifier(parts=['tab1']),),
65-
right=Union(
63+
left=Union(
6664
unique=True,
6765
left=Select(
6866
targets=[Identifier('col1')],
69-
from_table=Identifier(parts=['tab2']),),
67+
from_table=Identifier(parts=['tab1']),),
7068
right=Select(targets=[Identifier('col1')],
71-
from_table=Identifier(parts=['tab3']),),
69+
from_table=Identifier(parts=['tab2']),),
7270
),
73-
71+
right=Select(targets=[Identifier('col1')],
72+
from_table=Identifier(parts=['tab3']),),
7473
)
7574
)
7675
assert ast.to_tree() == expected_ast.to_tree()

0 commit comments

Comments
 (0)