Skip to content

Commit b0e37cb

Browse files
authored
feat: add exists() support via FT.AGGREGATE APPLY/FILTER (#14)
* feat: add IS NULL / IS NOT NULL support via RediSearch ismissing() Translate SQL IS NULL and IS NOT NULL to RediSearch ismissing(@field) and -ismissing(@field) respectively. Requires Redis 7.4+ with the INDEXMISSING attribute declared on the target field. Changes: - parser: handle sqlglot exp.Is nodes, emit IS_NULL / IS_NOT_NULL operators - query_builder: add build_missing_condition() for ismissing() syntax - translator: short-circuit IS_NULL/IS_NOT_NULL before field-type dispatch - translator: make DIALECT 2 the default for all FT.SEARCH and FT.AGGREGATE - executor: catch and re-raise ResponseError with clear version guidance when ismissing() fails (both sync and async paths) - translator: emit UserWarning on IS NULL/IS NOT NULL noting Redis 7.4+ and INDEXMISSING requirements Tests: - 18 integration tests against Redis 8 (TAG, TEXT, NUMERIC field types, combined conditions, edge cases, raw command verification) - Unit tests for parser, query_builder, and translator - Warning and error message verification tests * chore: linter * feat: add exists() support via FT.AGGREGATE APPLY/FILTER Translate SQL exists(field) to RediSearch exists(@field) in FT.AGGREGATE APPLY (projection) and FILTER (HAVING) clauses. - Parser: handle exp.Exists in SELECT → computed_fields and HAVING → filters, distinguishing from SQL EXISTS (SELECT ...) - Analyzer: extract field references from exists() for schema validation - Translator: force FT.AGGREGATE when exists() is used, generate APPLY/FILTER clauses, include exists() fields in LOAD - Reject exists() in WHERE with clear error (aggregate-only function) - No INDEXMISSING attribute required — works on any indexed field Tests: 8 integration, 5 parser, 7 translator (350 total pass) * fix: handle sqlglot uppercasing exists() to EXISTS() in arithmetic expressions like exists(a) + exists(b), ensuring the referenced fields are included in LOAD args and don't cause a Property not loaded runtime error. * fix: handle Paren in HAVING clause and move exists() FILTER after GROUPBY/REDUCE - parser: unwrap exp.Paren in _process_having_clause so queries like HAVING (exists(email)) don't raise 'Unsupported HAVING expression' - translator: move exists() FILTER emission to after GROUPBY/REDUCE block so it acts as post-aggregation filtering (correct HAVING semantics) instead of pre-group document filtering * refactor: move all local imports to module scope Move import re, import asyncio, import warnings, import time, and several from-imports out of function/method bodies to module-level imports. Local imports inside TYPE_CHECKING guards and try/except compatibility blocks are intentionally kept. * fix: remove benchmark file accidentally committed to feat/exists * fix: emit LOAD * in FT.AGGREGATE when SELECT * is used When FT.AGGREGATE is forced (e.g., by HAVING exists()), SELECT * was skipping the wildcard in the LOAD field collection loop, resulting in no fields being loaded and empty/partial rows. Now detects SELECT * in aggregate mode and emits LOAD * so RediSearch returns all document attributes. Adds test for SELECT * with HAVING exists() verifying LOAD * is emitted. * fix: emit LOAD * in FT.AGGREGATE when SELECT * is used When FT.AGGREGATE is forced (e.g., by HAVING exists()), SELECT * was skipping the wildcard in the LOAD field collection loop, resulting in no fields being loaded and empty/partial rows. Now detects SELECT * in aggregate mode and emits LOAD * so RediSearch returns all document attributes.
1 parent 467b416 commit b0e37cb

12 files changed

Lines changed: 499 additions & 52 deletions

sql_redis/analyzer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def analyze(self, parsed: ParsedQuery) -> AnalyzedQuery:
111111
if field_name in computed.expression:
112112
referenced_fields.add(field_name)
113113

114+
# Fields from filters (HAVING exists(field))
115+
for filter_expr in parsed.filters:
116+
for field_name in schema.keys():
117+
if field_name in filter_expr:
118+
referenced_fields.add(field_name)
119+
114120
# Fields from vector search
115121
if parsed.vector_search:
116122
referenced_fields.add(parsed.vector_search.field)

sql_redis/parser.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ class ParsedQuery:
218218
) # (field, ASC|DESC)
219219
limit: int | None = None
220220
offset: int | None = None
221+
filters: list[str] = dataclasses.field(default_factory=list)
221222

