-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexecutor.py
More file actions
332 lines (269 loc) · 12.3 KB
/
Copy pathexecutor.py
File metadata and controls
332 lines (269 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""SQL Executor - executes translated queries against Redis."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast
import redis
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
from sql_redis.translator import Translator
if TYPE_CHECKING:
import redis.asyncio as async_redis
SchemaCacheStrategy = Literal["lazy", "load_all"]
def _validate_schema_cache_strategy(
schema_cache_strategy: str,
) -> SchemaCacheStrategy:
"""Validate and normalize the schema cache strategy."""
if schema_cache_strategy not in {"lazy", "load_all"}:
raise ValueError("schema_cache_strategy must be one of: 'lazy', 'load_all'")
return cast(SchemaCacheStrategy, schema_cache_strategy)
def _substitute_params(sql: str, params: dict[str, Any]) -> str:
"""Substitute parameter placeholders in SQL with actual values.
This is a pure function with no I/O operations, shared by both
sync and async executors.
Uses token-based approach: splits SQL on :param patterns, then rebuilds
with substituted values. This approach solves two critical bugs:
1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
by treating each :identifier as a complete token
2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
using SQL standard (single quote -> double single quote)
Args:
sql: The SQL string with :param placeholders.
params: Dictionary mapping parameter names to values.
Returns:
SQL string with parameters substituted.
Implementation Details:
- Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
- Keeps delimiters (the :param tokens) in the split result
- Iterates through tokens, substituting matched parameters
- String values are wrapped in single quotes with proper escaping
- Numeric values are converted to strings
- Bytes values (e.g., vectors) are NOT substituted here
Known Limitations:
- Colons in string literals: SQL like "WHERE x = 'test:value'" would
theoretically match :value as a parameter. However, this is not a
practical issue because:
1. Users pass values via parameters, not hardcoded in SQL
2. The translator has its own handling of string literals
3. No real-world use cases have been identified
- Parameter names are case-sensitive (:id != :ID)
- Only handles int, float, str types; other types keep placeholder
"""
if not params:
return sql
# Split SQL on :param patterns, keeping the delimiters
# Pattern matches : followed by valid identifier:
# [a-zA-Z_] - First char must be letter or underscore
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
# This prevents partial matching: :id and :product_id are separate tokens
tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)
result = []
for token in tokens:
if token.startswith(":"):
# This is a parameter placeholder
key = token[1:] # Remove leading :
if key in params:
value = params[key]
if isinstance(value, (int, float)):
# Numeric values: convert to string
result.append(str(value))
elif isinstance(value, str):
# String values: wrap in quotes and escape single quotes
# SQL standard: ' -> '' (double single quote)
# This fixes the quote escaping bug
escaped = value.replace("'", "''")
result.append(f"'{escaped}'")
else:
# Other types (bytes, None, bool, list, etc.):
# Keep placeholder as-is (handled elsewhere or unsupported)
result.append(token)
else:
# Parameter not provided: keep placeholder as-is
result.append(token)
else:
# Not a parameter: keep as-is
result.append(token)
return "".join(result)
@dataclass
class QueryResult:
"""Result of executing a SQL query."""
rows: list[dict]
count: int
class Executor:
"""Executes SQL queries against Redis."""
def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry) -> None:
"""Initialize executor with Redis client and schema registry."""
self._client = client
self._schema_registry = schema_registry
self._translator = Translator(schema_registry)
def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
"""Execute a SQL query and return results."""
params = params or {}
# Substitute non-bytes params in SQL using token-based approach
sql = _substitute_params(sql, params)
# Translate SQL to Redis command
translated = self._translator.translate(sql)
# Build command list and substitute vector params
# Use list[str | bytes] to allow bytes for vector params
cmd: list[str | bytes] = list(translated.to_command_list())
# Find any bytes params (vectors) to substitute
vector_param: bytes | None = None
for value in params.values():
if isinstance(value, bytes):
vector_param = value
break
# Replace $vector placeholder with actual bytes
if vector_param:
for i, arg in enumerate(cmd):
if arg == "$vector":
cmd[i] = vector_param
# Execute command
try:
raw_result = self._client.execute_command(*cmd)
except redis.ResponseError as e:
error_msg = str(e)
_ismissing_signatures = (
"Unknown function",
"No such function",
"Syntax error",
"INDEXMISSING",
)
if "ismissing(@" in translated.query_string and any(
sig in error_msg for sig in _ismissing_signatures
):
raise redis.ResponseError(
f"{error_msg}. This error may be caused by use of the "
"ismissing() function. ismissing() requires Redis 7.4+ "
"(RediSearch 2.10+) and the field must have INDEXMISSING "
"declared in the schema."
) from e
raise
# Parse result based on command type
count = raw_result[0] if raw_result else 0
rows = []
if translated.command == "FT.SEARCH":
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
# Skip document keys (odd indices), take field lists (even indices after count)
for i in range(2, len(raw_result), 2):
row_data = raw_result[i]
row = dict(zip(row_data[::2], row_data[1::2]))
rows.append(row)
else:
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
for row_data in raw_result[1:]:
row = dict(zip(row_data[::2], row_data[1::2]))
rows.append(row)
return QueryResult(rows=rows, count=count)
class AsyncExecutor:
"""Async version of Executor for use with redis.asyncio clients."""
def __init__(
self,
client: "async_redis.Redis",
schema_registry: AsyncSchemaRegistry,
) -> None:
"""Initialize async executor with Redis client and schema registry.
Args:
client: An async Redis client (redis.asyncio.Redis).
schema_registry: An AsyncSchemaRegistry instance.
"""
self._client = client
self._schema_registry = schema_registry
self._translator = Translator(schema_registry)
async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
"""Execute a SQL query asynchronously and return results."""
params = params or {}
# Substitute non-bytes params in SQL
sql = _substitute_params(sql, params)
# Parse once, ensure schema is loaded (async lazy-load), then
# translate from the pre-parsed result to avoid double-parsing.
parsed = self._translator.parse(sql)
if parsed.index:
await self._schema_registry.ensure_schema(parsed.index)
# Translate from pre-parsed query (sync - no Redis calls)
translated = self._translator.translate_parsed(parsed)
# Build command list and substitute vector params
cmd: list[str | bytes] = list(translated.to_command_list())
# Find any bytes params (vectors) to substitute
vector_param: bytes | None = None
for value in params.values():
if isinstance(value, bytes):
vector_param = value
break
# Replace $vector placeholder with actual bytes
if vector_param:
for i, arg in enumerate(cmd):
if arg == "$vector":
cmd[i] = vector_param
# Execute command asynchronously
try:
raw_result = await self._client.execute_command(*cmd)
except redis.ResponseError as e:
error_msg = str(e)
_ismissing_signatures = (
"Unknown function",
"No such function",
"Syntax error",
"INDEXMISSING",
)
if "ismissing(@" in translated.query_string and any(
sig in error_msg for sig in _ismissing_signatures
):
raise redis.ResponseError(
f"{error_msg}. This error may be caused by use of the "
"ismissing() function. ismissing() requires Redis 7.4+ "
"(RediSearch 2.10+) and the field must have INDEXMISSING "
"declared in the schema."
) from e
raise
# Parse result based on command type
count = raw_result[0] if raw_result else 0
rows = []
if translated.command == "FT.SEARCH":
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
for i in range(2, len(raw_result), 2):
row_data = raw_result[i]
row = dict(zip(row_data[::2], row_data[1::2]))
rows.append(row)
else:
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
for row_data in raw_result[1:]:
row = dict(zip(row_data[::2], row_data[1::2]))
rows.append(row)
return QueryResult(rows=rows, count=count)
def create_executor(
client: redis.Redis,
*,
schema_registry: SchemaRegistry | None = None,
schema_cache_strategy: SchemaCacheStrategy = "lazy",
) -> Executor:
"""Create a sync SQL executor with the requested schema cache strategy.
Args:
client: Redis client used by the executor.
schema_registry: Optional existing registry to reuse.
schema_cache_strategy: Schema loading strategy. ``"lazy"`` defers
``FT.INFO`` calls until a referenced index is needed. ``"load_all"``
preserves the historical eager behavior by preloading all schemas.
"""
schema_cache_strategy = _validate_schema_cache_strategy(schema_cache_strategy)
registry = schema_registry or SchemaRegistry(client)
if schema_cache_strategy == "load_all":
registry.load_all()
return Executor(client, registry)
async def create_async_executor(
client: "async_redis.Redis",
*,
schema_registry: AsyncSchemaRegistry | None = None,
schema_cache_strategy: SchemaCacheStrategy = "lazy",
) -> AsyncExecutor:
"""Create an async SQL executor with the requested schema cache strategy.
Args:
client: Async Redis client used by the executor.
schema_registry: Optional existing async registry to reuse.
schema_cache_strategy: Schema loading strategy. ``"lazy"`` defers
``FT.INFO`` calls until a referenced index is needed. ``"load_all"``
preserves the historical eager behavior by preloading all schemas.
"""
schema_cache_strategy = _validate_schema_cache_strategy(schema_cache_strategy)
registry = schema_registry or AsyncSchemaRegistry(client)
if schema_cache_strategy == "load_all":
await registry.load_all()
return AsyncExecutor(client, registry)