-
Notifications
You must be signed in to change notification settings - Fork 124
Expand file tree
/
Copy pathquery_service.py
More file actions
213 lines (167 loc) · 7.19 KB
/
Copy pathquery_service.py
File metadata and controls
213 lines (167 loc) · 7.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""Query execution service for sqlit.
This module provides a unified query execution service used by both
the TUI and CLI to ensure consistent behavior.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from sqlit.domains.connections.domain.config import ConnectionConfig
from sqlit.shared.core.protocols import HistoryStoreProtocol, QueryExecutorProtocol
# Query types that return result sets (SELECT-like queries)
SELECT_KEYWORDS = frozenset(["SELECT", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "PRAGMA"])
# DML statements that may carry a RETURNING clause — when present they produce a result set.
_DML_KEYWORDS = frozenset(["INSERT", "UPDATE", "DELETE", "MERGE"])
_RETURNING_RE = re.compile(r"(?is)\bRETURNING\b\s+\S")
# Regex for parsing USE database statements
# Matches: USE dbname, USE [dbname], USE `dbname`, USE "dbname"
_USE_PATTERN = re.compile(
r"^\s*USE\s+"
r"(?:"
r"\[([^\]]+)\]" # [bracketed] - SQL Server style
r"|`([^`]+)`" # `backtick` - MySQL style
r"|\"([^\"]+)\"" # "quoted" - standard SQL style
r"|(\w+)" # unquoted identifier
r")"
r"\s*;?\s*$",
re.IGNORECASE,
)
def parse_use_statement(query: str) -> str | None:
"""Parse a USE database statement and return the database name.
Supports various quoting styles:
- USE mydb
- USE [mydb] (SQL Server)
- USE `mydb` (MySQL)
- USE "mydb"
Args:
query: The SQL query string.
Returns:
The database name if this is a USE statement, None otherwise.
"""
match = _USE_PATTERN.match(query)
if not match:
return None
# Return first non-None group (the captured database name)
return next((g for g in match.groups() if g is not None), None)
class QueryKind(Enum):
RETURNS_ROWS = "returns_rows"
NON_QUERY = "non_query"
class QueryAnalyzer(Protocol):
def classify(self, query: str) -> QueryKind: ...
class KeywordQueryAnalyzer:
def classify(self, query: str) -> QueryKind:
"""Classify query based on keyword of the last statement.
For multi-statement queries like 'BEGIN; INSERT...; SELECT * FROM t;',
we check the last statement to determine if results should be returned.
Uses the same splitting logic as multi_statement.split_statements.
"""
from sqlit.domains.query.editing.comments import (
is_comment_line,
is_comment_only_statement,
)
from .multi_statement import split_statements
statements = split_statements(query)
if not statements:
return QueryKind.NON_QUERY
# Filter out comment-only statements and find the last actual SQL statement
for stmt in reversed(statements):
if is_comment_only_statement(stmt):
continue
# Found a statement with actual SQL - get first non-comment line
lines = [line.strip() for line in stmt.split("\n") if line.strip()]
non_comment_lines = [line for line in lines if not is_comment_line(line)]
if non_comment_lines:
first_line = non_comment_lines[0].upper()
first_word = first_line.split()[0] if first_line else ""
if first_word in SELECT_KEYWORDS:
return QueryKind.RETURNS_ROWS
# DML with a RETURNING clause produces a result set too.
if first_word in _DML_KEYWORDS and _RETURNING_RE.search(stmt):
return QueryKind.RETURNS_ROWS
return QueryKind.NON_QUERY
return QueryKind.NON_QUERY
class DialectQueryAnalyzer:
def __init__(self, dialect: Any, fallback: QueryAnalyzer | None = None) -> None:
self._dialect = dialect
self._fallback = fallback or KeywordQueryAnalyzer()
def classify(self, query: str) -> QueryKind:
classifier = getattr(self._dialect, "classify_query", None)
if callable(classifier):
result = classifier(query)
if isinstance(result, QueryKind):
return result
if isinstance(result, bool):
return QueryKind.RETURNS_ROWS if result else QueryKind.NON_QUERY
return self._fallback.classify(query)
@dataclass
class QueryResult:
"""Result of a SELECT-type query execution."""
columns: list[str]
rows: list[tuple]
row_count: int
truncated: bool
@dataclass
class NonQueryResult:
"""Result of a non-SELECT query execution (INSERT, UPDATE, DELETE, etc.)."""
rows_affected: int
class QueryService:
"""Service for executing database queries.
This service provides a unified interface for query execution,
handling query type detection, execution, and optional history saving.
Args:
history_store: History store for saving queries.
analyzer: Query analyzer strategy for selecting execution behavior.
"""
def __init__(self, history_store: HistoryStoreProtocol | None = None, analyzer: QueryAnalyzer | None = None):
if history_store is None:
from sqlit.domains.query.store.memory import InMemoryHistoryStore
history_store = InMemoryHistoryStore()
self._history_store = history_store
self._analyzer = analyzer or KeywordQueryAnalyzer()
def execute(
self,
connection: Any,
executor: QueryExecutorProtocol,
query: str,
config: ConnectionConfig | None = None,
max_rows: int | None = None,
save_to_history: bool = True,
) -> QueryResult | NonQueryResult:
"""Execute a query and optionally save to history.
Args:
connection: The database connection object.
executor: The query executor to use for execution.
query: The SQL query string to execute.
config: Optional connection config (needed for history saving).
max_rows: Optional maximum rows to fetch for SELECT queries.
save_to_history: Whether to save the query to history.
Returns:
QueryResult for SELECT-type queries, NonQueryResult otherwise.
Raises:
Any exceptions raised by the underlying database driver.
"""
result: QueryResult | NonQueryResult
if self._analyzer.classify(query) == QueryKind.RETURNS_ROWS:
columns, rows, truncated = executor.execute_query(connection, query, max_rows)
result = QueryResult(
columns=columns,
rows=list(rows),
row_count=len(rows),
truncated=truncated,
)
else:
affected = executor.execute_non_query(connection, query)
result = NonQueryResult(rows_affected=affected)
# Save to history if requested and config is available
if save_to_history and config:
self._save_to_history(config.name, query)
return result
def _save_to_history(self, connection_name: str, query: str) -> None:
"""Save a query to history.
Args:
connection_name: The name of the connection.
query: The query string to save.
"""
self._history_store.save_query(connection_name, query)