Skip to content

Commit 1621b15

Browse files
committed
Fix Copilot review issues for date/geo support
- Fix geo_distance parsing: use func_name_lower for case-insensitive match - Reject date function predicates combined with OR operator - Add date function source fields to LOAD clause - Validate DATE_FORMAT requires exactly 2 args with literal format string - Defer date literal conversion to translator (preserve strings in parser) - Quote string values in FILTER expressions for date functions - Update tests to reflect deferred date conversion behavior
1 parent 663cc82 commit 1621b15

3 files changed

Lines changed: 178 additions & 91 deletions

File tree

sql_redis/parser.py

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,56 @@
1919
)
2020

2121

22+
def parse_date_to_timestamp(value: str) -> int | None:
23+
"""Parse an ISO 8601 date/datetime string to Unix timestamp.
24+
25+
Supports:
26+
- Date: '2024-01-01' (interpreted as midnight UTC)
27+
- Datetime: '2024-01-01T12:00:00' or '2024-01-01 12:00:00'
28+
- Datetime with timezone: '2024-01-01T12:00:00Z', '2024-01-01T12:00:00+00:00'
29+
30+
Args:
31+
value: The string value to parse.
32+
33+
Returns:
34+
Unix timestamp as integer, or None if not a valid date string.
35+
"""
36+
# Check if it matches date pattern
37+
if DATE_PATTERN.match(value):
38+
try:
39+
dt = datetime.strptime(value, "%Y-%m-%d")
40+
# Treat as UTC midnight
41+
dt = dt.replace(tzinfo=timezone.utc)
42+
return int(dt.timestamp())
43+
except ValueError:
44+
return None
45+
46+
# Check if it matches datetime pattern
47+
if DATETIME_PATTERN.match(value):
48+
# Normalize: replace space with T for parsing
49+
normalized = value.replace(" ", "T")
50+
51+
# Normalize 'Z' (UTC designator) to '+00:00' for fromisoformat
52+
if normalized.endswith("Z"):
53+
normalized = normalized[:-1] + "+00:00"
54+
55+
# Normalize timezone offsets without colon (+0000 -> +00:00)
56+
# This ensures compatibility with datetime.fromisoformat
57+
normalized = re.sub(r"([+-]\d{2})(\d{2})$", r"\1:\2", normalized)
58+
59+
try:
60+
# Use fromisoformat for robust parsing (handles fractional seconds)
61+
dt = datetime.fromisoformat(normalized)
62+
# If no timezone info, treat as UTC
63+
if dt.tzinfo is None:
64+
dt = dt.replace(tzinfo=timezone.utc)
65+
return int(dt.timestamp())
66+
except ValueError:
67+
return None
68+
69+
return None
70+
71+
2272
@dataclass
2373
class AggregationSpec:
2474
"""Specification for an aggregation function."""
@@ -364,7 +414,7 @@ def _process_select_expression_inner(
364414
field=field_name,
365415
alias=alias or func_name_lower,
366416
)
367-
elif func_name == "geo_distance":
417+
elif func_name_lower == "geo_distance":
368418
# geo_distance(field, POINT(lon, lat), unit) in SELECT
369419
self._process_geo_distance_select(expression, result, alias)
370420
elif func_name_lower in redis_reducers:
@@ -493,16 +543,25 @@ def _process_date_function(
493543
field_name = None
494544
format_string = None
495545

496-
if expression.expressions:
497-
first_arg = expression.expressions[0]
546+
args = expression.expressions or []
547+
548+
if func_name == "DATE_FORMAT":
549+
# DATE_FORMAT requires exactly 2 arguments: field, format_string
550+
if len(args) != 2:
551+
raise ValueError(
552+
"DATE_FORMAT requires exactly 2 arguments: field, format_string"
553+
)
554+
first_arg, second_arg = args
555+
if isinstance(first_arg, exp.Column):
556+
field_name = first_arg.name
557+
# Format argument must be a literal string
558+
if not isinstance(second_arg, exp.Literal) or not second_arg.is_string:
559+
raise ValueError("DATE_FORMAT format argument must be a literal string")
560+
format_string = second_arg.this
561+
elif args:
562+
first_arg = args[0]
498563
if isinstance(first_arg, exp.Column):
499564
field_name = first_arg.name
500-
501-
# For DATE_FORMAT, extract the format string as second argument
502-
if func_name == "DATE_FORMAT" and len(expression.expressions) >= 2:
503-
second_arg = expression.expressions[1]
504-
if isinstance(second_arg, exp.Literal):
505-
format_string = second_arg.this
506565

507566
if field_name:
508567
# Generate default alias if not provided
@@ -822,12 +881,15 @@ def _add_function_condition(
822881
)
823882
)
824883

