Skip to content

Commit fbc4315

Browse files
authored
Merge pull request #415 from mindsdb/window-fix
Parser fixes #1
2 parents fbb913d + fce61dc commit fbc4315

7 files changed

Lines changed: 164 additions & 10 deletions

File tree

mindsdb_sql/parser/ast/select/case.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55

66
class Case(ASTNode):
7-
def __init__(self, rules, default=None, *args, **kwargs):
7+
def __init__(self, rules, default=None, arg=None, *args, **kwargs):
88
super().__init__(*args, **kwargs)
99

1010
# structure:
1111
# [
1212
# [ condition, result ]
1313
# ]
14+
self.arg = arg
1415
self.rules = rules
1516
self.default = default
1617

@@ -36,7 +37,12 @@ def to_tree(self, *args, level=0, **kwargs):
3637
if self.default is not None:
3738
default_str = f'{ind1}default => {self.default.to_string()}\n'
3839

40+
arg_str = ''
41+
if self.arg is not None:
42+
arg_str = f'{ind1}arg => {self.arg.to_string()}\n'
43+
3944
return f'{ind}Case(\n' \
45+
f'{arg_str}'\
4046
f'{rules_str}\n' \
4147
f'{default_str}' \
4248
f'{ind})'
@@ -53,4 +59,8 @@ def get_string(self, *args, alias=True, **kwargs):
5359
default_str = ''
5460
if self.default is not None:
5561
default_str = f' ELSE {self.default.to_string()}'
56-
return f"CASE {rules_str}{default_str} END"
62+
63+
arg_str = ''
64+
if self.arg is not None:
65+
arg_str = f'{self.arg.to_string()} '
66+
return f"CASE {arg_str}{rules_str}{default_str} END"

mindsdb_sql/parser/ast/select/operation.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ def get_string(self, *args, **kwargs):
9898

9999

100100
class WindowFunction(ASTNode):
101-
def __init__(self, function, partition=None, order_by=None, alias=None):
101+
def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None):
102102
super().__init__()
103103
self.function = function
104104
self.partition = partition
105105
self.order_by = order_by
106106
self.alias = alias
107+
self.modifier = modifier
107108

108109
def to_tree(self, *args, level=0, **kwargs):
109110
fnc_str = self.function.to_tree(level=level+2)
@@ -143,7 +144,8 @@ def to_string(self, *args, **kwargs):
143144
alias_str = self.alias.to_string()
144145
else:
145146
alias_str = ''
146-
return f'{fnc_str} over({partition_str} {order_str}) {alias_str}'
147+
modifier_str = ' ' + self.modifier if self.modifier else ''
148+
return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}'
147149

148150

149151
class Object(ASTNode):
@@ -177,7 +179,12 @@ def __init__(self, info):
177179
super().__init__(op='interval', args=[info, ])
178180

179181
def get_string(self, *args, **kwargs):
180-
return f'INTERVAL {self.args[0]}'
182+
183+
arg = self.args[0]
184+
items = arg.split(' ', maxsplit=1)
185+
# quote first element
186+
items[0] = f"'{items[0]}'"
187+
return "INTERVAL " + " ".join(items)
181188

182189
def to_tree(self, *args, level=0, **kwargs):
183190
return self.get_string( *args, **kwargs)

mindsdb_sql/parser/dialects/mindsdb/parser.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,15 @@ def column_list(self, p):
13521352
def case(self, p):
13531353
return Case(rules=p.case_conditions, default=getattr(p, 'expr', None))
13541354

1355+
@_('CASE expr case_conditions ELSE expr END',
1356+
'CASE expr case_conditions END')
1357+
def case(self, p):
1358+
if hasattr(p, 'expr'):
1359+
arg, default = p.expr, None
1360+
else:
1361+
arg, default = p.expr0, p.expr1
1362+
return Case(rules=p.case_conditions, default=default, arg=arg)
1363+
13551364
@_('case_condition',
13561365
'case_conditions case_condition')
13571366
def case_conditions(self, p):
@@ -1364,13 +1373,18 @@ def case_condition(self, p):
13641373
return [p.expr0, p.expr1]
13651374

13661375
# Window function
1367-
@_('function OVER LPAREN window RPAREN')
1376+
@_('expr OVER LPAREN window RPAREN',
1377+
'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN')
13681378
def window_function(self, p):
13691379

1380+
modifier = None
1381+
if hasattr(p, 'BETWEEN'):
1382+
modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}'
13701383
return WindowFunction(
1371-
function=p.function,
1384+
function=p.expr,
13721385
order_by=p.window.get('order_by'),
13731386
partition=p.window.get('partition'),
1387+
modifier=modifier,
13741388
)
13751389

13761390
@_('window PARTITION_BY expr_list')

mindsdb_sql/planner/query_planner.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,19 @@ def find_objects(node, is_table, **kwargs):
229229
mdb_entities.append(node)
230230

231231
query_traversal(query, find_objects)
232+
233+
# cte names are not mdb objects
234+
if query.cte:
235+
cte_names = [
236+
cte.name.parts[-1]
237+
for cte in query.cte
238+
]
239+
mdb_entities = [
240+
item
241+
for item in mdb_entities
242+
if '.'.join(item.parts) not in cte_names
243+
]
244+
232245
return {
233246
'mdb_entities': mdb_entities,
234247
'integrations': integrations,
@@ -672,6 +685,16 @@ def plan_delete(self, query: Delete):
672685
))
673686

