Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions sql_redis/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from dataclasses import dataclass, field
import dataclasses
from dataclasses import dataclass

import sqlglot
from sqlglot import exp
Expand All @@ -15,6 +16,9 @@ class AggregationSpec:
function: str
field: str | None = None
alias: str | None = None
extra_args: list[str] = dataclasses.field(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantile takes 2 arguments instead of 1 like the others

default_factory=list
) # For reducers like QUANTILE


@dataclass
Expand Down Expand Up @@ -49,14 +53,14 @@ class ParsedQuery:
"""Result of parsing a SQL query."""

index: str = ""
fields: list[str] = field(default_factory=list)
conditions: list[Condition] = field(default_factory=list)
fields: list[str] = dataclasses.field(default_factory=list)
conditions: list[Condition] = dataclasses.field(default_factory=list)
boolean_operator: str = "AND"
aggregations: list[AggregationSpec] = field(default_factory=list)
computed_fields: list[ComputedField] = field(default_factory=list)
aggregations: list[AggregationSpec] = dataclasses.field(default_factory=list)
computed_fields: list[ComputedField] = dataclasses.field(default_factory=list)
vector_search: VectorSearchSpec | None = None
groupby_fields: list[str] = field(default_factory=list)
orderby_fields: list[tuple[str, str]] = field(
groupby_fields: list[str] = dataclasses.field(default_factory=list)
orderby_fields: list[tuple[str, str]] = dataclasses.field(
default_factory=list
) # (field, ASC|DESC)
limit: int | None = None
Expand Down Expand Up @@ -150,9 +154,28 @@ def _process_select_expression_inner(
result.fields.append(expression.name)
elif isinstance(expression, exp.Star):
result.fields.append("*")
elif isinstance(expression, (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)):
elif isinstance(
expression,
(
exp.Count,
exp.Sum,
exp.Avg,
exp.Min,
exp.Max,
exp.Stddev,
exp.Variance,
exp.FirstValue,
exp.ArrayAgg,
),
):
# Aggregation function
# Map sqlglot function names to Redis reducer names
func_name = expression.key.upper()
redis_func_map = {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add these to enable all available reducers from redis

"FIRSTVALUE": "FIRST_VALUE",
"ARRAYAGG": "TOLIST",
}
func_name = redis_func_map.get(func_name, func_name)
field_name = None
# Get the field being aggregated (if any)
if expression.this:
Expand Down Expand Up @@ -184,10 +207,34 @@ def _process_select_expression_inner(
# - Distance: L2/Euclidean distance
# - CosineDistance: cosine_distance() function
self._process_vector_distance(expression, result, alias)
elif isinstance(expression, exp.Quantile):
# QUANTILE(field, quantile_value) -> REDUCE QUANTILE 2 @field quantile_value
field_name = None
if expression.this and isinstance(expression.this, exp.Column):
field_name = expression.this.name
quantile_value = None
if expression.args.get("quantile"):
quantile_value = str(expression.args["quantile"].this)
extra_args = [quantile_value] if quantile_value else []
result.aggregations.append(
AggregationSpec(
function="QUANTILE",
field=field_name,
alias=alias,
extra_args=extra_args,
)
)
elif isinstance(expression, exp.Anonymous):
# Custom function call (e.g., vector_distance) - check before exp.Func
# since Anonymous is a subclass of Func
func_name = expression.name.lower()
# Redis-specific reducer functions that sqlglot doesn't recognize
redis_reducers = {
"count_distinct",
"count_distinctish",
"quantile",
"random_sample",
}
if func_name == "vector_distance":
# Extract the vector field name from first argument
if expression.expressions:
Expand All @@ -198,6 +245,26 @@ def _process_select_expression_inner(
field=field_name,
alias=alias or func_name,
)
elif func_name in redis_reducers:
# Redis-specific reducer functions
field_name = None
reducer_extra_args: list[str] = []
if expression.expressions:
first_arg = expression.expressions[0]
if isinstance(first_arg, exp.Column):
field_name = first_arg.name
# Extract additional arguments (e.g., quantile value for QUANTILE)
for arg in expression.expressions[1:]:
if isinstance(arg, exp.Literal):
reducer_extra_args.append(str(arg.this))
result.aggregations.append(
AggregationSpec(
function=func_name.upper(),
field=field_name,
alias=alias,
extra_args=reducer_extra_args,
)
)
else:
# Other custom functions - treat as computed field
expr_str = expression.sql()
Expand Down
36 changes: 28 additions & 8 deletions sql_redis/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,17 @@ def _build_search(
args.append("2")
params["vector"] = None # Placeholder for vector bytes

# RETURN clause
if parsed.fields and parsed.fields != ["*"]:
# RETURN clause - include vector distance alias if present
return_fields = list(parsed.fields) if parsed.fields else []
if analyzed.vector_search and analyzed.vector_search.alias:
# Add vector distance alias to return fields (like VectorQuery with return_score=True)
if analyzed.vector_search.alias not in return_fields:
return_fields.append(analyzed.vector_search.alias)

if return_fields and return_fields != ["*"]:
args.append("RETURN")
args.append(str(len(parsed.fields)))
args.extend(parsed.fields)
args.append(str(len(return_fields)))
args.extend(return_fields)

# SORTBY
if parsed.orderby_fields:
Expand Down Expand Up @@ -251,8 +257,15 @@ def _build_aggregate(
for agg in analyzed.aggregations:
args.append("REDUCE")
args.append(agg.function.upper())
if agg.field:
args.extend(["1", f"@{agg.field}"])
# COUNT always takes 0 arguments in Redis
if agg.function.upper() == "COUNT":

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count reducer had a bug because unlike the rest of the reducers it takes 0 args

args.append("0")
elif agg.field:
# Calculate nargs: 1 for field + number of extra args
nargs = 1 + len(agg.extra_args)
args.append(str(nargs))
args.append(f"@{agg.field}")
args.extend(agg.extra_args)
else:
args.append("0")
if agg.alias:
Expand All @@ -263,8 +276,15 @@ def _build_aggregate(
for agg in analyzed.aggregations:
args.append("REDUCE")
args.append(agg.function.upper())
if agg.field:
args.extend(["1", f"@{agg.field}"])
# COUNT always takes 0 arguments in Redis
if agg.function.upper() == "COUNT":
args.append("0")
elif agg.field:
# Calculate nargs: 1 for field + number of extra args
nargs = 1 + len(agg.extra_args)
args.append(str(nargs))
args.append(f"@{agg.field}")
args.extend(agg.extra_args)
else:
args.append("0")
# Always provide an alias
Expand Down
5 changes: 5 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def test_vector_search_with_param(
assert len(result.rows) <= 3
# First result should be closest to query vector
assert result.rows[0]["title"] == "First"
# Verify vector distance score is returned (like VectorQuery with return_score=True)
assert "score" in result.rows[0]
# Score should be a valid distance value (0 for exact match with cosine)
score = float(result.rows[0]["score"])
assert score >= 0 # Distance should be non-negative


class TestErrorHandling:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,70 @@ def test_computed_field(self, translator: Translator, basic_index: str):
assert result.command == "FT.AGGREGATE"
assert "APPLY" in result.args

def test_count_with_field_uses_zero_args(
self, translator: Translator, basic_index: str
):
"""COUNT(field) should generate REDUCE COUNT 0, not REDUCE COUNT 1 @field.

Redis COUNT reducer always takes 0 arguments - it counts rows, not field values.
"""
result = translator.translate(
f"SELECT category, COUNT(price) AS count_price FROM {basic_index} GROUP BY category"
)

assert result.command == "FT.AGGREGATE"
# Find REDUCE COUNT in args and verify it's followed by "0"
args = result.args
reduce_idx = args.index("REDUCE")
assert args[reduce_idx + 1] == "COUNT"
assert args[reduce_idx + 2] == "0" # COUNT always takes 0 args
# Should NOT have @price after COUNT
assert "@price" not in args[reduce_idx + 2 : reduce_idx + 4]

def test_count_star_uses_zero_args(self, translator: Translator, basic_index: str):
"""COUNT(*) should generate REDUCE COUNT 0."""
result = translator.translate(
f"SELECT category, COUNT(*) AS cnt FROM {basic_index} GROUP BY category"
)

args = result.args
reduce_idx = args.index("REDUCE")
assert args[reduce_idx + 1] == "COUNT"
assert args[reduce_idx + 2] == "0"

def test_count_distinct_reducer(self, translator: Translator, basic_index: str):
"""COUNT_DISTINCT(field) should generate REDUCE COUNT_DISTINCT 1 @field."""
result = translator.translate(
f"SELECT category, COUNT_DISTINCT(title) AS unique_titles "
f"FROM {basic_index} GROUP BY category"
)

assert result.command == "FT.AGGREGATE"
args = result.args
reduce_idx = args.index("REDUCE")
assert args[reduce_idx + 1] == "COUNT_DISTINCT"
assert args[reduce_idx + 2] == "1"
assert args[reduce_idx + 3] == "@title"
assert "AS" in args
assert "unique_titles" in args

def test_quantile_reducer(self, translator: Translator, basic_index: str):
"""QUANTILE(field, value) should generate REDUCE QUANTILE 2 @field value."""
result = translator.translate(
f"SELECT category, QUANTILE(price, 0.5) AS median_price "
f"FROM {basic_index} GROUP BY category"
)

assert result.command == "FT.AGGREGATE"
args = result.args
reduce_idx = args.index("REDUCE")
assert args[reduce_idx + 1] == "QUANTILE"
assert args[reduce_idx + 2] == "2" # nargs = 1 (field) + 1 (quantile value)
assert args[reduce_idx + 3] == "@price"
assert args[reduce_idx + 4] == "0.5"
assert "AS" in args
assert "median_price" in args


class TestTranslatorVectorSearch:
"""Tests for vector search translation."""
Expand Down
Loading