@@ -16,10 +16,10 @@ def __init__(self):
1616
1717 def _reset (self ):
1818 """Set the filter attributes to its default values"""
19- self ._in_declare = False
20- self ._in_case = False
19+ self ._block_stack = []
20+ self ._parenthesis_level = 0
21+ self ._unconfirmed_start = None
2122 self ._is_create = False
22- self ._begin_depth = 0
2323 self ._seen_begin = False
2424
2525 self .consume_ws = False
@@ -29,37 +29,44 @@ def _reset(self):
2929 def _change_splitlevel (self , ttype , value ):
3030 """Get the new split level (increase, decrease or remain equal)"""
3131
32+ # Semicolon resets unconfirmed loop starters
33+ if ttype is T .Punctuation and value == ';' :
34+ self ._unconfirmed_start = None
35+
3236 # parenthesis increase/decrease a level
3337 if ttype is T .Punctuation and value == '(' :
38+ self ._parenthesis_level += 1
3439 return 1
3540 elif ttype is T .Punctuation and value == ')' :
41+ self ._parenthesis_level = max (0 , self ._parenthesis_level - 1 )
3642 return - 1
3743 elif ttype not in T .Keyword : # if normal token return
3844 return 0
3945
4046 # Everything after here is ttype = T.Keyword
41- # Also to note, once entered an If statement you are done and basically
42- # returning
4347 unified = value .upper ()
4448
45- # three keywords begin with CREATE, but only one of them is DDL
4649 # DDL Create though can contain more words such as "or replace"
4750 if ttype is T .Keyword .DDL and unified .startswith ('CREATE' ):
4851 self ._is_create = True
4952 return 0
5053
51- # can have nested declare inside of being...
52- if unified == 'DECLARE' and self ._is_create and self ._begin_depth == 0 :
53- self ._in_declare = True
54+ # Handle DECLARE block start (only for CREATE statements)
55+ if unified == 'DECLARE' and self ._is_create and not self ._block_stack :
56+ self ._block_stack . append ( 'DECLARE' )
5457 return 1
5558
59+ # Handle BEGIN block start
5660 if unified == 'BEGIN' :
57- self ._begin_depth += 1
5861 self ._seen_begin = True
59- if self ._is_create :
60- # FIXME(andi): This makes no sense. ## this comment neither
62+ # Transition DECLARE to BEGIN if present
63+ if self ._block_stack and self ._block_stack [- 1 ] == 'DECLARE' :
64+ self ._block_stack .pop ()
65+ self ._block_stack .append ('BEGIN' )
66+ return 0
67+ else :
68+ self ._block_stack .append ('BEGIN' )
6169 return 1
62- return 0
6370
6471 # Issue826: If we see a transaction keyword after BEGIN,
6572 # it's a transaction statement, not a block.
@@ -68,28 +75,72 @@ def _change_splitlevel(self, ttype, value):
6875 unified in ('TRANSACTION' , 'WORK' , 'TRAN' ,
6976 'DISTRIBUTED' , 'DEFERRED' ,
7077 'IMMEDIATE' , 'EXCLUSIVE' ):
71- self ._begin_depth = max (0 , self ._begin_depth - 1 )
7278 self ._seen_begin = False
79+ if self ._block_stack and self ._block_stack [- 1 ] == 'BEGIN' :
80+ self ._block_stack .pop ()
81+ return - 1
7382 return 0
7483
75- # BEGIN and CASE/WHEN both end with END
76- if unified == 'END' :
77- if not self ._in_case :
78- self ._begin_depth = max (0 , self ._begin_depth - 1 )
79- else :
80- self ._in_case = False
81- return - 1
82-
83- if (unified in ('IF' , 'FOR' , 'WHILE' , 'CASE' )
84- and self ._is_create and self ._begin_depth > 0 ):
85- if unified == 'CASE' :
86- self ._in_case = True
87- return 1
84+ # Inside a block, check for nested loop or control structures
85+ if 'BEGIN' in self ._block_stack :
86+ if unified == 'FOR' :
87+ self ._unconfirmed_start = 'FOR'
88+ return 0
89+ elif unified == 'WHILE' :
90+ self ._unconfirmed_start = 'WHILE'
91+ return 0
92+ elif unified == 'LOOP' :
93+ if self ._unconfirmed_start in ('FOR' , 'WHILE' ):
94+ self ._block_stack .append (self ._unconfirmed_start )
95+ self ._unconfirmed_start = None
96+ return 1
97+ else :
98+ self ._block_stack .append ('LOOP' )
99+ return 1
100+ elif unified == 'DO' :
101+ if self ._unconfirmed_start in ('FOR' , 'WHILE' ):
102+ self ._block_stack .append (self ._unconfirmed_start )
103+ self ._unconfirmed_start = None
104+ return 1
105+ elif unified == 'IF' :
106+ self ._block_stack .append ('IF' )
107+ return 1
108+ elif unified == 'CASE' :
109+ self ._block_stack .append ('CASE' )
110+ return 1
88111
89- if unified in ('END IF' , 'END FOR' , 'END WHILE' ):
90- return - 1
112+ # Handle closing keywords
113+ if unified == 'END IF' :
114+ if self ._block_stack and self ._block_stack [- 1 ] == 'IF' :
115+ self ._block_stack .pop ()
116+ return - 1
117+ elif unified == 'END FOR' :
118+ if self ._block_stack and self ._block_stack [- 1 ] == 'FOR' :
119+ self ._block_stack .pop ()
120+ return - 1
121+ elif unified == 'END WHILE' :
122+ if self ._block_stack and self ._block_stack [- 1 ] == 'WHILE' :
123+ self ._block_stack .pop ()
124+ return - 1
125+ elif unified == 'END LOOP' :
126+ if self ._block_stack and self ._block_stack [- 1 ] in ('LOOP' , 'FOR' , 'WHILE' ):
127+ self ._block_stack .pop ()
128+ return - 1
129+ elif unified == 'END CASE' :
130+ if self ._block_stack and self ._block_stack [- 1 ] == 'CASE' :
131+ self ._block_stack .pop ()
132+ return - 1
133+ elif unified == 'END' :
134+ if self ._block_stack :
135+ if self ._block_stack [- 1 ] in ('CASE' , 'BEGIN' ):
136+ self ._block_stack .pop ()
137+ return - 1
138+ else :
139+ self ._block_stack .pop ()
140+ return - 1
141+ else :
142+ return - 1
91143
92- # Default
93144 return 0
94145
95146 def process (self , stream ):
@@ -125,10 +176,12 @@ def process(self, stream):
125176 # If we just saw BEGIN; then this is a transaction BEGIN,
126177 # not a BEGIN...END block, so decrement depth
127178 if self ._seen_begin :
128- self ._begin_depth = max (0 , self ._begin_depth - 1 )
179+ if self ._block_stack and self ._block_stack [- 1 ] == 'BEGIN' :
180+ self ._block_stack .pop ()
181+ self .level = max (0 , self .level - 1 )
129182 self ._seen_begin = False
130183 # Split on semicolon if not inside a BEGIN...END block
131- if self .level <= 0 and self . _begin_depth == 0 :
184+ if self .level <= 0 and 'BEGIN' not in self . _block_stack :
132185 self .consume_ws = True
133186 elif ttype is T .Keyword and value .split ()[0 ] == 'GO' :
134187 self .consume_ws = True
0 commit comments