Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Convert Haystack filter dictionaries to ArcadeDB SQL WHERE clauses."""

import re
from typing import Any


Expand Down Expand Up @@ -64,6 +65,14 @@ def _parse_condition(condition: dict[str, Any]) -> str:


def _comparison_to_sql(field: str, operator: str, value: Any) -> str:

if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_.\[\]"]*$', field):
msg = (
f"Invalid field name: {field}. Field names must start with a letter or underscore and can contain "
f"letters, digits, underscores, dots, brackets, and quotes."
)
raise ValueError(msg)

if operator == "==":
if value is None:
return f"{field} IS NULL"
Expand Down Expand Up @@ -106,7 +115,7 @@ def _comparison_to_sql(field: str, operator: str, value: Any) -> str:
def _sql_value(value: Any) -> str:
"""Format a Python value as an ArcadeDB SQL literal."""
if isinstance(value, str):
escaped = value.replace("'", "\\'")
escaped = value.replace("\\", "\\\\").replace("'", "\\'")
return f"'{escaped}'"
if isinstance(value, bool):
return "true" if value else "false"
Expand Down
30 changes: 30 additions & 0 deletions integrations/arcadedb/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,33 @@ def test_conversion_edge_cases(self, filter_dict, expected):
def test_invalid_filter_raises(self, filter_dict):
with pytest.raises(ValueError):
_convert_filters(filter_dict)

@pytest.mark.parametrize(
"field",
[
"x; DROP TABLE Documents",
"x OR 1=1",
"x--",
"x; SELECT *",
"'injected'",
"1field",
"field name",
],
)
def test_sql_injection_field_names_raise(self, field):
with pytest.raises(ValueError, match="Invalid field name"):
_convert_filters({"field": field, "operator": "==", "value": "v"})

def test_value_with_backslash(self):
# A single backslash must be doubled: \ → \\
result = _convert_filters({"field": "meta.x", "operator": "==", "value": "\\"})
assert result == "meta.x = '\\\\'"

def test_value_with_backslash_then_quote(self):
# \' in value → \\ (escaped backslash) + \' (escaped quote) in SQL
result = _convert_filters({"field": "meta.x", "operator": "==", "value": "a\\'b"})
assert result == "meta.x = 'a\\\\\\'b'"

def test_value_with_single_quote(self):
result = _convert_filters({"field": "meta.x", "operator": "==", "value": "it's"})
assert result == "meta.x = 'it\\'s'"
Loading