Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sqlit/domains/connections/providers/sqlite/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> t
else:
rows = cursor.fetchall()
truncated = False
# DML with RETURNING produces a result set but also writes — persist it.
if conn.in_transaction:
conn.commit()
return columns, [tuple(row) for row in rows], truncated
return [], [], False

Expand Down
11 changes: 10 additions & 1 deletion sqlit/domains/query/app/query_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
# Query types that return result sets (SELECT-like queries)
SELECT_KEYWORDS = frozenset(["SELECT", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "PRAGMA"])

# DML statements that may carry a RETURNING clause — when present they produce a result set.
_DML_KEYWORDS = frozenset(["INSERT", "UPDATE", "DELETE", "MERGE"])
_RETURNING_RE = re.compile(r"(?is)\bRETURNING\b\s+\S")

# Regex for parsing USE database statements
# Matches: USE dbname, USE [dbname], USE `dbname`, USE "dbname"
_USE_PATTERN = re.compile(
Expand Down Expand Up @@ -93,7 +97,12 @@ def classify(self, query: str) -> QueryKind:
if non_comment_lines:
first_line = non_comment_lines[0].upper()
first_word = first_line.split()[0] if first_line else ""
return QueryKind.RETURNS_ROWS if first_word in SELECT_KEYWORDS else QueryKind.NON_QUERY
if first_word in SELECT_KEYWORDS:
return QueryKind.RETURNS_ROWS
# DML with a RETURNING clause produces a result set too.
if first_word in _DML_KEYWORDS and _RETURNING_RE.search(stmt):
return QueryKind.RETURNS_ROWS
return QueryKind.NON_QUERY

return QueryKind.NON_QUERY

Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_sqlite_returning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Regression tests for issue #147: SQLite UPDATE ... RETURNING crashes on commit."""

from __future__ import annotations

import sqlite3
from pathlib import Path

import pytest

from sqlit.domains.connections.providers.sqlite.adapter import SQLiteAdapter
from sqlit.domains.query.app.query_service import KeywordQueryAnalyzer, QueryKind


@pytest.fixture
def jobs_db(tmp_path: Path) -> Path:
"""A tiny SQLite DB with a `jobs` table for RETURNING tests."""
db = tmp_path / "jobs.db"
conn = sqlite3.connect(str(db))
conn.execute("CREATE TABLE jobs (id INTEGER PRIMARY KEY, status TEXT)")
conn.executemany("INSERT INTO jobs (id, status) VALUES (?, ?)", [(1, "new"), (2, "new")])
conn.commit()
conn.close()
return db


def test_classifier_recognizes_update_returning_as_returns_rows():
"""`UPDATE ... RETURNING` produces a result set, so the analyzer must classify it as RETURNS_ROWS."""
analyzer = KeywordQueryAnalyzer()
sql = "UPDATE jobs SET status = status WHERE id = 1 RETURNING id"
assert analyzer.classify(sql) == QueryKind.RETURNS_ROWS


def test_classifier_recognizes_insert_returning_as_returns_rows():
analyzer = KeywordQueryAnalyzer()
sql = "INSERT INTO jobs (id, status) VALUES (3, 'new') RETURNING id"
assert analyzer.classify(sql) == QueryKind.RETURNS_ROWS


def test_classifier_recognizes_delete_returning_as_returns_rows():
analyzer = KeywordQueryAnalyzer()
sql = "DELETE FROM jobs WHERE id = 1 RETURNING id"
assert analyzer.classify(sql) == QueryKind.RETURNS_ROWS


def test_classifier_plain_update_is_non_query():
"""Plain DML without RETURNING must still be NON_QUERY (sanity check we don't over-correct)."""
analyzer = KeywordQueryAnalyzer()
assert analyzer.classify("UPDATE jobs SET status = 'done'") == QueryKind.NON_QUERY


def test_sqlite_execute_query_runs_update_returning_and_persists(jobs_db: Path):
"""UPDATE ... RETURNING via execute_query must return the row AND persist the change."""
adapter = SQLiteAdapter()
conn = sqlite3.connect(str(jobs_db))
try:
columns, rows, _ = adapter.execute_query(
conn,
"UPDATE jobs SET status = 'done' WHERE id = 1 RETURNING id, status",
)
assert columns == ["id", "status"]
assert rows == [(1, "done")]
finally:
conn.close()

# Verify the write was actually committed by opening a fresh connection.
verify = sqlite3.connect(str(jobs_db))
try:
result = verify.execute("SELECT status FROM jobs WHERE id = 1").fetchone()
assert result == ("done",), f"UPDATE ... RETURNING did not persist; got {result!r}"
finally:
verify.close()
Loading