Skip to content

Commit 1f8e261

Browse files
committed
Address review comments: strict parser validation, stable score alias resolution, OR docs fix
- Validate fulltext/fuzzy/score args are literals (allow Placeholders) - Collect all field names across result set before resolving score alias - Clarify OR operator case-sensitivity in README
1 parent a9eea83 commit 1f8e261

3 files changed

Lines changed: 62 additions & 25 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ SELECT * FROM products WHERE fulltext(title, 'laptop') OR fulltext(description,
218218
- `=` on TEXT fields performs **exact phrase** matching (preserves stopwords)
219219
- `fulltext()` performs **tokenized** search (stopwords are filtered with a warning)
220220
- `fuzzy()` and `fulltext()` only work on TEXT fields — using them on TAG or NUMERIC raises `ValueError`
221-
- OR is case-insensitive: `'laptop OR tablet'`, `'laptop or tablet'`, and `'laptop Or tablet'` all work
221+
- OR must be **uppercase**: `'laptop OR tablet'` triggers union; lowercase `'laptop or tablet'` is treated as a regular three-word AND search
222222
- Special characters (`@`, `|`, `-`, `*`, `+`, etc.) in search terms are automatically escaped
223223

224224
### IS NULL / IS NOT NULL (ismissing)

sql_redis/executor.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ def _resolve_score_alias(
122122
first_row_fields: set[str] | None = None,
123123
) -> str:
124124
"""Determine a stable score column name that won't collide with
125-
document fields. The alias is decided once before iterating rows
126-
so every row uses the same column name.
125+
document fields. The alias is resolved once and reused for every
126+
row so all rows share the same column name.
127127
128128
When a RETURN clause is present, the returned field names are used
129129
for collision detection. When RETURN is absent (SELECT *), the
130-
caller should pass ``first_row_fields`` — the field names from the
131-
first result row — so we can detect collisions even when all
132-
document attributes are returned."""
130+
caller should pass ``first_row_fields`` — the union of all field
131+
names across all result rows — so we can detect collisions even
132+
when different documents have different field sets."""
133133
alias = score_alias or "__score"
134134
# Extract RETURN field names from args to detect collision
135135
try:
@@ -228,19 +228,23 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
228228
elif with_scores:
229229
# WITHSCORES format: [count, key1, score1, [fields1], key2, score2, [fields2], ...]
230230
# Stride of 3: key, score, field_list
231-
# Resolve alias once from the first row so every row uses the
232-
# same column name (consistent output schema).
233-
resolved_alias: str | None = None
231+
# First pass: collect all field names across all rows so the
232+
# alias avoids collisions with any document field, not just
233+
# the first row's fields.
234+
all_field_names: set[str] = set()
235+
parsed_rows: list[tuple[dict, Any]] = []
234236
for i in range(1, len(raw_result) - 2, 3):
235237
score = raw_result[i + 1]
236238
row_data = raw_result[i + 2]
237239
row = dict(zip(row_data[::2], row_data[1::2]))
238-
if resolved_alias is None:
239-
resolved_alias = self._resolve_score_alias(
240-
translated.score_alias,
241-
translated.args,
242-
first_row_fields=set(row.keys()),
243-
)
240+
all_field_names.update(row.keys())
241+
parsed_rows.append((row, score))
242+
resolved_alias = self._resolve_score_alias(
243+
translated.score_alias,
244+
translated.args,
245+
first_row_fields=all_field_names,
246+
)
247+
for row, score in parsed_rows:
244248
row[resolved_alias] = score
245249
rows.append(row)
246250
else:
@@ -345,19 +349,22 @@ async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
345349
rows.append(row)
346350
elif with_scores:
347351
# WITHSCORES format: [count, key1, score1, [fields1], ...]
348-
# Resolve alias once from the first row so every row uses the
349-
# same column name (consistent output schema).
350-
resolved_alias: str | None = None
352+
# First pass: collect all field names across all rows so the
353+
# alias avoids collisions with any document field.
354+
all_field_names: set[str] = set()
355+
parsed_rows: list[tuple[dict, Any]] = []
351356
for i in range(1, len(raw_result) - 2, 3):
352357
score = raw_result[i + 1]
353358
row_data = raw_result[i + 2]
354359
row = dict(zip(row_data[::2], row_data[1::2]))
355-
if resolved_alias is None:
356-
resolved_alias = self._resolve_score_alias(
357-
translated.score_alias,
358-
translated.args,
359-
first_row_fields=set(row.keys()),
360-
)
360+
all_field_names.update(row.keys())
361+
parsed_rows.append((row, score))
362+
resolved_alias = self._resolve_score_alias(
363+
translated.score_alias,
364+
translated.args,
365+
first_row_fields=all_field_names,
366+
)
367+
for row, score in parsed_rows:
361368
row[resolved_alias] = score
362369
rows.append(row)
363370
else:

sql_redis/parser.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,12 @@ def _process_select_expression_inner(
471471
"score() argument must be a literal scorer name "
472472
f"(e.g., 'BM25', 'TFIDF'), got {expression.expressions[0]}."
473473
)
474-
scorer = str(scorer_val)
474+
if not isinstance(scorer_val, str):
475+
raise ValueError(
476+
"score() argument must be a string scorer name "
477+
f"(e.g., 'BM25', 'TFIDF'), got {scorer_val!r}."
478+
)
479+
scorer = scorer_val
475480
if result.scoring is not None:
476481
raise ValueError(
477482
"Only one score() expression is allowed per query."
@@ -1003,11 +1008,21 @@ def _add_function_condition(
10031008
if func_name == "FULLTEXT" and len(args) >= 2:
10041009
field_name = args[0].name if isinstance(args[0], exp.Column) else None
10051010
value = self._extract_literal_value(args[1])
1011+
if value is None and not isinstance(args[1], exp.Placeholder):
1012+
raise ValueError(
1013+
"fulltext() second argument must be a literal string, "
1014+
f"got {args[1]}. Usage: fulltext(field, 'search terms')"
1015+
)
10061016

10071017
# Optional 3rd arg: slop (non-negative int)
10081018
slop = None
10091019
if len(args) >= 3:
10101020
slop_val = self._extract_literal_value(args[2])
1021+
if slop_val is None and not isinstance(args[2], exp.Placeholder):
1022+
raise ValueError(
1023+
"fulltext() slop argument must be a literal integer, "
1024+
f"got {args[2]}."
1025+
)
10111026
if slop_val is not None:
10121027
# Reject booleans and non-integer floats — only real
10131028
# integers are valid for slop.
@@ -1029,6 +1044,11 @@ def _add_function_condition(
10291044
inorder = False
10301045
if len(args) >= 4:
10311046
inorder_val = self._extract_literal_value(args[3])
1047+
if inorder_val is None and not isinstance(args[3], exp.Placeholder):
1048+
raise ValueError(
1049+
"fulltext() inorder argument must be a literal boolean "
1050+
f"(true/false or 1/0), got {args[3]}."
1051+
)
10321052
if inorder_val is not None:
10331053
if isinstance(inorder_val, bool):
10341054
inorder = inorder_val
@@ -1059,11 +1079,21 @@ def _add_function_condition(
10591079
elif func_name == "FUZZY" and len(args) >= 2:
10601080
field_name = args[0].name if isinstance(args[0], exp.Column) else None
10611081
value = self._extract_literal_value(args[1])
1082+
if value is None and not isinstance(args[1], exp.Placeholder):
1083+
raise ValueError(
1084+
"fuzzy() second argument must be a literal string, "
1085+
f"got {args[1]}. Usage: fuzzy(field, 'search term')"
1086+
)
10621087

10631088
# Optional 3rd arg: fuzzy level (1, 2, or 3)
10641089
fuzzy_level = None
10651090
if len(args) >= 3:
10661091
level_val = self._extract_literal_value(args[2])
1092+
if level_val is None and not isinstance(args[2], exp.Placeholder):
1093+
raise ValueError(
1094+
"fuzzy() level argument must be a literal integer, "
1095+
f"got {args[2]}."
1096+
)
10671097
if level_val is not None:
10681098
if isinstance(level_val, bool):
10691099
raise ValueError(

0 commit comments

Comments
 (0)