From 33f3de30c48b70ddbe2a1c5a7b8ab2afd6b18317 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 28 Jan 2026 14:30:27 -0500 Subject: [PATCH 1/3] test tweaks --- sql_redis/translator.py | 14 ++++++++++---- tests/test_executor.py | 5 +++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql_redis/translator.py b/sql_redis/translator.py index bde84ac..6461e1e 100644 --- a/sql_redis/translator.py +++ b/sql_redis/translator.py @@ -189,11 +189,17 @@ def _build_search( args.append("2") params["vector"] = None # Placeholder for vector bytes - # RETURN clause - if parsed.fields and parsed.fields != ["*"]: + # RETURN clause - include vector distance alias if present + return_fields = list(parsed.fields) if parsed.fields else [] + if analyzed.vector_search and analyzed.vector_search.alias: + # Add vector distance alias to return fields (like VectorQuery with return_score=True) + if analyzed.vector_search.alias not in return_fields: + return_fields.append(analyzed.vector_search.alias) + + if return_fields and return_fields != ["*"]: args.append("RETURN") - args.append(str(len(parsed.fields))) - args.extend(parsed.fields) + args.append(str(len(return_fields))) + args.extend(return_fields) # SORTBY if parsed.orderby_fields: diff --git a/tests/test_executor.py b/tests/test_executor.py index 15416bf..21b98d7 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -301,6 +301,11 @@ def test_vector_search_with_param( assert len(result.rows) <= 3 # First result should be closest to query vector assert result.rows[0]["title"] == "First" + # Verify vector distance score is returned (like VectorQuery with return_score=True) + assert "score" in result.rows[0] + # Score should be a valid distance value (0 for exact match with cosine) + score = float(result.rows[0]["score"]) + assert score >= 0 # Distance should be non-negative class TestErrorHandling: From b07fcf32c4623eba0f07867f0006af88c3a19bf4 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 28 Jan 2026 16:10:49 -0500 Subject: [PATCH 2/3] working primary reducers --- sql_redis/parser.py | 40 +++++++++++++++++++++++++++++++++- sql_redis/translator.py | 10 +++++++-- tests/test_translator.py | 47 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/sql_redis/parser.py b/sql_redis/parser.py index 926ef68..59d2a79 100644 --- a/sql_redis/parser.py +++ b/sql_redis/parser.py @@ -150,9 +150,28 @@ def _process_select_expression_inner( result.fields.append(expression.name) elif isinstance(expression, exp.Star): result.fields.append("*") - elif isinstance(expression, (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)): + elif isinstance( + expression, + ( + exp.Count, + exp.Sum, + exp.Avg, + exp.Min, + exp.Max, + exp.Stddev, + exp.Variance, + exp.FirstValue, + exp.ArrayAgg, + ), + ): # Aggregation function + # Map sqlglot function names to Redis reducer names func_name = expression.key.upper() + redis_func_map = { + "FIRSTVALUE": "FIRST_VALUE", + "ARRAYAGG": "TOLIST", + } + func_name = redis_func_map.get(func_name, func_name) field_name = None # Get the field being aggregated (if any) if expression.this: @@ -188,6 +207,13 @@ def _process_select_expression_inner( # Custom function call (e.g., vector_distance) - check before exp.Func # since Anonymous is a subclass of Func func_name = expression.name.lower() + # Redis-specific reducer functions that sqlglot doesn't recognize + redis_reducers = { + "count_distinct", + "count_distinctish", + "quantile", + "random_sample", + } if func_name == "vector_distance": # Extract the vector field name from first argument if expression.expressions: @@ -198,6 +224,18 @@ def _process_select_expression_inner( field=field_name, alias=alias or func_name, ) + elif func_name in redis_reducers: + # Redis-specific reducer functions + field_name = None + if expression.expressions: + first_arg = expression.expressions[0] + if isinstance(first_arg, exp.Column): + field_name = first_arg.name + result.aggregations.append( + AggregationSpec( + function=func_name.upper(), field=field_name, alias=alias + ) + ) else: # Other custom functions - treat as computed field expr_str = expression.sql() diff --git a/sql_redis/translator.py b/sql_redis/translator.py index 6461e1e..99a64c6 100644 --- a/sql_redis/translator.py +++ b/sql_redis/translator.py @@ -257,7 +257,10 @@ def _build_aggregate( for agg in analyzed.aggregations: args.append("REDUCE") args.append(agg.function.upper()) - if agg.field: + # COUNT always takes 0 arguments in Redis + if agg.function.upper() == "COUNT": + args.append("0") + elif agg.field: args.extend(["1", f"@{agg.field}"]) else: args.append("0") @@ -269,7 +272,10 @@ def _build_aggregate( for agg in analyzed.aggregations: args.append("REDUCE") args.append(agg.function.upper()) - if agg.field: + # COUNT always takes 0 arguments in Redis + if agg.function.upper() == "COUNT": + args.append("0") + elif agg.field: args.extend(["1", f"@{agg.field}"]) else: args.append("0") diff --git a/tests/test_translator.py b/tests/test_translator.py index e6178cb..2bde1bf 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -288,6 +288,53 @@ def test_computed_field(self, translator: Translator, basic_index: str): assert result.command == "FT.AGGREGATE" assert "APPLY" in result.args + def test_count_with_field_uses_zero_args( + self, translator: Translator, basic_index: str + ): + """COUNT(field) should generate REDUCE COUNT 0, not REDUCE COUNT 1 @field. + + Redis COUNT reducer always takes 0 arguments - it counts rows, not field values. + """ + result = translator.translate( + f"SELECT category, COUNT(price) AS count_price FROM {basic_index} GROUP BY category" + ) + + assert result.command == "FT.AGGREGATE" + # Find REDUCE COUNT in args and verify it's followed by "0" + args = result.args + reduce_idx = args.index("REDUCE") + assert args[reduce_idx + 1] == "COUNT" + assert args[reduce_idx + 2] == "0" # COUNT always takes 0 args + # Should NOT have @price after COUNT + assert "@price" not in args[reduce_idx + 2 : reduce_idx + 4] + + def test_count_star_uses_zero_args(self, translator: Translator, basic_index: str): + """COUNT(*) should generate REDUCE COUNT 0.""" + result = translator.translate( + f"SELECT category, COUNT(*) AS cnt FROM {basic_index} GROUP BY category" + ) + + args = result.args + reduce_idx = args.index("REDUCE") + assert args[reduce_idx + 1] == "COUNT" + assert args[reduce_idx + 2] == "0" + + def test_count_distinct_reducer(self, translator: Translator, basic_index: str): + """COUNT_DISTINCT(field) should generate REDUCE COUNT_DISTINCT 1 @field.""" + result = translator.translate( + f"SELECT category, COUNT_DISTINCT(title) AS unique_titles " + f"FROM {basic_index} GROUP BY category" + ) + + assert result.command == "FT.AGGREGATE" + args = result.args + reduce_idx = args.index("REDUCE") + assert args[reduce_idx + 1] == "COUNT_DISTINCT" + assert args[reduce_idx + 2] == "1" + assert args[reduce_idx + 3] == "@title" + assert "AS" in args + assert "unique_titles" in args + class TestTranslatorVectorSearch: """Tests for vector search translation.""" From 8561708a9d66cf87c409ffa113c5e00eb46c9579 Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 28 Jan 2026 16:42:49 -0500 Subject: [PATCH 3/3] support quantile --- sql_redis/parser.py | 45 +++++++++++++++++++++++++++++++++------- sql_redis/translator.py | 12 +++++++++-- tests/test_translator.py | 17 +++++++++++++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/sql_redis/parser.py b/sql_redis/parser.py index 59d2a79..037bcd9 100644 --- a/sql_redis/parser.py +++ b/sql_redis/parser.py @@ -2,7 +2,8 @@ from __future__ import annotations -from dataclasses import dataclass, field +import dataclasses +from dataclasses import dataclass import sqlglot from sqlglot import exp @@ -15,6 +16,9 @@ class AggregationSpec: function: str field: str | None = None alias: str | None = None + extra_args: list[str] = dataclasses.field( + default_factory=list + ) # For reducers like QUANTILE @dataclass @@ -49,14 +53,14 @@ class ParsedQuery: """Result of parsing a SQL query.""" index: str = "" - fields: list[str] = field(default_factory=list) - conditions: list[Condition] = field(default_factory=list) + fields: list[str] = dataclasses.field(default_factory=list) + conditions: list[Condition] = dataclasses.field(default_factory=list) boolean_operator: str = "AND" - aggregations: list[AggregationSpec] = field(default_factory=list) - computed_fields: list[ComputedField] = field(default_factory=list) + aggregations: list[AggregationSpec] = dataclasses.field(default_factory=list) + computed_fields: list[ComputedField] = dataclasses.field(default_factory=list) vector_search: VectorSearchSpec | None = None - groupby_fields: list[str] = field(default_factory=list) - orderby_fields: list[tuple[str, str]] = field( + groupby_fields: list[str] = dataclasses.field(default_factory=list) + orderby_fields: list[tuple[str, str]] = dataclasses.field( default_factory=list ) # (field, ASC|DESC) limit: int | None = None @@ -203,6 +207,23 @@ def _process_select_expression_inner( # - Distance: L2/Euclidean distance # - CosineDistance: cosine_distance() function self._process_vector_distance(expression, result, alias) + elif isinstance(expression, exp.Quantile): + # QUANTILE(field, quantile_value) -> REDUCE QUANTILE 2 @field quantile_value + field_name = None + if expression.this and isinstance(expression.this, exp.Column): + field_name = expression.this.name + quantile_value = None + if expression.args.get("quantile"): + quantile_value = str(expression.args["quantile"].this) + extra_args = [quantile_value] if quantile_value else [] + result.aggregations.append( + AggregationSpec( + function="QUANTILE", + field=field_name, + alias=alias, + extra_args=extra_args, + ) + ) elif isinstance(expression, exp.Anonymous): # Custom function call (e.g., vector_distance) - check before exp.Func # since Anonymous is a subclass of Func @@ -227,13 +248,21 @@ def _process_select_expression_inner( elif func_name in redis_reducers: # Redis-specific reducer functions field_name = None + reducer_extra_args: list[str] = [] if expression.expressions: first_arg = expression.expressions[0] if isinstance(first_arg, exp.Column): field_name = first_arg.name + # Extract additional arguments (e.g., quantile value for QUANTILE) + for arg in expression.expressions[1:]: + if isinstance(arg, exp.Literal): + reducer_extra_args.append(str(arg.this)) result.aggregations.append( AggregationSpec( - function=func_name.upper(), field=field_name, alias=alias + function=func_name.upper(), + field=field_name, + alias=alias, + extra_args=reducer_extra_args, ) ) else: diff --git a/sql_redis/translator.py b/sql_redis/translator.py index 99a64c6..0fec87c 100644 --- a/sql_redis/translator.py +++ b/sql_redis/translator.py @@ -261,7 +261,11 @@ def _build_aggregate( if agg.function.upper() == "COUNT": args.append("0") elif agg.field: - args.extend(["1", f"@{agg.field}"]) + # Calculate nargs: 1 for field + number of extra args + nargs = 1 + len(agg.extra_args) + args.append(str(nargs)) + args.append(f"@{agg.field}") + args.extend(agg.extra_args) else: args.append("0") if agg.alias: @@ -276,7 +280,11 @@ def _build_aggregate( if agg.function.upper() == "COUNT": args.append("0") elif agg.field: - args.extend(["1", f"@{agg.field}"]) + # Calculate nargs: 1 for field + number of extra args + nargs = 1 + len(agg.extra_args) + args.append(str(nargs)) + args.append(f"@{agg.field}") + args.extend(agg.extra_args) else: args.append("0") # Always provide an alias diff --git a/tests/test_translator.py b/tests/test_translator.py index 2bde1bf..5309d76 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -335,6 +335,23 @@ def test_count_distinct_reducer(self, translator: Translator, basic_index: str): assert "AS" in args assert "unique_titles" in args + def test_quantile_reducer(self, translator: Translator, basic_index: str): + """QUANTILE(field, value) should generate REDUCE QUANTILE 2 @field value.""" + result = translator.translate( + f"SELECT category, QUANTILE(price, 0.5) AS median_price " + f"FROM {basic_index} GROUP BY category" + ) + + assert result.command == "FT.AGGREGATE" + args = result.args + reduce_idx = args.index("REDUCE") + assert args[reduce_idx + 1] == "QUANTILE" + assert args[reduce_idx + 2] == "2" # nargs = 1 (field) + 1 (quantile value) + assert args[reduce_idx + 3] == "@price" + assert args[reduce_idx + 4] == "0.5" + assert "AS" in args + assert "median_price" in args + class TestTranslatorVectorSearch: """Tests for vector search translation."""