Skip to content

Commit 2d7685b

Browse files
committed
feat: advanced text search — fuzzy LD 2/3, suffix/infix, OR, proximity, BM25 scoring
- Parser: add LIKE handler, Boolean extraction for inorder param - QueryBuilder: suffix/infix patterns, fuzzy LD 2-3, slop/inorder attributes, scoring - Translator: score_alias on TranslatedQuery, WITHSCORES/SCORER args - Executor: stride-3 response parsing for WITHSCORES (sync + async) - Tests: 137 new tests (82 unit QB, 37 unit translator, 18 integration) - All 402 tests pass, mypy clean
1 parent a7820d0 commit 2d7685b

7 files changed

Lines changed: 604 additions & 29 deletions

File tree

sql_redis/executor.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,25 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
166166
rows = []
167167

168168
if translated.command == "FT.SEARCH":
169-
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
170-
# Skip document keys (odd indices), take field lists (even indices after count)
171-
for i in range(2, len(raw_result), 2):
172-
row_data = raw_result[i]
173-
row = dict(zip(row_data[::2], row_data[1::2]))
174-
rows.append(row)
169+
# Check if WITHSCORES was requested — changes response format
170+
with_scores = "WITHSCORES" in translated.args
171+
score_alias = translated.score_alias
172+
173+
if with_scores:
174+
# WITHSCORES format: [count, key1, score1, [fields1], key2, score2, [fields2], ...]
175+
# Stride of 3: key, score, field_list
176+
for i in range(1, len(raw_result) - 2, 3):
177+
score = raw_result[i + 1]
178+
row_data = raw_result[i + 2]
179+
row = dict(zip(row_data[::2], row_data[1::2]))
180+
row[score_alias or "__score"] = score
181+
rows.append(row)
182+
else:
183+
# Standard format: [count, key1, [fields1], key2, [fields2], ...]
184+
for i in range(2, len(raw_result), 2):
185+
row_data = raw_result[i]
186+
row = dict(zip(row_data[::2], row_data[1::2]))
187+
rows.append(row)
175188
else:
176189
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
177190
for row_data in raw_result[1:]:
@@ -252,11 +265,23 @@ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
252265
rows = []
253266

254267
if translated.command == "FT.SEARCH":
255-
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
256-
for i in range(2, len(raw_result), 2):
257-
row_data = raw_result[i]
258-
row = dict(zip(row_data[::2], row_data[1::2]))
259-
rows.append(row)
268+
with_scores = "WITHSCORES" in translated.args
269+
score_alias = translated.score_alias
270+
271+
if with_scores:
272+
# WITHSCORES format: [count, key1, score1, [fields1], ...]
273+
for i in range(1, len(raw_result) - 2, 3):
274+
score = raw_result[i + 1]
275+
row_data = raw_result[i + 2]
276+
row = dict(zip(row_data[::2], row_data[1::2]))
277+
row[score_alias or "__score"] = score
278+
rows.append(row)
279+
else:
280+
# Standard format: [count, key1, [fields1], key2, [fields2], ...]
281+
for i in range(2, len(raw_result), 2):
282+
row_data = raw_result[i]
283+
row = dict(zip(row_data[::2], row_data[1::2]))
284+
rows.append(row)
260285
else:
261286
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
262287
for row_data in raw_result[1:]:

sql_redis/parser.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ class Condition:
164164
operator: str
165165
value: object
166166
negated: bool = False
167+
fuzzy_level: int | None = None # Levenshtein distance for FUZZY (1, 2, or 3)
168+
slop: int | None = None # Max distance between terms for proximity search
169+
inorder: bool = False # Require terms in order (used with slop)
167170

168171

169172
@dataclass
@@ -196,6 +199,17 @@ class GeoDistanceSelect:
196199
unit: str = "m" # m, km, mi, ft (default: meters)
197200

198201

202+
@dataclass
203+
class ScoringSpec:
204+
"""Specification for relevance scoring.
205+
206+
Triggers WITHSCORES and optional SCORER on FT.SEARCH.
207+
"""
208+
209+
alias: str = "score" # Column alias for the score
210+
scorer: str = "BM25" # Scorer algorithm (BM25, TFIDF, DISMAX, etc.)
211+
212+
199213
@dataclass
200214
class ParsedQuery:
201215
"""Result of parsing a SQL query."""
@@ -219,6 +233,9 @@ class ParsedQuery:
219233
limit: int | None = None
220234
offset: int | None = None
221235
filters: list[str] = dataclasses.field(default_factory=list)
236+
scoring: ScoringSpec | None = None # Relevance scoring config
237+
verbatim: bool = False # If True, add VERBATIM to FT.SEARCH
238+
nostopwords: bool = False # If True, add NOSTOPWORDS to FT.SEARCH
222239

223240

