Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion sqlit/domains/query/app/multi_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,33 @@
def _iter_sql_chars(sql: str) -> Iterator[tuple[int, str, bool]]:
"""Iterate through SQL characters, tracking string literal context.

Handles escape sequences (backslash) and SQL-style doubled quotes.
Handles escape sequences (backslash), SQL-style doubled quotes,
and PostgreSQL dollar-quoted strings ($$ or $tag$).

Yields:
(index, char, outside_string) tuples where outside_string is True
when the character is not inside a string literal.
"""
in_single_quote = False
in_double_quote = False
in_dollar_tag: str | None = None
i = 0

while i < len(sql):
# If inside a dollar-quoted string, check for the closing tag
if in_dollar_tag is not None:
if sql[i:].startswith(in_dollar_tag):
# Yield the characters of the closing tag as inside string
for offset in range(len(in_dollar_tag)):
yield (i + offset, sql[i + offset], False)
i += len(in_dollar_tag)
in_dollar_tag = None
continue
else:
yield (i, sql[i], False)
i += 1
continue

char = sql[i]

# Handle escape sequences in strings
Expand All @@ -51,6 +67,18 @@ def _iter_sql_chars(sql: str) -> Iterator[tuple[int, str, bool]]:
i += 2
continue

# Check for PostgreSQL dollar-quoted string start
if char == "$" and not in_single_quote and not in_double_quote:
# Match $[a-zA-Z_][a-zA-Z0-9_]*$ or $$
match = re.match(r"^\$([a-zA-Z_][a-zA-Z0-9_]*)?\$", sql[i:])
if match:
delimiter = match.group(0)
in_dollar_tag = delimiter
for offset in range(len(delimiter)):
yield (i + offset, sql[i + offset], False)
i += len(delimiter)
continue

# Toggle quote state and yield
if char == "'" and not in_double_quote:
in_single_quote = not in_single_quote
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test_multi_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,55 @@ def test_handles_multiline_statements(self):

assert len(statements) == 2

def test_preserves_semicolons_in_dollar_quoted_strings(self):
"""Should not split on semicolons inside dollar-quoted strings."""
from sqlit.domains.query.app.multi_statement import split_statements

query = """
CREATE OR REPLACE FUNCTION example()
RETURNS void AS $$
BEGIN
INSERT INTO t (x) VALUES ('a;b');
END;
$$ LANGUAGE plpgsql;
SELECT 1;
"""
statements = split_statements(query)

assert len(statements) == 2
assert "CREATE OR REPLACE FUNCTION" in statements[0]
assert "SELECT 1" in statements[1]

def test_preserves_semicolons_in_named_dollar_quoted_strings(self):
"""Should not split on semicolons inside named dollar-quoted strings."""
from sqlit.domains.query.app.multi_statement import split_statements

query = """
CREATE OR REPLACE FUNCTION example()
RETURNS void AS $func_tag$
BEGIN
INSERT INTO t (x) VALUES ('a;b');
END;
$func_tag$ LANGUAGE plpgsql;
SELECT 1;
"""
statements = split_statements(query)

assert len(statements) == 2
assert "CREATE OR REPLACE FUNCTION" in statements[0]
assert "SELECT 1" in statements[1]

def test_dollar_quotes_inside_standard_strings_are_ignored(self):
"""Should ignore dollar quote delimiters when inside standard string literals."""
from sqlit.domains.query.app.multi_statement import split_statements

query = "INSERT INTO t (x) VALUES ('$$'); SELECT 1"
statements = split_statements(query)

assert len(statements) == 2
assert "INSERT" in statements[0]
assert "SELECT 1" in statements[1]


class TestMultiStatementResult:
"""Tests for MultiStatementResult data structure."""
Expand Down
Loading