222223

223224
class SQLParser:
@@ -260,6 +261,11 @@ def parse(self, sql: str) -> ParsedQuery:
260261
if isinstance(expr, exp.Column):
261262
result.groupby_fields.append(expr.name)
262263

264+
# Extract HAVING clause — exists() in HAVING → FILTER
265+
having = ast.find(exp.Having)
266+
if having:
267+
self._process_having_clause(having.this, result)
268+
263269
# Extract ORDER BY clause
264270
order = ast.find(exp.Order)
265271
if order:
@@ -392,6 +398,24 @@ def _process_select_expression_inner(
392398
extra_args=extra_args,
393399
)
394400
)
401+
elif isinstance(expression, exp.Exists):
402+
# exists(field) — RediSearch aggregation function
403+
# sqlglot parses exists(col) as exp.Exists(this=Column),
404+
# distinct from EXISTS (SELECT ...) which has this=Select.
405+
inner = expression.this
406+
if isinstance(inner, exp.Column):
407+
field_name = inner.name
408+
expr_str = f"exists({field_name})"
409+
field_alias = alias if alias else f"exists_{field_name}"
410+
result.computed_fields.append(
411+
ComputedField(expression=expr_str, alias=field_alias)
412+
)
413+
else:
414+
raise ValueError(
415+
"exists() in SELECT expects a column reference, "
416+
f"got {type(inner).__name__}. "
417+
"Use exists(field_name) for RediSearch field existence checks."
418+
)
395419
elif isinstance(expression, exp.Anonymous):
396420
# Custom function call (e.g., vector_distance) - check before exp.Func
397421
# since Anonymous is a subclass of Func
@@ -664,10 +688,43 @@ def _process_where_clause(
664688
"Unsupported IS expression in WHERE clause; only "
665689
"`column IS NULL` and `column IS NOT NULL` are supported."
666690
)
691+
elif isinstance(expression, exp.Exists):
692+
# Distinguish exists(column) from EXISTS (SELECT ...)
693+
inner = expression.this
694+
if isinstance(inner, exp.Column):
695+
# exists(field) — RediSearch aggregate function, not valid in WHERE
696+
raise ValueError(
697+
"exists() is a RediSearch aggregate function and cannot be "
698+
"used in WHERE clauses. Use HAVING exists(field) instead "
699+
"for post-aggregate filtering."
700+
)
701+
# EXISTS (SELECT ...) — SQL subquery, silently ignored (not supported)
667702
elif isinstance(expression, exp.Anonymous):
668703
# Custom function like MATCH(field, value)
669704
self._add_function_condition(expression, result, negated)
670705

706+
def _process_having_clause(self, expression, result: ParsedQuery) -> None:
707+
"""Process HAVING clause — routes exists() to filters."""
708+
if isinstance(expression, exp.Exists):
709+
inner = expression.this
710+
if isinstance(inner, exp.Column):
711+
result.filters.append(f"exists({inner.name})")
712+
else:
713+
raise ValueError(
714+
"exists() in HAVING expects a column reference, "
715+
f"got {type(inner).__name__}."
716+
)
717+
elif isinstance(expression, exp.Paren):
718+
self._process_having_clause(expression.this, result)
719+
elif isinstance(expression, exp.And):
720+
self._process_having_clause(expression.this, result)
721+
self._process_having_clause(expression.expression, result)
722+
else:
723+
raise ValueError(
724+
f"Unsupported HAVING expression: {type(expression).__name__}. "
725+
"Only exists(field) is supported in HAVING."
726+
)
727+
671728
def _add_condition(
672729
self, expression, operator: str, result: ParsedQuery, negated: bool
673730
) -> None:

