1010
1111import re
1212from dataclasses import dataclass , field
13- from typing import TYPE_CHECKING , Any
13+ from typing import TYPE_CHECKING , Any , Iterator
1414
1515if TYPE_CHECKING :
1616 from .query_service import NonQueryResult , QueryResult
1717
1818
19- def _has_semicolon_outside_strings (sql : str ) -> bool :
20- """Check if SQL has semicolons outside of string literals."""
19+ def _iter_sql_chars (sql : str ) -> Iterator [tuple [int , str , bool ]]:
20+ """Iterate through SQL characters, tracking string literal context.
21+
22+ Handles escape sequences (backslash) and SQL-style doubled quotes.
23+
24+ Yields:
25+ (index, char, outside_string) tuples where outside_string is True
26+ when the character is not inside a string literal.
27+ """
2128 in_single_quote = False
2229 in_double_quote = False
2330 i = 0
2431
2532 while i < len (sql ):
2633 char = sql [i ]
2734
28- # Handle escape sequences
35+ # Handle escape sequences in strings
2936 if i + 1 < len (sql ) and char == "\\ " and (in_single_quote or in_double_quote ):
37+ yield (i , char , False )
38+ yield (i + 1 , sql [i + 1 ], False )
3039 i += 2
3140 continue
3241
33- # Handle doubled quotes
42+ # Handle doubled quotes (SQL escape for quotes)
3443 if char == "'" and i + 1 < len (sql ) and sql [i + 1 ] == "'" and in_single_quote :
44+ yield (i , "'" , False )
45+ yield (i + 1 , "'" , False )
3546 i += 2
3647 continue
3748 if char == '"' and i + 1 < len (sql ) and sql [i + 1 ] == '"' and in_double_quote :
49+ yield (i , '"' , False )
50+ yield (i + 1 , '"' , False )
3851 i += 2
3952 continue
4053
41- # Toggle quote state
54+ # Toggle quote state and yield
4255 if char == "'" and not in_double_quote :
4356 in_single_quote = not in_single_quote
57+ yield (i , char , False ) # Quote char is part of string syntax
4458 elif char == '"' and not in_single_quote :
4559 in_double_quote = not in_double_quote
46- elif char == ";" and not in_single_quote and not in_double_quote :
47- return True
60+ yield (i , char , False ) # Quote char is part of string syntax
61+ else :
62+ yield (i , char , not in_single_quote and not in_double_quote )
4863
4964 i += 1
5065
66+
67+ def _has_semicolon_outside_strings (sql : str ) -> bool :
68+ """Check if SQL has semicolons outside of string literals."""
69+ for _ , char , outside in _iter_sql_chars (sql ):
70+ if char == ";" and outside :
71+ return True
5172 return False
5273
5374
5475def _split_by_semicolons (sql : str ) -> list [str ]:
5576 """Split SQL by semicolons, respecting string literals."""
5677 statements = []
57- current = []
58- in_single_quote = False
59- in_double_quote = False
60- i = 0
61-
62- while i < len (sql ):
63- char = sql [i ]
64-
65- # Handle escape sequences in strings
66- if i + 1 < len (sql ) and char == "\\ " and (in_single_quote or in_double_quote ):
67- current .append (char )
68- current .append (sql [i + 1 ])
69- i += 2
70- continue
71-
72- # Handle doubled quotes (SQL escape for quotes)
73- if char == "'" and i + 1 < len (sql ) and sql [i + 1 ] == "'" and in_single_quote :
74- current .append ("''" )
75- i += 2
76- continue
78+ current : list [str ] = []
7779
78- if char == '"' and i + 1 < len (sql ) and sql [i + 1 ] == '"' and in_double_quote :
79- current .append ('""' )
80- i += 2
81- continue
82-
83- # Toggle quote state
84- if char == "'" and not in_double_quote :
85- in_single_quote = not in_single_quote
86- current .append (char )
87- elif char == '"' and not in_single_quote :
88- in_double_quote = not in_double_quote
89- current .append (char )
90- elif char == ";" and not in_single_quote and not in_double_quote :
91- # End of statement
80+ for _ , char , outside in _iter_sql_chars (sql ):
81+ if char == ";" and outside :
9282 stmt = "" .join (current ).strip ()
9383 if stmt :
9484 statements .append (stmt )
9585 current = []
9686 else :
9787 current .append (char )
9888
99- i += 1
100-
10189 # Don't forget the last statement (may not end with semicolon)
10290 stmt = "" .join (current ).strip ()
10391 if stmt :
@@ -113,51 +101,20 @@ def _split_by_blank_lines(sql: str) -> list[str]:
113101 This is triggered when there are no semicolons in the query.
114102 """
115103 statements = []
116- current = []
117- in_single_quote = False
118- in_double_quote = False
119- i = 0
104+ current : list [str ] = []
120105 line_start = 0
121106 prev_line_empty = False
122107
123- while i < len (sql ):
124- char = sql [i ]
125-
126- # Handle escape sequences in strings
127- if i + 1 < len (sql ) and char == "\\ " and (in_single_quote or in_double_quote ):
128- current .append (char )
129- current .append (sql [i + 1 ])
130- i += 2
131- continue
132-
133- # Handle doubled quotes (SQL escape for quotes)
134- if char == "'" and i + 1 < len (sql ) and sql [i + 1 ] == "'" and in_single_quote :
135- current .append ("''" )
136- i += 2
137- continue
138-
139- if char == '"' and i + 1 < len (sql ) and sql [i + 1 ] == '"' and in_double_quote :
140- current .append ('""' )
141- i += 2
142- continue
143-
144- # Toggle quote state
145- if char == "'" and not in_double_quote :
146- in_single_quote = not in_single_quote
147- current .append (char )
148- elif char == '"' and not in_single_quote :
149- in_double_quote = not in_double_quote
150- current .append (char )
151- elif char == "\n " and not in_single_quote and not in_double_quote :
152- # Check if this line (from line_start to i) is empty/whitespace
153- line_content = sql [line_start :i ]
108+ for idx , char , outside in _iter_sql_chars (sql ):
109+ if char == "\n " and outside :
110+ line_content = sql [line_start :idx ]
154111 current_line_empty = not line_content .strip ()
155112
156113 if current_line_empty and prev_line_empty :
157- # We have a blank line separator - don't add more newlines
114+ # Consecutive blank lines, skip
158115 pass
159116 elif current_line_empty and current :
160- # This is a blank line after content - split here
117+ # Blank line after content - split here
161118 stmt = "" .join (current ).strip ()
162119 if stmt :
163120 statements .append (stmt )
@@ -167,14 +124,12 @@ def _split_by_blank_lines(sql: str) -> list[str]:
167124 current .append (char )
168125
169126 prev_line_empty = current_line_empty
170- line_start = i + 1
127+ line_start = idx + 1
171128 else :
172129 current .append (char )
173- if char not in " \t " :
130+ if char not in " \t \n " :
174131 prev_line_empty = False
175132
176- i += 1
177-
178133 # Don't forget the last statement
179134 stmt = "" .join (current ).strip ()
180135 if stmt :
@@ -183,6 +138,125 @@ def _split_by_blank_lines(sql: str) -> list[str]:
183138 return statements
184139
185140
141+ def _append_statement_range (
142+ ranges : list [tuple [str , int , int ]], sql : str , stmt_start : int , stmt_end : int
143+ ) -> None :
144+ """Helper to append a statement range, calculating actual positions."""
145+ stmt_full = sql [stmt_start :stmt_end ]
146+ stmt_text = stmt_full .strip ()
147+ if stmt_text :
148+ actual_start = stmt_start + (len (stmt_full ) - len (stmt_full .lstrip ()))
149+ ranges .append ((stmt_text , actual_start , actual_start + len (stmt_text )))
150+
151+
152+ def _get_statement_ranges (sql : str ) -> list [tuple [str , int , int ]]:
153+ """Get statements with their character ranges in the original SQL.
154+
155+ Splitting strategy (matches split_statements):
156+ 1. If query contains semicolons (outside strings) → split by semicolons
157+ 2. If no semicolons but has blank lines → split by blank lines
158+ 3. Otherwise → return as single statement
159+
160+ Returns:
161+ List of (statement_text, start_offset, end_offset) tuples.
162+ Offsets are 0-based character positions in the original SQL string.
163+ """
164+ if not sql or not sql .strip ():
165+ return []
166+
167+ ranges : list [tuple [str , int , int ]] = []
168+
169+ # Strategy 1: If semicolons exist, use semicolon splitting with tracking
170+ if _has_semicolon_outside_strings (sql ):
171+ stmt_start = 0
172+
173+ for idx , char , outside in _iter_sql_chars (sql ):
174+ if char == ";" and outside :
175+ _append_statement_range (ranges , sql , stmt_start , idx )
176+ stmt_start = idx + 1
177+
178+ _append_statement_range (ranges , sql , stmt_start , len (sql ))
179+ return ranges
180+
181+ # Strategy 2: If blank lines exist, use blank line splitting with tracking
182+ if re .search (r"\n\s*\n" , sql ):
183+ stmt_start = 0
184+ line_start = 0
185+ prev_line_empty = False
186+
187+ for idx , char , outside in _iter_sql_chars (sql ):
188+ if char == "\n " and outside :
189+ line_content = sql [line_start :idx ]
190+ current_line_empty = not line_content .strip ()
191+
192+ if current_line_empty and prev_line_empty :
193+ # Consecutive blank lines, skip
194+ pass
195+ elif current_line_empty :
196+ # Blank line after content - this is a statement boundary
197+ _append_statement_range (ranges , sql , stmt_start , idx )
198+ stmt_start = idx + 1
199+
200+ prev_line_empty = current_line_empty
201+ line_start = idx + 1
202+ elif char not in " \t \n " :
203+ prev_line_empty = False
204+
205+ _append_statement_range (ranges , sql , stmt_start , len (sql ))
206+ return ranges
207+
208+ # Strategy 3: Single statement
209+ stripped = sql .strip ()
210+ if stripped :
211+ start_offset = len (sql ) - len (sql .lstrip ())
212+ return [(stripped , start_offset , len (sql ))]
213+
214+ return []
215+
216+
217+ def find_statement_at_cursor (sql : str , row : int , col : int ) -> tuple [str , int , int ] | None :
218+ """Find the SQL statement containing the cursor position.
219+
220+ Args:
221+ sql: Full SQL text (may contain multiple statements).
222+ row: Cursor row (0-based line number).
223+ col: Cursor column (0-based character position within the line).
224+
225+ Returns:
226+ Tuple of (statement_text, start_char_offset, end_char_offset) or None if not found.
227+ """
228+ if not sql :
229+ return None
230+
231+ # Convert (row, col) to absolute character offset
232+ lines = sql .split ("\n " )
233+ if row >= len (lines ):
234+ # Cursor is past end of text, use last position
235+ cursor_offset = len (sql )
236+ else :
237+ # Sum lengths of all previous lines plus newline characters
238+ cursor_offset = sum (len (lines [i ]) + 1 for i in range (row )) + col
239+
240+ ranges = _get_statement_ranges (sql )
241+
242+ if not ranges :
243+ return None
244+
245+ # Find the statement containing the cursor
246+ for stmt_text , start , end in ranges :
247+ if start <= cursor_offset <= end :
248+ return (stmt_text , start , end )
249+
250+ # If cursor is between statements or at the very end,
251+ # return the nearest preceding statement
252+ for stmt_text , start , end in reversed (ranges ):
253+ if cursor_offset >= start :
254+ return (stmt_text , start , end )
255+
256+ # Fallback to first statement
257+ return ranges [0 ] if ranges else None
258+
259+
186260def split_statements (sql : str ) -> list [str ]:
187261 """Split SQL into individual statements.
188262
0 commit comments