224241
class SQLParser:
@@ -441,6 +458,17 @@ def _process_select_expression_inner(
441458
elif func_name_lower == "geo_distance":
442459
# geo_distance(field, POINT(lon, lat), unit) in SELECT
443460
self._process_geo_distance_select(expression, result, alias)
461+
elif func_name_lower == "score":
462+
# score() or score('BM25') — triggers WITHSCORES + SCORER
463+
scorer = "BM25"
464+
if expression.expressions:
465+
scorer_val = self._extract_literal_value(expression.expressions[0])
466+
if scorer_val is not None:
467+
scorer = str(scorer_val).upper()
468+
result.scoring = ScoringSpec(
469+
alias=alias or "score",
470+
scorer=scorer,
471+
)
444472
elif func_name_lower in redis_reducers:
445473
# Redis-specific reducer functions
446474
field_name = None
@@ -656,6 +684,9 @@ def _process_where_clause(
656684
self._add_between_condition(expression, result, negated)
657685
elif isinstance(expression, exp.In):
658686
self._add_in_condition(expression, result, negated)
687+
elif isinstance(expression, exp.Like):
688+
# LIKE 'pattern%' / '%pattern' / '%pattern%'
689+
self._add_condition(expression, "LIKE", result, negated)
659690
elif isinstance(expression, exp.And):
660691
result.boolean_operator = "AND"
661692
self._process_where_clause(expression.this, result, negated)
@@ -938,25 +969,59 @@ def _add_in_condition(self, expression, result: ParsedQuery, negated: bool) -> N
938969
def _add_function_condition(
939970
self, expression, result: ParsedQuery, negated: bool
940971
) -> None:
941-
"""Add a condition from a function call like fulltext(field, value)."""
972+
"""Add a condition from a function call like fulltext(field, value) or fuzzy(field, value, level)."""
942973
func_name = expression.name.upper()
943-
if func_name == "FULLTEXT" and len(expression.expressions) >= 2:
944-
first_arg = expression.expressions[0]
945-
second_arg = expression.expressions[1]
974+
args = expression.expressions
975+
976+
if func_name == "FULLTEXT" and len(args) >= 2:
977+
field_name = args[0].name if isinstance(args[0], exp.Column) else None
978+
value = self._extract_literal_value(args[1])
979+
980+
# Optional 3rd arg: slop (int)
981+
slop = None
982+
if len(args) >= 3:
983+
slop_val = self._extract_literal_value(args[2])
984+
if slop_val is not None:
985+
slop = int(slop_val)
986+
987+
# Optional 4th arg: inorder (boolean-like: 1/0 or true/false)
988+
inorder = False
989+
if len(args) >= 4:
990+
inorder_val = self._extract_literal_value(args[3])
991+
if inorder_val is not None:
992+
inorder = str(inorder_val).lower() in ("1", "true")
946993

947-
field_name = None
948-
if isinstance(first_arg, exp.Column):
949-
field_name = first_arg.name
994+
if field_name is not None:
995+
result.conditions.append(
996+
Condition(
997+
field=field_name,
998+
operator="FULLTEXT",
999+
value=value,
1000+
negated=negated,
1001+
slop=slop,
1002+
inorder=inorder,
1003+
)
1004+
)
9501005

951-
value = self._extract_literal_value(second_arg)
1006+
elif func_name == "FUZZY" and len(args) >= 2:
1007+
field_name = args[0].name if isinstance(args[0], exp.Column) else None
1008+
value = self._extract_literal_value(args[1])
1009+
1010+
# Optional 3rd arg: fuzzy level (1, 2, or 3)
1011+
fuzzy_level = None
1012+
if len(args) >= 3:
1013+
level_val = self._extract_literal_value(args[2])
1014+
if level_val is not None:
1015+
fuzzy_level = int(level_val)
9521016

9531017
if field_name is not None:
9541018
result.conditions.append(
9551019
Condition(
9561020
field=field_name,
957-
operator="FULLTEXT",
1021+
operator="FUZZY",
9581022
value=value,
9591023
negated=negated,
1024+
fuzzy_level=fuzzy_level,
9601025
)
9611026
)
9621027

@@ -983,6 +1048,9 @@ def _extract_literal_value(self, expression, convert_dates: bool = False):
9831048
if timestamp is not None:
9841049
return timestamp
9851050
return value
1051+
elif isinstance(expression, exp.Boolean):
1052+
# Handle TRUE/FALSE keywords parsed by sqlglot
1053+
return expression.this
9861054
elif isinstance(expression, exp.Neg):
9871055
# Handle negative numbers: Neg(Literal(122.4)) -> -122.4
9881056
inner_value = self._extract_literal_value(expression.this)

sql_redis/query_builder.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def build_text_condition(
6868
operator: str,
6969
value: str,
7070
negated: bool = False,
71+
*,
72+
fuzzy_level: int | None = None,
73+
slop: int | None = None,
74+
inorder: bool = False,
7175
) -> str:
7276
"""Build query syntax for TEXT field conditions.
7377
@@ -76,6 +80,9 @@ def build_text_condition(
7680
operator: One of =, !=, FULLTEXT, LIKE, FUZZY.
7781
value: The search term or pattern.
7882
negated: If True, prefix with - for negation.
83+
fuzzy_level: Levenshtein distance for FUZZY (1, 2, or 3). Default 1.
84+
slop: Maximum distance between terms for proximity search.
85+
inorder: If True with slop, require terms in order.
7986
8087
Returns:
8188
RediSearch query syntax like @field:"exact phrase" or @field:(term1 term2).
@@ -91,19 +98,24 @@ def build_text_condition(
9198

9299
# Handle different operators
93100
if operator == "LIKE":
94-
# Convert SQL LIKE pattern (%) to RediSearch prefix (*)
101+
# Convert SQL LIKE pattern (%) to RediSearch prefix/suffix/infix (*)
95102
search_value = value.replace("%", "*")
96103
elif operator == "FUZZY":
97-
# Wrap with % for fuzzy matching
98-
search_value = f"%{value}%"
104+
# Wrap with % signs — count determined by fuzzy_level
105+
level = fuzzy_level if fuzzy_level is not None else 1
106+
if level not in (1, 2, 3):
107+
raise ValueError(
108+
f"Fuzzy level must be 1, 2, or 3 (got {level}). "
109+
"RediSearch supports a maximum Levenshtein distance of 3."
110+
)
111+
pct = "%" * level
112+
search_value = f"{pct}{value}{pct}"
99113
elif operator in ("=", "!="):
100114
# Exact phrase match — always wrap in quotes, preserve stopwords.
101-
# This ensures "bank of america" stays as-is rather than
102-
# being tokenized or having stopwords stripped.
103115
escaped = self._escape_text_value(value)
104116
search_value = f'"{escaped}"'
105-
elif " " in value:
106-
# MATCH with multi-word: tokenized search with stopword filtering
117+
elif " " in value and " OR " not in value:
118+
# FULLTEXT/MATCH with multi-word: tokenized search with stopword filtering
107119
words = value.split()
108120
removed_stopwords = [
109121
w for w in words if w.lower() in REDIS_DEFAULT_STOPWORDS
@@ -122,13 +134,25 @@ def build_text_condition(
122134
stacklevel=2,
123135
)
124136

125-
# Use filtered words in parentheses (AND semantics), or original if all were stopwords
126137
terms = " ".join(filtered_words) if filtered_words else value
127138
search_value = f"({terms})"
139+
elif " OR " in value:
140+
# OR union within text field: split on ' OR ' and join with |
141+
or_terms = [t.strip() for t in value.split(" OR ")]
142+
search_value = f"({'|'.join(or_terms)})"
128143
else:
129144
search_value = value
130145

131-
return f"{prefix}@{field}:{search_value}"
146+
base = f"{prefix}@{field}:{search_value}"
147+
148+
# Append query attributes (slop, inorder) if specified
149+
if slop is not None:
150+
attrs = f"$slop: {slop};"
151+
if inorder:
152+
attrs += " $inorder: true;"
153+
base = f"{base} => {{ {attrs} }}"
154+
155+
return base
132156

133157
def _escape_tag_value(self, value: str) -> str:
134158
"""Escape special characters in TAG values."""

sql_redis/translator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class TranslatedQuery:
2727
query_string: str
2828
args: list[str] = field(default_factory=list)
2929
params: dict[str, object] = field(default_factory=dict) # Named parameters
30+
score_alias: str | None = None # Alias for score column when WITHSCORES is used
3031

3132
def to_command_list(self) -> list[str]:
3233
"""Return as a list suitable for redis.execute_command()."""
@@ -197,6 +198,9 @@ def _build_condition(self, condition: Condition, field_type: str | None) -> str:
197198
operator,
198199
str(condition.value),
199200
is_negated,
201+
fuzzy_level=condition.fuzzy_level,
202+
slop=condition.slop,
203+
inorder=condition.inorder,
200204
)
201205
elif field_type == "TAG":
202206
# Keep list value for IN clauses, convert scalar to string
@@ -301,6 +305,18 @@ def _build_search(
301305
offset = parsed.offset or 0
302306
args.extend(["LIMIT", str(offset), str(parsed.limit)])
303307

308+
# Scoring — WITHSCORES and SCORER
309+
if parsed.scoring is not None:
310+
args.append("WITHSCORES")
311+
if parsed.scoring.scorer:
312+
args.extend(["SCORER", parsed.scoring.scorer])
313+
314+
# Verbatim / nostopwords flags
315+
if parsed.verbatim:
316+
args.append("VERBATIM")
317+
if parsed.nostopwords:
318+
args.append("NOSTOPWORDS")
319+
304320
# DIALECT 2 — unconditionally appended as the last arguments
305321
args.extend(["DIALECT", "2"])
306322

@@ -310,6 +326,7 @@ def _build_search(
310326
query_string=query_string,
311327
args=args,
312328
params=params,
329+
score_alias=(parsed.scoring.alias if parsed.scoring is not None else None),
313330
)
314331

315332
def _build_geo_filter_args(self, geo_cond: GeoDistanceCondition) -> list[str]:

0 commit comments

Comments
 (0)