sql_redis/schema.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from typing import TYPE_CHECKING, Callable
67

78
import redis
@@ -172,8 +173,6 @@ async def load_all(self) -> None:
172173
173174
Uses asyncio.gather() to load all index schemas concurrently.
174175
"""
175-
import asyncio
176-
177176
self._schemas.clear()
178177
indexes = await self._client.execute_command("FT._LIST")
179178
# Decode bytes to strings

sql_redis/translator.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import re
56
import warnings
67
from dataclasses import dataclass, field
78

@@ -120,6 +121,7 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery:
120121
or geo_requires_aggregate # geo_distance with >, >=, BETWEEN
121122
or len(analyzed.date_functions) > 0
122123
or has_date_func_conditions
124+
or len(parsed.filters) > 0 # exists() in HAVING → FILTER
123125
)
124126

125127
# Build query string from conditions
@@ -333,33 +335,44 @@ def _build_aggregate(
333335
geo_filter_conditions = list(parsed.geo_conditions)
334336

335337
# LOAD fields if needed
336-
load_fields = set()
337-
for agg in analyzed.aggregations:
338-
if agg.field:
339-
load_fields.add(agg.field)
340-
for field_name in analyzed.groupby_fields:
341-
load_fields.add(field_name)
342-
# Load geo fields used in geo_distance() SELECT expressions
343-
for geo_select in parsed.geo_distance_selects:
344-
load_fields.add(geo_select.field)
345-
# Load geo fields used in geo_distance() WHERE with >, >=, BETWEEN
346-
for geo_cond in geo_filter_conditions:
347-
load_fields.add(geo_cond.field)
348-
# Load source fields for date functions in SELECT
349-
for date_func in analyzed.date_functions:
350-
load_fields.add(date_func.field)
351-
# Load source fields for date function conditions in WHERE
352-
for condition in parsed.conditions:
353-
if self._is_date_function_condition(condition):
354-
load_fields.add(condition.field)
355-
# Load explicit SELECT fields for FT.AGGREGATE
356-
for field_name in parsed.fields:
357-
if field_name != "*":
338+
# SELECT * in aggregate mode → LOAD * (all document attributes)
339+
load_all = "*" in (parsed.fields or [])
340+
341+
load_fields: set[str] = set()
342+
if not load_all:
343+
for agg in analyzed.aggregations:
344+
if agg.field:
345+
load_fields.add(agg.field)
346+
for field_name in analyzed.groupby_fields:
347+
load_fields.add(field_name)
348+
# Load geo fields used in geo_distance() SELECT expressions
349+
for geo_select in parsed.geo_distance_selects:
350+
load_fields.add(geo_select.field)
351+
# Load geo fields used in geo_distance() WHERE with >, >=, BETWEEN
352+
for geo_cond in geo_filter_conditions:
353+
load_fields.add(geo_cond.field)
354+
# Load source fields for date functions in SELECT
355+
for date_func in analyzed.date_functions:
356+
load_fields.add(date_func.field)
357+
# Load source fields for date function conditions in WHERE
358+
for condition in parsed.conditions:
359+
if self._is_date_function_condition(condition):
360+
load_fields.add(condition.field)
361+
# Load explicit SELECT fields for FT.AGGREGATE
362+
for field_name in parsed.fields:
358363
# Skip computed fields (they have aliases from geo_distance)
359364
if field_name not in [gs.alias for gs in parsed.geo_distance_selects]:
360365
load_fields.add(field_name)
361-
362-
if load_fields:
366+
# Load fields referenced in exists() filters (HAVING)
367+
for filter_expr in parsed.filters:
368+
self._extract_exists_fields(filter_expr, load_fields)
369+
# Load fields referenced in exists() computed fields (SELECT)
370+
for computed in analyzed.computed_fields:
371+
self._extract_exists_fields(computed.expression, load_fields)
372+
373+
if load_all:
374+
args.extend(["LOAD", "*"])
375+
elif load_fields:
363376
args.append("LOAD")
364377
args.append(str(len(load_fields)))
365378
# Redis expects property names prefixed with '@' in LOAD
@@ -498,6 +511,13 @@ def _build_aggregate(
498511
alias = agg.alias or agg.function.lower()
499512
args.extend(["AS", alias])
500513

514+
# FILTER for exists() from HAVING clause (post-aggregation)
515+
for filter_expr in parsed.filters:
516+
prefixed = self._prefix_fields_in_expression(
517+
filter_expr, analyzed.field_types
518+
)
519+
args.extend(["FILTER", prefixed])
520+
501521
# SORTBY
502522
if parsed.orderby_fields:
503523
args.append("SORTBY")
@@ -593,12 +613,16 @@ def _convert_to_meters(self, value: float, unit: str) -> float:
593613
)
594614
return value * conversions[normalized_unit]
595615

616+
@staticmethod
617+
def _extract_exists_fields(expression: str, load_fields: set[str]) -> None:
618+
"""Extract field names from exists() calls and add to load_fields."""
619+
for match in re.finditer(r"exists\((\w+)\)", expression, re.IGNORECASE):
620+
load_fields.add(match.group(1))
621+
596622
def _prefix_fields_in_expression(
597623
self, expression: str, schema: dict[str, str]
598624
) -> str:
599625
"""Prefix field names with @ in an expression for Redis APPLY."""
600-
import re
601-
602626
result = expression
603627
for field_name in schema:
604628
# Match field name as a whole word, not already prefixed with @

tests/test_date_fields.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tests for DATE/DATETIME literal parsing and conversion."""
22