825-
def _extract_literal_value(self, expression, convert_dates: bool = True):
884+
def _extract_literal_value(self, expression, convert_dates: bool = False):
826885
"""Extract a Python value from a sqlglot Literal or Neg expression.
827886
828887
Args:
829888
expression: The sqlglot expression to extract from.
830889
convert_dates: If True, convert ISO 8601 date strings to Unix timestamps.
890+
Default is False to avoid changing semantics for TEXT/TAG
891+
fields. Date conversion should be handled by the translator
892+
when the field type is known to be NUMERIC.
831893
832894
Returns:
833895
The extracted value, or None if not a literal.
@@ -872,48 +934,6 @@ def _validate_geo_unit(self, unit_val: object) -> str:
872934
def _parse_date_to_timestamp(self, value: str) -> int | None:
873935
"""Parse an ISO 8601 date/datetime string to Unix timestamp.
874936
875-
Supports:
876-
- Date: '2024-01-01' (interpreted as midnight UTC)
877-
- Datetime: '2024-01-01T12:00:00' or '2024-01-01 12:00:00'
878-
- Datetime with timezone: '2024-01-01T12:00:00Z', '2024-01-01T12:00:00+00:00'
879-
880-
Args:
881-
value: The string value to parse.
882-
883-
Returns:
884-
Unix timestamp as integer, or None if not a valid date string.
937+
Delegates to module-level parse_date_to_timestamp function.
885938
"""
886-
# Check if it matches date pattern
887-
if DATE_PATTERN.match(value):
888-
try:
889-
dt = datetime.strptime(value, "%Y-%m-%d")
890-
# Treat as UTC midnight
891-
dt = dt.replace(tzinfo=timezone.utc)
892-
return int(dt.timestamp())
893-
except ValueError:
894-
return None
895-
896-
# Check if it matches datetime pattern
897-
if DATETIME_PATTERN.match(value):
898-
# Normalize: replace space with T for parsing
899-
normalized = value.replace(" ", "T")
900-
901-
# Normalize 'Z' (UTC designator) to '+00:00' for fromisoformat
902-
if normalized.endswith("Z"):
903-
normalized = normalized[:-1] + "+00:00"
904-
905-
# Normalize timezone offsets without colon (+0000 -> +00:00)
906-
# This ensures compatibility with datetime.fromisoformat
907-
normalized = re.sub(r"([+-]\d{2})(\d{2})$", r"\1:\2", normalized)
908-
909-
try:
910-
# Use fromisoformat for robust parsing (handles fractional seconds)
911-
dt = datetime.fromisoformat(normalized)
912-
# If no timezone info, treat as UTC
913-
if dt.tzinfo is None:
914-
dt = dt.replace(tzinfo=timezone.utc)
915-
return int(dt.timestamp())
916-
except ValueError:
917-
return None
918-
919-
return None
939+
return parse_date_to_timestamp(value)

sql_redis/translator.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Condition,
1111
GeoDistanceCondition,
1212
SQLParser,
13+
parse_date_to_timestamp,
1314
)
1415
from sql_redis.query_builder import QueryBuilder
1516
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
@@ -99,6 +100,16 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery:
99100
self._is_date_function_condition(c) for c in parsed.conditions
100101
)
101102

103+
# Validate: date function predicates cannot be combined with OR
104+
# Date filters are applied via FILTER clauses (ANDed with query).
105+
# Combining with OR would change semantics.
106+
if has_date_func_conditions and parsed.boolean_operator == "OR":
107+
raise ValueError(
108+
"Date function predicates cannot be combined with OR; "
109+
"they are applied as top-level filters and would change query "
110+
"semantics. Rewrite the query to avoid OR with date functions."
111+
)
112+
102113
# Determine if we need FT.AGGREGATE
103114
use_aggregate = (
104115
len(analyzed.aggregations) > 0
@@ -188,11 +199,16 @@ def _build_condition(self, condition: Condition, field_type: str | None) -> str:
188199
# Cast value to expected type for numeric conditions
189200
numeric_value: int | float | tuple[int | float, int | float]
190201
if isinstance(condition.value, tuple):
191-
numeric_value = condition.value # type: ignore[assignment]
202+
# Handle tuple values (e.g., BETWEEN) - try date conversion for each
203+
low, high = condition.value
204+
low_val = self._convert_to_numeric(low)
205+
high_val = self._convert_to_numeric(high)
206+
numeric_value = (low_val, high_val)
192207
elif isinstance(condition.value, (int, float)):
193208
numeric_value = condition.value
194209
else:
195-
numeric_value = float(condition.value) # type: ignore[arg-type]
210+
# Try date string conversion for NUMERIC fields
211+
numeric_value = self._convert_to_numeric(condition.value)
196212
return self._query_builder.build_numeric_condition(
197213
condition.field,
198214
operator,
@@ -207,6 +223,29 @@ def _build_condition(self, condition: Condition, field_type: str | None) -> str:
207223
condition.negated,
208224
)
209225

226+
def _convert_to_numeric(self, value: object) -> int | float:
227+
"""Convert a value to numeric, trying date string conversion if needed.
228+
229+
Args:
230+
value: The value to convert. Can be int, float, or string (possibly a date).
231+
232+
Returns:
233+
Numeric value (int or float).
234+
235+
Raises:
236+
ValueError: If conversion fails.
237+
"""
238+
if isinstance(value, (int, float)):
239+
return value
240+
if isinstance(value, str):
241+
# Try date string to timestamp conversion first
242+
timestamp = parse_date_to_timestamp(value)
243+
if timestamp is not None:
244+
return timestamp
245+
# Fall back to float conversion
246+
return float(value)
247+
return float(value) # type: ignore[arg-type]
248+
210249
def _build_search(
211250
self, analyzed: AnalyzedQuery, query_string: str
212251
) -> TranslatedQuery:
@@ -292,6 +331,13 @@ def _build_aggregate(
292331
# Load geo fields used in geo_distance() WHERE with >, >=, BETWEEN
293332
for geo_cond in geo_filter_conditions:
294333
load_fields.add(geo_cond.field)
334+
# Load source fields for date functions in SELECT
335+
for date_func in analyzed.date_functions:
336+
load_fields.add(date_func.field)
337+
# Load source fields for date function conditions in WHERE
338+
for condition in parsed.conditions:
339+
if self._is_date_function_condition(condition):
340+
load_fields.add(condition.field)
295341
# Load explicit SELECT fields for FT.AGGREGATE
296342
for field_name in parsed.fields:
297343
if field_name != "*":
@@ -576,4 +622,23 @@ def _build_date_function_filter(self, condition) -> str:
576622
op_map = {"=": "==", "!=": "!=", ">": ">", ">=": ">=", "<": "<", "<=": "<="}
577623
redis_op = op_map.get(op, "==")
578624

579-
return f"@{alias} {redis_op} {condition.value}"
625+
# Normalize value for FILTER expression (quote strings, pass numbers as-is)
626+
normalized_value = self._normalize_filter_value(condition.value)
627+
628+
return f"@{alias} {redis_op} {normalized_value}"
629+
630+
def _normalize_filter_value(self, value: object) -> str:
631+
"""Normalize a value for use in FILTER expressions.
632+
633+
Redis FILTER expressions require string values to be quoted.
634+
635+
Args:
636+
value: The value to normalize.
637+
638+
Returns:
639+
String representation suitable for FILTER expression.
640+
"""
641+
if isinstance(value, (int, float)):
642+
return str(value)
643+
# Quote string values for FILTER
644+
return f'"{value}"'

0 commit comments

Comments
 (0)