674687
def plan_cte(self, query):
688+
query_info = self.get_query_info(query)
689+
690+
if (
691+
len(query_info['integrations']) == 1
692+
and len(query_info['mdb_entities']) == 0
693+
and len(query_info['user_functions']) == 0
694+
):
695+
# single integration, will be planned later
696+
return
697+
675698
for cte in query.cte:
676699
step = self.plan_select(cte.query)
677700
name = cte.name.parts[-1]

mindsdb_sql/render/sqlalchemy_render.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,15 @@ def prepare_case(self, t: ast.Case):
293293
conditions.append(
294294
(self.to_expression(condition), self.to_expression(result))
295295
)
296+
default = None
296297
if t.default is not None:
297-
conditions.append(self.to_expression(t.default))
298+
default = self.to_expression(t.default)
298299

299-
return sa.case(*conditions)
300+
value = None
301+
if t.arg is not None:
302+
value = self.to_expression(t.arg)
303+
304+
return sa.case(*conditions, else_=default, value=value)
300305

301306
def to_function(self, t):
302307
op = getattr(sa.func, t.op)

tests/test_parser/test_base_sql/test_select_structure.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,40 @@ def test_case(self):
10261026
assert ast.to_tree() == expected_ast.to_tree()
10271027
assert str(ast) == str(expected_ast)
10281028

1029+
def test_case_simple_form(self):
1030+
sql = f'''SELECT
1031+
CASE R.DELETE_RULE
1032+
WHEN 'CASCADE' THEN 0
1033+
WHEN 'SET NULL' THEN 2
1034+
ELSE 3
1035+
END AS DELETE_RULE
1036+
FROM COLLATIONS'''
1037+
ast = parse_sql(sql)
1038+
1039+
expected_ast = Select(
1040+
targets=[
1041+
Case(
1042+
arg=Identifier('R.DELETE_RULE'),
1043+
rules=[
1044+
[
1045+
Constant('CASCADE'),
1046+
Constant(0)
1047+
],
1048+
[
1049+
Constant('SET NULL'),
1050+
Constant(2)
1051+
]
1052+
],
1053+
default=Constant(3),
1054+
alias=Identifier('DELETE_RULE')
1055+
)
1056+
],
1057+
from_table=Identifier('COLLATIONS')
1058+
)
1059+
1060+
assert ast.to_tree() == expected_ast.to_tree()
1061+
assert str(ast) == str(expected_ast)
1062+
10291063
def test_select_left(self):
10301064
sql = f'select left(a, 1) from tab1'
10311065
ast = parse_sql(sql)
@@ -1152,3 +1186,23 @@ def test_table_double_quote(self):
11521186

11531187
ast = parse_sql(sql)
11541188
assert str(ast) == str(expected_ast)
1189+
1190+
def test_window_function_mindsdb(self):
1191+
1192+
# modifier
1193+
query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 "
1194+
expected_ast = Select(
1195+
targets=[
1196+
WindowFunction(
1197+
function=Function(op='sum', args=[Identifier('col0')]),
1198+
partition=[Identifier('col1')],
1199+
order_by=[OrderBy(field=Identifier('col2'))],
1200+
modifier='rows BETWEEN unbounded preceding AND current row'
1201+
)
1202+
],
1203+
from_table=Identifier('table1')
1204+
)
1205+
ast = parse_sql(query)
1206+
assert str(ast) == str(expected_ast)
1207+
assert ast.to_tree() == expected_ast.to_tree()
1208+

tests/test_planner/test_integration_select.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def test_select_from_table_subselect_api_integration(self):
554554
plan = plan_query(
555555
query,
556556
integrations=[{'name': 'int1', 'class_type': 'api', 'type': 'data'}],
557-
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}]
557+
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}],
558558
)
559559

560560
assert plan.steps == expected_plan.steps
@@ -583,6 +583,47 @@ def test_select_from_table_subselect_sql_integration(self):
583583

584584
assert plan.steps == expected_plan.steps
585585

586+
def test_select_from_single_integration(self):
587+
sql_parsed = '''
588+
with tab2 as (
589+
select * from int1.tabl2
590+
)
591+
select x from tab2
592+
join int1.tab1 on 0=0
593+
where x1 in (select id from int1.tab1)
594+
limit 1
595+
'''
596+
597+
sql_integration = '''
598+
with tab2 as (
599+
select * from tabl2
600+
)
601+
select x from tab2
602+
join tab1 on 0=0
603+
where x1 in (select id as id from tab1)
604+
limit 1
605+
'''
606+
query = parse_sql(sql_parsed, dialect='mindsdb')
607+
608+
expected_plan = QueryPlan(
609+
predictor_namespace='mindsdb',
610+
steps=[
611+
FetchDataframeStep(
612+
integration='int1',
613+
query=parse_sql(sql_integration),
614+
),
615+
],
616+
)
617+
618+
plan = plan_query(
619+
query,
620+
integrations=[{'name': 'int1', 'class_type': 'sql', 'type': 'data'}],
621+
predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}],
622+
default_namespace='mindsdb',
623+
)
624+
625+
assert plan.steps == expected_plan.steps
626+
586627
def test_delete_from_table_subselect_api_integration(self):
587628
query = parse_sql('''
588629
delete from int1.tab1

0 commit comments

Comments
 (0)