22
33from __future__ import annotations
44
5- from dataclasses import dataclass , field
5+ import dataclasses
6+ from dataclasses import dataclass
67
78import sqlglot
89from sqlglot import exp
@@ -15,6 +16,9 @@ class AggregationSpec:
1516 function : str
1617 field : str | None = None
1718 alias : str | None = None
19+ extra_args : list [str ] = dataclasses .field (
20+ default_factory = list
21+ ) # For reducers like QUANTILE
1822
1923
2024@dataclass
@@ -49,14 +53,14 @@ class ParsedQuery:
4953 """Result of parsing a SQL query."""
5054
5155 index : str = ""
52- fields : list [str ] = field (default_factory = list )
53- conditions : list [Condition ] = field (default_factory = list )
56+ fields : list [str ] = dataclasses . field (default_factory = list )
57+ conditions : list [Condition ] = dataclasses . field (default_factory = list )
5458 boolean_operator : str = "AND"
55- aggregations : list [AggregationSpec ] = field (default_factory = list )
56- computed_fields : list [ComputedField ] = field (default_factory = list )
59+ aggregations : list [AggregationSpec ] = dataclasses . field (default_factory = list )
60+ computed_fields : list [ComputedField ] = dataclasses . field (default_factory = list )
5761 vector_search : VectorSearchSpec | None = None
58- groupby_fields : list [str ] = field (default_factory = list )
59- orderby_fields : list [tuple [str , str ]] = field (
62+ groupby_fields : list [str ] = dataclasses . field (default_factory = list )
63+ orderby_fields : list [tuple [str , str ]] = dataclasses . field (
6064 default_factory = list
6165 ) # (field, ASC|DESC)
6266 limit : int | None = None
@@ -150,9 +154,28 @@ def _process_select_expression_inner(
150154 result .fields .append (expression .name )
151155 elif isinstance (expression , exp .Star ):
152156 result .fields .append ("*" )
153- elif isinstance (expression , (exp .Count , exp .Sum , exp .Avg , exp .Min , exp .Max )):
157+ elif isinstance (
158+ expression ,
159+ (
160+ exp .Count ,
161+ exp .Sum ,
162+ exp .Avg ,
163+ exp .Min ,
164+ exp .Max ,
165+ exp .Stddev ,
166+ exp .Variance ,
167+ exp .FirstValue ,
168+ exp .ArrayAgg ,
169+ ),
170+ ):
154171 # Aggregation function
172+ # Map sqlglot function names to Redis reducer names
155173 func_name = expression .key .upper ()
174+ redis_func_map = {
175+ "FIRSTVALUE" : "FIRST_VALUE" ,
176+ "ARRAYAGG" : "TOLIST" ,
177+ }
178+ func_name = redis_func_map .get (func_name , func_name )
156179 field_name = None
157180 # Get the field being aggregated (if any)
158181 if expression .this :
@@ -184,10 +207,34 @@ def _process_select_expression_inner(
184207 # - Distance: L2/Euclidean distance
185208 # - CosineDistance: cosine_distance() function
186209 self ._process_vector_distance (expression , result , alias )
210+ elif isinstance (expression , exp .Quantile ):
211+ # QUANTILE(field, quantile_value) -> REDUCE QUANTILE 2 @field quantile_value
212+ field_name = None
213+ if expression .this and isinstance (expression .this , exp .Column ):
214+ field_name = expression .this .name
215+ quantile_value = None
216+ if expression .args .get ("quantile" ):
217+ quantile_value = str (expression .args ["quantile" ].this )
218+ extra_args = [quantile_value ] if quantile_value else []
219+ result .aggregations .append (
220+ AggregationSpec (
221+ function = "QUANTILE" ,
222+ field = field_name ,
223+ alias = alias ,
224+ extra_args = extra_args ,
225+ )
226+ )
187227 elif isinstance (expression , exp .Anonymous ):
188228 # Custom function call (e.g., vector_distance) - check before exp.Func
189229 # since Anonymous is a subclass of Func
190230 func_name = expression .name .lower ()
231+ # Redis-specific reducer functions that sqlglot doesn't recognize
232+ redis_reducers = {
233+ "count_distinct" ,
234+ "count_distinctish" ,
235+ "quantile" ,
236+ "random_sample" ,
237+ }
191238 if func_name == "vector_distance" :
192239 # Extract the vector field name from first argument
193240 if expression .expressions :
@@ -198,6 +245,26 @@ def _process_select_expression_inner(
198245 field = field_name ,
199246 alias = alias or func_name ,
200247 )
248+ elif func_name in redis_reducers :
249+ # Redis-specific reducer functions
250+ field_name = None
251+ reducer_extra_args : list [str ] = []
252+ if expression .expressions :
253+ first_arg = expression .expressions [0 ]
254+ if isinstance (first_arg , exp .Column ):
255+ field_name = first_arg .name
256+ # Extract additional arguments (e.g., quantile value for QUANTILE)
257+ for arg in expression .expressions [1 :]:
258+ if isinstance (arg , exp .Literal ):
259+ reducer_extra_args .append (str (arg .this ))
260+ result .aggregations .append (
261+ AggregationSpec (
262+ function = func_name .upper (),
263+ field = field_name ,
264+ alias = alias ,
265+ extra_args = reducer_extra_args ,
266+ )
267+ )
201268 else :
202269 # Other custom functions - treat as computed field
203270 expr_str = expression .sql ()
0 commit comments