22#
33# SPDX-License-Identifier: Apache-2.0
44
5- import re
65from datetime import datetime
76from itertools import chain
87from typing import Any , Literal
98
109from haystack .errors import FilterError
11- from psycopg .sql import SQL , Composed
10+ from psycopg .sql import SQL , Composable , Composed , Identifier
11+ from psycopg .sql import Literal as SQLLiteral
1212from psycopg .types .json import Jsonb
1313
1414# we need this mapping to cast meta values to the correct type,
2121}
2222
2323NO_VALUE = "no_value"
24- SAFE_META_FIELD_RE = re .compile (r"^[a-zA-Z0-9_-]+$" )
2524
2625
2726def _validate_filters (filters : dict [str , Any ] | None = None ) -> None :
@@ -48,13 +47,13 @@ def _convert_filters_to_where_clause_and_params(
4847 else :
4948 query , values = _parse_logical_condition (filters )
5049
51- where_clause = SQL (f" { operator } " ) + SQL ( query )
50+ where_clause = SQL (f" { operator } " ) + query
5251 params = tuple (value for value in values if value != NO_VALUE )
5352
5453 return where_clause , params
5554
5655
57- def _parse_logical_condition (condition : dict [str , Any ]) -> tuple [str , list [Any ]]:
56+ def _parse_logical_condition (condition : dict [str , Any ]) -> tuple [Composed , list [Any ]]:
5857 if "operator" not in condition :
5958 msg = f"'operator' key missing in { condition } "
6059 raise FilterError (msg )
@@ -84,17 +83,14 @@ def _parse_logical_condition(condition: dict[str, Any]) -> tuple[str, list[Any]]
8483 values = list (chain .from_iterable (values ))
8584
8685 if operator == "AND" :
87- sql_query = f"({ ' AND ' .join (query_parts )} )"
88- elif operator == "OR" :
89- sql_query = f"({ ' OR ' .join (query_parts )} )"
90- else :
91- msg = f"Unknown logical operator '{ operator } '"
92- raise FilterError (msg )
86+ sql_query = SQL ("(" ) + SQL (" AND " ).join (query_parts ) + SQL (")" )
87+ else : # operator == "OR"
88+ sql_query = SQL ("(" ) + SQL (" OR " ).join (query_parts ) + SQL (")" )
9389
9490 return sql_query , values
9591
9692
97- def _parse_comparison_condition (condition : dict [str , Any ]) -> tuple [str , list [Any ]]:
93+ def _parse_comparison_condition (condition : dict [str , Any ]) -> tuple [Composed , list [Any ]]:
9894 field : str = condition ["field" ]
9995 if "operator" not in condition :
10096 msg = f"'operator' key missing in { condition } "
@@ -110,34 +106,37 @@ def _parse_comparison_condition(condition: dict[str, Any]) -> tuple[str, list[An
110106 value : Any = condition ["value" ]
111107
112108 if field .startswith ("meta." ):
113- field = _treat_meta_field (field , value )
109+ sql_field : Composable = _treat_meta_field (field , value )
110+ else :
111+ sql_field = Identifier (field )
114112
115- field , value = COMPARISON_OPERATORS [operator ](field , value )
116- return field , [value ]
113+ sql_expr , value = COMPARISON_OPERATORS [operator ](sql_field , value )
114+ return sql_expr , [value ]
117115
118116
119- def _treat_meta_field (field : str , value : Any ) -> str :
117+ def _treat_meta_field (field : str , value : Any ) -> Composed :
120118 """
121- Internal method that modifies the field str
122- to make the meta JSONB field queryable.
119+ Internal method that returns a psycopg Composed object
120+ to make the meta JSONB field queryable safely.
121+
122+ Uses psycopg.sql.Literal to embed the field name, preventing SQL injection
123+ via metadata field names without requiring regex validation.
124+
125+ Use the ->> operator to access keys in the meta JSONB field.
123126
124127 Examples:
125128 >>> _treat_meta_field(field="meta.number", value=9)
126- "( meta->>'number')::integer"
129+ Composed([SQL('( meta->>'), Literal(' number'), SQL(') ::integer')])
127130
128131 >>> _treat_meta_field(field="meta.name", value="my_name")
129- "meta->>'name'"
130- """
132+ Composed([SQL('meta->>'), Literal('name')])
131133
132- # use the ->> operator to access keys in the meta JSONB field
134+ """
133135 field_name = field .split ("." , 1 )[- 1 ]
134- if not SAFE_META_FIELD_RE .match (field_name ):
135- msg = (
136- f"Invalid metadata field name '{ field_name } '. "
137- "Only alphanumeric characters, dashes, and underscores are allowed."
138- )
139- raise FilterError (msg )
140- field = f"meta->>'{ field_name } '"
136+
137+ # Use SQLLiteral to safely embed the field name as a SQL string literal,
138+ # preventing SQL injection via metadata field names.
139+ composed : Composed = SQL ("meta->>" ) + SQLLiteral (field_name )
141140
142141 # meta fields are stored as strings in the JSONB field,
143142 # so we need to cast them to the correct type
@@ -146,25 +145,24 @@ def _treat_meta_field(field: str, value: Any) -> str:
146145 type_value = PYTHON_TYPES_TO_PG_TYPES .get (type (value [0 ]))
147146
148147 if type_value :
149- field = f"( { field } ) ::{ type_value } "
148+ composed = SQL ( "(" ) + composed + SQL ( f") ::{ type_value } ")
150149
151- return field
150+ return composed
152151
153152
154- def _equal (field : str , value : Any ) -> tuple [str , Any ]:
153+ def _equal (field : Composable , value : Any ) -> tuple [Composed , Any ]:
155154 if value is None :
156- # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params
157- return f"{ field } IS NULL" , NO_VALUE
158- return f"{ field } = %s" , value
155+ return SQL ("{} IS NULL" ).format (field ), NO_VALUE
156+ return SQL ("{} = %s" ).format (field ), value
159157
160158
161- def _not_equal (field : str , value : Any ) -> tuple [str , Any ]:
159+ def _not_equal (field : Composable , value : Any ) -> tuple [Composed , Any ]:
162160 # we use IS DISTINCT FROM to correctly handle NULL values
163161 # (not handled by !=)
164- return f" { field } IS DISTINCT FROM %s" , value
162+ return SQL ( "{ } IS DISTINCT FROM %s"). format ( field ) , value
165163
166164
167- def _greater_than (field : str , value : Any ) -> tuple [str , Any ]:
165+ def _greater_than (field : Composable , value : Any ) -> tuple [Composed , Any ]:
168166 if isinstance (value , str ):
169167 try :
170168 datetime .fromisoformat (value )
@@ -177,11 +175,10 @@ def _greater_than(field: str, value: Any) -> tuple[str, Any]:
177175 if type (value ) in [list , Jsonb ]:
178176 msg = f"Filter value can't be of type { type (value )} using operators '>', '>=', '<', '<='"
179177 raise FilterError (msg )
178+ return SQL ("{} > %s" ).format (field ), value
180179
181- return f"{ field } > %s" , value
182180
183-
184- def _greater_than_equal (field : str , value : Any ) -> tuple [str , Any ]:
181+ def _greater_than_equal (field : Composable , value : Any ) -> tuple [Composed , Any ]:
185182 if isinstance (value , str ):
186183 try :
187184 datetime .fromisoformat (value )
@@ -194,11 +191,10 @@ def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]:
194191 if type (value ) in [list , Jsonb ]:
195192 msg = f"Filter value can't be of type { type (value )} using operators '>', '>=', '<', '<='"
196193 raise FilterError (msg )
197-
198- return f"{ field } >= %s" , value
194+ return SQL ("{} >= %s" ).format (field ), value
199195
200196
201- def _less_than (field : str , value : Any ) -> tuple [str , Any ]:
197+ def _less_than (field : Composable , value : Any ) -> tuple [Composed , Any ]:
202198 if isinstance (value , str ):
203199 try :
204200 datetime .fromisoformat (value )
@@ -211,11 +207,10 @@ def _less_than(field: str, value: Any) -> tuple[str, Any]:
211207 if type (value ) in [list , Jsonb ]:
212208 msg = f"Filter value can't be of type { type (value )} using operators '>', '>=', '<', '<='"
213209 raise FilterError (msg )
210+ return SQL ("{} < %s" ).format (field ), value
214211
215- return f"{ field } < %s" , value
216212
217-
218- def _less_than_equal (field : str , value : Any ) -> tuple [str , Any ]:
213+ def _less_than_equal (field : Composable , value : Any ) -> tuple [Composed , Any ]:
219214 if isinstance (value , str ):
220215 try :
221216 datetime .fromisoformat (value )
@@ -228,39 +223,37 @@ def _less_than_equal(field: str, value: Any) -> tuple[str, Any]:
228223 if type (value ) in [list , Jsonb ]:
229224 msg = f"Filter value can't be of type { type (value )} using operators '>', '>=', '<', '<='"
230225 raise FilterError (msg )
231-
232- return f"{ field } <= %s" , value
226+ return SQL ("{} <= %s" ).format (field ), value
233227
234228
235- def _not_in (field : str , value : Any ) -> tuple [str , list ]:
229+ def _not_in (field : Composable , value : Any ) -> tuple [Composed , list ]:
236230 if not isinstance (value , list ):
237231 msg = f"{ field } 's value must be a list when using 'not in' comparator in Pinecone"
238232 raise FilterError (msg )
239-
240- return f"{ field } IS NULL OR { field } != ALL(%s)" , [value ]
233+ return SQL ("{} IS NULL OR {} != ALL(%s)" ).format (field , field ), [value ]
241234
242235
243- def _in (field : str , value : Any ) -> tuple [str , list ]:
236+ def _in (field : Composable , value : Any ) -> tuple [Composed , list ]:
244237 if not isinstance (value , list ):
245238 msg = f"{ field } 's value must be a list when using 'in' comparator in Pinecone"
246239 raise FilterError (msg )
247240
248241 # see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation
249- return f" { field } = ANY(%s)" , [value ]
242+ return SQL ( "{ } = ANY(%s)"). format ( field ) , [value ]
250243
251244
252- def _like (field : str , value : Any ) -> tuple [str , Any ]:
245+ def _like (field : Composable , value : Any ) -> tuple [Composed , Any ]:
253246 if not isinstance (value , str ):
254247 msg = f"{ field } 's value must be a str when using 'LIKE' "
255248 raise FilterError (msg )
256- return f" { field } LIKE %s" , value
249+ return SQL ( "{ } LIKE %s"). format ( field ) , value
257250
258251
259- def _not_like (field : str , value : Any ) -> tuple [str , Any ]:
252+ def _not_like (field : Composable , value : Any ) -> tuple [Composed , Any ]:
260253 if not isinstance (value , str ):
261254 msg = f"{ field } 's value must be a str when using 'LIKE' "
262255 raise FilterError (msg )
263- return f" { field } NOT LIKE %s" , value
256+ return SQL ( "{ } NOT LIKE %s"). format ( field ) , value
264257
265258
266259COMPARISON_OPERATORS = {
0 commit comments