Skip to content

Commit 743bbc5

Browse files
committed
Add basic translation->execution stack
1 parent ed2dcdc commit 743bbc5

17 files changed

Lines changed: 3101 additions & 341 deletions

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,15 @@ build-backend = "hatchling.build"
1818

1919
[tool.pytest.ini_options]
2020
testpaths = ["tests"]
21+
addopts = "--cov=sql_redis --cov-report=term-missing"
22+
23+
[tool.coverage.run]
24+
source = ["sql_redis"]
25+
branch = true
26+
27+
[tool.coverage.report]
28+
exclude_lines = [
29+
"pragma: no cover",
30+
"raise NotImplementedError",
31+
]
2132

sql_redis/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""SQL to Redis command translation utility."""
22

3-
from sql_redis.translator import translate_sql
3+
from sql_redis.translator import Translator, TranslatedQuery
44

5-
__all__ = ["translate_sql"]
5+
__all__ = ["Translator", "TranslatedQuery"]
66

sql_redis/analyzer.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,91 @@ def get_conditions_by_type(self, field_type: str) -> list[Condition]:
4040

4141
class Analyzer:
4242
"""Analyzes parsed SQL queries with schema context."""
43-
43+
4444
def __init__(self, schemas: dict[str, dict[str, str]]):
4545
"""Initialize analyzer with schema registry data.
46-
46+
4747
Args:
4848
schemas: Dict mapping index names to field->type dicts.
4949
"""
5050
self._schemas = schemas
51-
51+
5252
def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery:
5353
"""Analyze a parsed query, resolving field types.
54-
54+
5555
Args:
5656
parsed: The parsed SQL query.
57-
57+
5858
Returns:
5959
An AnalyzedQuery with field types resolved.
60-
60+
6161
Raises:
6262
ValueError: If the index or a field is unknown.
6363
"""
64-
raise NotImplementedError("Analyzer.analyze is not yet implemented")
64+
# Validate index exists
65+
if parsed.index not in self._schemas:
66+
raise ValueError(f"Unknown index: {parsed.index}")
67+
68+
schema = self._schemas[parsed.index]
69+
result = AnalyzedQuery(parsed=parsed)
70+
71+
# Collect all fields referenced in the query
72+
referenced_fields: set[str] = set()
73+
74+
# Fields from SELECT
75+
for field_name in parsed.fields:
76+
if field_name != "*":
77+
referenced_fields.add(field_name)
78+
79+
# Fields from conditions
80+
for condition in parsed.conditions:
81+
referenced_fields.add(condition.field)
82+
83+
# Fields from aggregations
84+
for agg in parsed.aggregations:
85+
if agg.field:
86+
referenced_fields.add(agg.field)
87+
88+
# Fields from computed fields (extract field references from expressions)
89+
for computed in parsed.computed_fields:
90+
# Simple extraction - look for field names in the expression
91+
for field_name in schema.keys():
92+
if field_name in computed.expression:
93+
referenced_fields.add(field_name)
94+
95+
# Fields from vector search
96+
if parsed.vector_search:
97+
referenced_fields.add(parsed.vector_search.field)
98+
99+
# Fields from GROUP BY
100+
for field_name in parsed.groupby_fields:
101+
referenced_fields.add(field_name)
102+
103+
# Resolve field types
104+
for field_name in referenced_fields:
105+
if field_name not in schema:
106+
raise ValueError(f"Unknown field: {field_name}")
107+
result.field_types[field_name] = schema[field_name]
108+
109+
# Copy aggregations and computed fields
110+
result.aggregations = parsed.aggregations
111+
result.computed_fields = parsed.computed_fields
112+
result.groupby_fields = parsed.groupby_fields
113+
114+
# Determine if this is a global aggregation
115+
result.is_global_aggregation = (
116+
len(parsed.aggregations) > 0 and len(parsed.groupby_fields) == 0
117+
)
118+
119+
# Analyze vector search
120+
if parsed.vector_search:
121+
result.vector_search = VectorSearchAnalysis(
122+
field=parsed.vector_search.field,
123+
k=parsed.limit or parsed.vector_search.k or 10,
124+
alias=parsed.vector_search.alias,
125+
)
126+
# Has prefilter if there are conditions
127+
result.has_prefilter = len(parsed.conditions) > 0
128+
129+
return result
65130

sql_redis/executor.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""SQL Executor - executes translated queries against Redis."""
2+
3+
from dataclasses import dataclass
4+
5+
import redis
6+
7+
from sql_redis.schema import SchemaRegistry
8+
from sql_redis.translator import Translator
9+
10+
11+
@dataclass
12+
class QueryResult:
13+
"""Result of executing a SQL query."""
14+
15+
rows: list[dict]
16+
count: int
17+
18+
19+
class Executor:
20+
"""Executes SQL queries against Redis."""
21+
22+
def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry):
23+
"""Initialize executor with Redis client and schema registry."""
24+
self._client = client
25+
self._schema_registry = schema_registry
26+
self._translator = Translator(schema_registry)
27+
28+
def execute(
29+
self, sql: str, *, params: dict | None = None
30+
) -> QueryResult:
31+
"""Execute a SQL query and return results."""
32+
params = params or {}
33+
34+
# Substitute non-bytes params in SQL
35+
for key, value in params.items():
36+
placeholder = f":{key}"
37+
if isinstance(value, (int, float)):
38+
sql = sql.replace(placeholder, str(value))
39+
elif isinstance(value, str):
40+
sql = sql.replace(placeholder, f"'{value}'")
41+
# bytes (vectors) are handled via Redis PARAMS
42+
43+
# Translate SQL to Redis command
44+
translated = self._translator.translate(sql)
45+
46+
# Build command list and substitute vector params
47+
cmd = list(translated.to_command_list())
48+
49+
# Find any bytes params (vectors) to substitute
50+
vector_param = None
51+
for value in params.values():
52+
if isinstance(value, bytes):
53+
vector_param = value
54+
break
55+
56+
# Replace $vector placeholder with actual bytes
57+
if vector_param:
58+
for i, arg in enumerate(cmd):
59+
if arg == "$vector":
60+
cmd[i] = vector_param
61+
62+
# Execute command
63+
raw_result = self._client.execute_command(*cmd)
64+
65+
# Parse result based on command type
66+
count = raw_result[0] if raw_result else 0
67+
rows = []
68+
69+
if translated.command == "FT.SEARCH":
70+
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
71+
# Skip document keys (odd indices), take field lists (even indices after count)
72+
for i in range(2, len(raw_result), 2):
73+
row_data = raw_result[i]
74+
row = dict(zip(row_data[::2], row_data[1::2]))
75+
rows.append(row)
76+
else:
77+
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
78+
for row_data in raw_result[1:]:
79+
row = dict(zip(row_data[::2], row_data[1::2]))
80+
rows.append(row)
81+
82+
return QueryResult(rows=rows, count=count)
83+

0 commit comments

Comments
 (0)