|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import re |
5 | 6 | import warnings |
6 | 7 | from dataclasses import dataclass, field |
7 | 8 |
|
@@ -120,6 +121,7 @@ def _build_command(self, analyzed: AnalyzedQuery) -> TranslatedQuery: |
120 | 121 | or geo_requires_aggregate # geo_distance with >, >=, BETWEEN |
121 | 122 | or len(analyzed.date_functions) > 0 |
122 | 123 | or has_date_func_conditions |
| 124 | + or len(parsed.filters) > 0 # exists() in HAVING → FILTER |
123 | 125 | ) |
124 | 126 |
|
125 | 127 | # Build query string from conditions |
@@ -333,33 +335,44 @@ def _build_aggregate( |
333 | 335 | geo_filter_conditions = list(parsed.geo_conditions) |
334 | 336 |
|
335 | 337 | # LOAD fields if needed |
336 | | - load_fields = set() |
337 | | - for agg in analyzed.aggregations: |
338 | | - if agg.field: |
339 | | - load_fields.add(agg.field) |
340 | | - for field_name in analyzed.groupby_fields: |
341 | | - load_fields.add(field_name) |
342 | | - # Load geo fields used in geo_distance() SELECT expressions |
343 | | - for geo_select in parsed.geo_distance_selects: |
344 | | - load_fields.add(geo_select.field) |
345 | | - # Load geo fields used in geo_distance() WHERE with >, >=, BETWEEN |
346 | | - for geo_cond in geo_filter_conditions: |
347 | | - load_fields.add(geo_cond.field) |
348 | | - # Load source fields for date functions in SELECT |
349 | | - for date_func in analyzed.date_functions: |
350 | | - load_fields.add(date_func.field) |
351 | | - # Load source fields for date function conditions in WHERE |
352 | | - for condition in parsed.conditions: |
353 | | - if self._is_date_function_condition(condition): |
354 | | - load_fields.add(condition.field) |
355 | | - # Load explicit SELECT fields for FT.AGGREGATE |
356 | | - for field_name in parsed.fields: |
357 | | - if field_name != "*": |
| 338 | + # SELECT * in aggregate mode → LOAD * (all document attributes) |
| 339 | + load_all = "*" in (parsed.fields or []) |
| 340 | + |
| 341 | + load_fields: set[str] = set() |
| 342 | + if not load_all: |
| 343 | + for agg in analyzed.aggregations: |
| 344 | + if agg.field: |
| 345 | + load_fields.add(agg.field) |
| 346 | + for field_name in analyzed.groupby_fields: |
| 347 | + load_fields.add(field_name) |
| 348 | + # Load geo fields used in geo_distance() SELECT expressions |
| 349 | + for geo_select in parsed.geo_distance_selects: |
| 350 | + load_fields.add(geo_select.field) |
| 351 | + # Load geo fields used in geo_distance() WHERE with >, >=, BETWEEN |
| 352 | + for geo_cond in geo_filter_conditions: |
| 353 | + load_fields.add(geo_cond.field) |
| 354 | + # Load source fields for date functions in SELECT |
| 355 | + for date_func in analyzed.date_functions: |
| 356 | + load_fields.add(date_func.field) |
| 357 | + # Load source fields for date function conditions in WHERE |
| 358 | + for condition in parsed.conditions: |
| 359 | + if self._is_date_function_condition(condition): |
| 360 | + load_fields.add(condition.field) |
| 361 | + # Load explicit SELECT fields for FT.AGGREGATE |
| 362 | + for field_name in parsed.fields: |
358 | 363 | # Skip computed fields (they have aliases from geo_distance) |
359 | 364 | if field_name not in [gs.alias for gs in parsed.geo_distance_selects]: |
360 | 365 | load_fields.add(field_name) |
361 | | - |
362 | | - if load_fields: |
| 366 | + # Load fields referenced in exists() filters (HAVING) |
| 367 | + for filter_expr in parsed.filters: |
| 368 | + self._extract_exists_fields(filter_expr, load_fields) |
| 369 | + # Load fields referenced in exists() computed fields (SELECT) |
| 370 | + for computed in analyzed.computed_fields: |
| 371 | + self._extract_exists_fields(computed.expression, load_fields) |
| 372 | + |
| 373 | + if load_all: |
| 374 | + args.extend(["LOAD", "*"]) |
| 375 | + elif load_fields: |
363 | 376 | args.append("LOAD") |
364 | 377 | args.append(str(len(load_fields))) |
365 | 378 | # Redis expects property names prefixed with '@' in LOAD |
@@ -498,6 +511,13 @@ def _build_aggregate( |
498 | 511 | alias = agg.alias or agg.function.lower() |
499 | 512 | args.extend(["AS", alias]) |
500 | 513 |
|
| 514 | + # FILTER for exists() from HAVING clause (post-aggregation) |
| 515 | + for filter_expr in parsed.filters: |
| 516 | + prefixed = self._prefix_fields_in_expression( |
| 517 | + filter_expr, analyzed.field_types |
| 518 | + ) |
| 519 | + args.extend(["FILTER", prefixed]) |
| 520 | + |
501 | 521 | # SORTBY |
502 | 522 | if parsed.orderby_fields: |
503 | 523 | args.append("SORTBY") |
@@ -593,12 +613,16 @@ def _convert_to_meters(self, value: float, unit: str) -> float: |
593 | 613 | ) |
594 | 614 | return value * conversions[normalized_unit] |
595 | 615 |
|
| 616 | + @staticmethod |
| 617 | + def _extract_exists_fields(expression: str, load_fields: set[str]) -> None: |
| 618 | + """Extract field names from exists() calls and add to load_fields.""" |
| 619 | + for match in re.finditer(r"exists\((\w+)\)", expression, re.IGNORECASE): |
| 620 | + load_fields.add(match.group(1)) |
| 621 | + |
596 | 622 | def _prefix_fields_in_expression( |
597 | 623 | self, expression: str, schema: dict[str, str] |
598 | 624 | ) -> str: |
599 | 625 | """Prefix field names with @ in an expression for Redis APPLY.""" |
600 | | - import re |
601 | | - |
602 | 626 | result = expression |
603 | 627 | for field_name in schema: |
604 | 628 | # Match field name as a whole word, not already prefixed with @ |
|
0 commit comments