33
import pytest
4+
import redis as redis_lib
45

56
from sql_redis.parser import SQLParser
7+
from sql_redis.schema import SchemaRegistry
68
from sql_redis.translator import Translator
79

810

@@ -126,8 +128,6 @@ class TestDateTranslation:
126128
@pytest.fixture
127129
def date_index(self, redis_client):
128130
"""Create an index with NUMERIC field for dates."""
129-
import redis as redis_lib
130-
131131
index_name = "test_dates"
132132
try:
133133
redis_client.execute_command("FT.DROPINDEX", index_name, "DD")
@@ -153,8 +153,6 @@ def date_index(self, redis_client):
153153
@pytest.fixture
154154
def date_translator(self, redis_client, date_index):
155155
"""Create a translator with the date index loaded."""
156-
from sql_redis.schema import SchemaRegistry
157-
158156
registry = SchemaRegistry(redis_client)
159157
registry.load_all()
160158
return Translator(registry)

tests/test_date_functions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Tests for DATE function parsing and translation (Phase 2 & 3)."""
22

33
import pytest
4+
import redis as redis_lib
45

56
from sql_redis.parser import SQLParser
7+
from sql_redis.schema import SchemaRegistry
8+
from sql_redis.translator import Translator
69

710

811
class TestDateFunctionParsing:
@@ -108,8 +111,6 @@ class TestDateFunctionTranslation:
108111
@pytest.fixture
109112
def date_index(self, redis_client):
110113
"""Create an index with NUMERIC field for dates."""
111-
import redis as redis_lib
112-
113114
index_name = "test_date_funcs"
114115
try:
115116
redis_client.execute_command("FT.DROPINDEX", index_name, "DD")
@@ -135,9 +136,6 @@ def date_index(self, redis_client):
135136
@pytest.fixture
136137
def date_translator(self, redis_client, date_index):
137138
"""Create a translator with the date index loaded."""
138-
from sql_redis.schema import SchemaRegistry
139-
from sql_redis.translator import Translator
140-
141139
registry = SchemaRegistry(redis_client)
142140
registry.load_all()
143141
return Translator(registry)

0 commit comments

Comments
 (0)