Skip to content

Commit c36d851

Browse files
authored
Merge pull request #2 from redis-developer/feat/redisvl-integration-tests
Feat/redisvl integration tests
2 parents f002ce9 + 8561708 commit c36d851

4 files changed

Lines changed: 172 additions & 16 deletions

File tree

sql_redis/parser.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from __future__ import annotations
44

5-
from dataclasses import dataclass, field
5+
import dataclasses
6+
from dataclasses import dataclass
67

78
import sqlglot
89
from sqlglot import exp
@@ -15,6 +16,9 @@ class AggregationSpec:
1516
function: str
1617
field: str | None = None
1718
alias: str | None = None
19+
extra_args: list[str] = dataclasses.field(
20+
default_factory=list
21+
) # For reducers like QUANTILE
1822

1923

2024
@dataclass
@@ -49,14 +53,14 @@ class ParsedQuery:
4953
"""Result of parsing a SQL query."""
5054

5155
index: str = ""
52-
fields: list[str] = field(default_factory=list)
53-
conditions: list[Condition] = field(default_factory=list)
56+
fields: list[str] = dataclasses.field(default_factory=list)
57+
conditions: list[Condition] = dataclasses.field(default_factory=list)
5458
boolean_operator: str = "AND"
55-
aggregations: list[AggregationSpec] = field(default_factory=list)
56-
computed_fields: list[ComputedField] = field(default_factory=list)
59+
aggregations: list[AggregationSpec] = dataclasses.field(default_factory=list)
60+
computed_fields: list[ComputedField] = dataclasses.field(default_factory=list)
5761
vector_search: VectorSearchSpec | None = None
58-
groupby_fields: list[str] = field(default_factory=list)
59-
orderby_fields: list[tuple[str, str]] = field(
62+
groupby_fields: list[str] = dataclasses.field(default_factory=list)
63+
orderby_fields: list[tuple[str, str]] = dataclasses.field(
6064
default_factory=list
6165
) # (field, ASC|DESC)
6266
limit: int | None = None
@@ -150,9 +154,28 @@ def _process_select_expression_inner(
150154
result.fields.append(expression.name)
151155
elif isinstance(expression, exp.Star):
152156
result.fields.append("*")
153-
elif isinstance(expression, (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)):
157+
elif isinstance(
158+
expression,
159+
(
160+
exp.Count,
161+
exp.Sum,
162+
exp.Avg,
163+
exp.Min,
164+
exp.Max,
165+
exp.Stddev,
166+
exp.Variance,
167+
exp.FirstValue,
168+
exp.ArrayAgg,
169+
),
170+
):
154171
# Aggregation function
172+
# Map sqlglot function names to Redis reducer names
155173
func_name = expression.key.upper()
174+
redis_func_map = {
175+
"FIRSTVALUE": "FIRST_VALUE",
176+
"ARRAYAGG": "TOLIST",
177+
}
178+
func_name = redis_func_map.get(func_name, func_name)
156179
field_name = None
157180
# Get the field being aggregated (if any)
158181
if expression.this:
@@ -184,10 +207,34 @@ def _process_select_expression_inner(
184207
# - Distance: L2/Euclidean distance
185208
# - CosineDistance: cosine_distance() function
186209
self._process_vector_distance(expression, result, alias)
210+
elif isinstance(expression, exp.Quantile):
211+
# QUANTILE(field, quantile_value) -> REDUCE QUANTILE 2 @field quantile_value
212+
field_name = None
213+
if expression.this and isinstance(expression.this, exp.Column):
214+
field_name = expression.this.name
215+
quantile_value = None
216+
if expression.args.get("quantile"):
217+
quantile_value = str(expression.args["quantile"].this)
218+
extra_args = [quantile_value] if quantile_value else []
219+
result.aggregations.append(
220+
AggregationSpec(
221+
function="QUANTILE",
222+
field=field_name,
223+
alias=alias,
224+
extra_args=extra_args,
225+
)
226+
)
187227
elif isinstance(expression, exp.Anonymous):
188228
# Custom function call (e.g., vector_distance) - check before exp.Func
189229
# since Anonymous is a subclass of Func
190230
func_name = expression.name.lower()
231+
# Redis-specific reducer functions that sqlglot doesn't recognize
232+
redis_reducers = {
233+
"count_distinct",
234+
"count_distinctish",
235+
"quantile",
236+
"random_sample",
237+
}
191238
if func_name == "vector_distance":
192239
# Extract the vector field name from first argument
193240
if expression.expressions:
@@ -198,6 +245,26 @@ def _process_select_expression_inner(
198245
field=field_name,
199246
alias=alias or func_name,
200247
)
248+
elif func_name in redis_reducers:
249+
# Redis-specific reducer functions
250+
field_name = None
251+
reducer_extra_args: list[str] = []
252+
if expression.expressions:
253+
first_arg = expression.expressions[0]
254+
if isinstance(first_arg, exp.Column):
255+
field_name = first_arg.name
256+
# Extract additional arguments (e.g., quantile value for QUANTILE)
257+
for arg in expression.expressions[1:]:
258+
if isinstance(arg, exp.Literal):
259+
reducer_extra_args.append(str(arg.this))
260+
result.aggregations.append(
261+
AggregationSpec(
262+
function=func_name.upper(),
263+
field=field_name,
264+
alias=alias,
265+
extra_args=reducer_extra_args,
266+
)
267+
)
201268
else:
202269
# Other custom functions - treat as computed field
203270
expr_str = expression.sql()

sql_redis/translator.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,17 @@ def _build_search(
189189
args.append("2")
190190
params["vector"] = None # Placeholder for vector bytes
191191

192-
# RETURN clause
193-
if parsed.fields and parsed.fields != ["*"]:
192+
# RETURN clause - include vector distance alias if present
193+
return_fields = list(parsed.fields) if parsed.fields else []
194+
if analyzed.vector_search and analyzed.vector_search.alias:
195+
# Add vector distance alias to return fields (like VectorQuery with return_score=True)
196+
if analyzed.vector_search.alias not in return_fields:
197+
return_fields.append(analyzed.vector_search.alias)
198+
199+
if return_fields and return_fields != ["*"]:
194200
args.append("RETURN")
195-
args.append(str(len(parsed.fields)))
196-
args.extend(parsed.fields)
201+
args.append(str(len(return_fields)))
202+
args.extend(return_fields)
197203

198204
# SORTBY
199205
if parsed.orderby_fields:
@@ -251,8 +257,15 @@ def _build_aggregate(
251257
for agg in analyzed.aggregations:
252258
args.append("REDUCE")
253259
args.append(agg.function.upper())
254-
if agg.field:
255-
args.extend(["1", f"@{agg.field}"])
260+
# COUNT always takes 0 arguments in Redis
261+
if agg.function.upper() == "COUNT":
262+
args.append("0")
263+
elif agg.field:
264+
# Calculate nargs: 1 for field + number of extra args
265+
nargs = 1 + len(agg.extra_args)
266+
args.append(str(nargs))
267+
args.append(f"@{agg.field}")
268+
args.extend(agg.extra_args)
256269
else:
257270
args.append("0")
258271
if agg.alias:
@@ -263,8 +276,15 @@ def _build_aggregate(
263276
for agg in analyzed.aggregations:
264277
args.append("REDUCE")
265278
args.append(agg.function.upper())
266-
if agg.field:
267-
args.extend(["1", f"@{agg.field}"])
279+
# COUNT always takes 0 arguments in Redis
280+
if agg.function.upper() == "COUNT":
281+
args.append("0")
282+
elif agg.field:
283+
# Calculate nargs: 1 for field + number of extra args
284+
nargs = 1 + len(agg.extra_args)
285+
args.append(str(nargs))
286+
args.append(f"@{agg.field}")
287+
args.extend(agg.extra_args)
268288
else:
269289
args.append("0")
270290
# Always provide an alias

tests/test_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ def test_vector_search_with_param(
301301
assert len(result.rows) <= 3
302302
# First result should be closest to query vector
303303
assert result.rows[0]["title"] == "First"
304+
# Verify vector distance score is returned (like VectorQuery with return_score=True)
305+
assert "score" in result.rows[0]
306+
# Score should be a valid distance value (0 for exact match with cosine)
307+
score = float(result.rows[0]["score"])
308+
assert score >= 0 # Distance should be non-negative
304309

305310

306311
class TestErrorHandling:

tests/test_translator.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,70 @@ def test_computed_field(self, translator: Translator, basic_index: str):
288288
assert result.command == "FT.AGGREGATE"
289289
assert "APPLY" in result.args
290290

291+
def test_count_with_field_uses_zero_args(
292+
self, translator: Translator, basic_index: str
293+
):
294+
"""COUNT(field) should generate REDUCE COUNT 0, not REDUCE COUNT 1 @field.
295+
296+
Redis COUNT reducer always takes 0 arguments - it counts rows, not field values.
297+
"""
298+
result = translator.translate(
299+
f"SELECT category, COUNT(price) AS count_price FROM {basic_index} GROUP BY category"
300+
)
301+
302+
assert result.command == "FT.AGGREGATE"
303+
# Find REDUCE COUNT in args and verify it's followed by "0"
304+
args = result.args
305+
reduce_idx = args.index("REDUCE")
306+
assert args[reduce_idx + 1] == "COUNT"
307+
assert args[reduce_idx + 2] == "0" # COUNT always takes 0 args
308+
# Should NOT have @price after COUNT
309+
assert "@price" not in args[reduce_idx + 2 : reduce_idx + 4]
310+
311+
def test_count_star_uses_zero_args(self, translator: Translator, basic_index: str):
312+
"""COUNT(*) should generate REDUCE COUNT 0."""
313+
result = translator.translate(
314+
f"SELECT category, COUNT(*) AS cnt FROM {basic_index} GROUP BY category"
315+
)
316+
317+
args = result.args
318+
reduce_idx = args.index("REDUCE")
319+
assert args[reduce_idx + 1] == "COUNT"
320+
assert args[reduce_idx + 2] == "0"
321+
322+
def test_count_distinct_reducer(self, translator: Translator, basic_index: str):
323+
"""COUNT_DISTINCT(field) should generate REDUCE COUNT_DISTINCT 1 @field."""
324+
result = translator.translate(
325+
f"SELECT category, COUNT_DISTINCT(title) AS unique_titles "
326+
f"FROM {basic_index} GROUP BY category"
327+
)
328+
329+
assert result.command == "FT.AGGREGATE"
330+
args = result.args
331+
reduce_idx = args.index("REDUCE")
332+
assert args[reduce_idx + 1] == "COUNT_DISTINCT"
333+
assert args[reduce_idx + 2] == "1"
334+
assert args[reduce_idx + 3] == "@title"
335+
assert "AS" in args
336+
assert "unique_titles" in args
337+
338+
def test_quantile_reducer(self, translator: Translator, basic_index: str):
339+
"""QUANTILE(field, value) should generate REDUCE QUANTILE 2 @field value."""
340+
result = translator.translate(
341+
f"SELECT category, QUANTILE(price, 0.5) AS median_price "
342+
f"FROM {basic_index} GROUP BY category"
343+
)
344+
345+
assert result.command == "FT.AGGREGATE"
346+
args = result.args
347+
reduce_idx = args.index("REDUCE")
348+
assert args[reduce_idx + 1] == "QUANTILE"
349+
assert args[reduce_idx + 2] == "2" # nargs = 1 (field) + 1 (quantile value)
350+
assert args[reduce_idx + 3] == "@price"
351+
assert args[reduce_idx + 4] == "0.5"
352+
assert "AS" in args
353+
assert "median_price" in args
354+
291355

292356
class TestTranslatorVectorSearch:
293357
"""Tests for vector search translation."""

0 commit comments

Comments
 (0)