Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,185 @@ def delete_documents(self, document_ids: list[str]) -> None:
{"ids": document_ids},
)

def delete_all_documents(self) -> None:
"""
Delete all documents from the graph.
"""
self._ensure_connected()
self.graph.query(f"MATCH (d:{self.node_label}) DETACH DELETE d")

def delete_by_filter(self, filters: dict[str, Any]) -> int:
"""
Delete all documents that match the provided filters.

:param filters: Haystack filter dict.
:returns: Number of documents deleted.
"""
self._ensure_connected()
where_clause, params = _convert_filters(filters)
count_result = self.graph.query(
f"MATCH (d:{self.node_label}) WHERE {where_clause} RETURN count(d) AS n",
params,
)
count = int(count_result.result_set[0][0]) if count_result.result_set else 0
self.graph.query(
f"MATCH (d:{self.node_label}) WHERE {where_clause} DETACH DELETE d",
params,
)
return count

def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int:
"""
Update metadata fields on all documents that match the provided filters.

:param filters: Haystack filter dict selecting which documents to update.
:param meta: Metadata fields to set. Keys may include or omit the `meta.` prefix.
:returns: Number of documents updated.
"""
self._ensure_connected()
where_clause, params = _convert_filters(filters)
flat_meta = {k[5:] if k.startswith("meta.") else k: v for k, v in meta.items()}
params["meta_update"] = flat_meta
result = self.graph.query(
f"MATCH (d:{self.node_label}) WHERE {where_clause} SET d += $meta_update RETURN count(d) AS n",
params,
)
rows = result.result_set
return int(rows[0][0]) if rows else 0

def count_documents_by_filter(self, filters: dict[str, Any]) -> int:
"""
Return the number of documents that match the provided filters.

:param filters: Haystack filter dict.
:returns: Integer count of matching document nodes.
"""
self._ensure_connected()
where_clause, params = _convert_filters(filters)
result = self.graph.query(
f"MATCH (d:{self.node_label}) WHERE {where_clause} RETURN count(d) AS n",
params,
)
rows = result.result_set
return int(rows[0][0]) if rows else 0

def count_unique_metadata_by_filter(self, filters: dict[str, Any], metadata_fields: list[str]) -> dict[str, int]:
"""
Return the number of unique values for each metadata field among matching documents.

:param filters: Haystack filter dict. Pass an empty dict to count across all documents.
:param metadata_fields: List of metadata field names. May include or omit the `meta.` prefix.
:returns: Dict mapping each field name (without `meta.` prefix) to its unique value count.
"""
self._ensure_connected()
if filters:
where_clause, params = _convert_filters(filters)
match = f"MATCH (d:{self.node_label}) WHERE {where_clause}"
else:
params = {}
match = f"MATCH (d:{self.node_label})"

result: dict[str, int] = {}
for field in metadata_fields:
actual = field[5:] if field.startswith("meta.") else field
res = self.graph.query(
f"{match} RETURN count(DISTINCT d.{actual}) AS n",
params,
)
rows = res.result_set
result[actual] = int(rows[0][0]) if rows else 0
return result

def get_metadata_fields_info(self) -> dict[str, dict[str, str]]:
"""
Return type information for each metadata field present on document nodes.

:returns: Dict mapping field names to a `{"type": <typename>}` dict.
Type names are `"str"`, `"int"`, `"float"`, or `"bool"`.
"""
self._ensure_connected()
standard_fields = {"id", "content", "embedding", "score", "sparse_embedding"}
result = self.graph.query(f"MATCH (d:{self.node_label}) RETURN keys(d)")
all_keys: set[str] = set()
for row in result.result_set:
all_keys.update(row[0])
all_keys -= standard_fields

info: dict[str, dict[str, str]] = {}
for key in sorted(all_keys):
res = self.graph.query(f"MATCH (d:{self.node_label}) WHERE d.{key} IS NOT NULL RETURN d.{key} LIMIT 1")
if not res.result_set:
continue
val = res.result_set[0][0]
if isinstance(val, bool):
type_name = "bool"
elif isinstance(val, int):
type_name = "int"
elif isinstance(val, float):
type_name = "float"
else:
type_name = "str"
info[key] = {"type": type_name}
return info

def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]:
"""
Return the minimum and maximum values for the given metadata field.

:param metadata_field: Metadata field name. May include or omit the `meta.` prefix.
:returns: Dict with keys `"min"` and `"max"`. Values are `None` when no documents
have a non-null value for the field.
"""
self._ensure_connected()
field = metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field
result = self.graph.query(
f"MATCH (d:{self.node_label}) WHERE d.{field} IS NOT NULL RETURN min(d.{field}), max(d.{field})"
)
if not result.result_set:
return {"min": None, "max": None}
row = result.result_set[0]
return {"min": row[0], "max": row[1]}

def get_metadata_field_unique_values(
self,
metadata_field: str,
search_term: str | None = None,
size: int | None = 10000,
after: dict[str, Any] | None = None,
) -> tuple[list[Any], dict[str, Any] | None]:
"""
Return distinct values for the given metadata field with optional filtering and pagination.

:param metadata_field: Metadata field name. May include or omit the `meta.` prefix.
:param search_term: Optional substring filter applied to string field values.
:param size: Maximum number of values to return per page. Defaults to 10 000.
:param after: Pagination cursor returned by a previous call. Pass `None` for the first page.
:returns: Tuple of `(values, next_cursor)`. `next_cursor` is `None` on the last page.
"""
self._ensure_connected()
field = metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field
offset = after.get("offset", 0) if after else 0
limit = size if size is not None else 10000

query_params: dict[str, Any] = {}
where_parts = [f"d.{field} IS NOT NULL"]
if search_term:
where_parts.append(f"toString(d.{field}) CONTAINS $search_term")
query_params["search_term"] = search_term

where = " AND ".join(where_parts)
cypher = (
f"MATCH (d:{self.node_label}) WHERE {where} "
f"RETURN DISTINCT d.{field} AS val "
f"ORDER BY val "
f"SKIP {offset} LIMIT {limit + 1}"
)
result = self.graph.query(cypher, query_params)
rows = result.result_set
values = [row[0] for row in rows[:limit]]
next_cursor: dict[str, Any] | None = {"offset": offset + limit} if len(rows) > limit else None
return values, next_cursor

# ------------------------------------------------------------------
# Internal retrieval helpers (called by retriever components)
# ------------------------------------------------------------------
Expand Down
108 changes: 106 additions & 2 deletions integrations/falkordb/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.errors import FilterError
from haystack.testing.document_store import DocumentStoreBaseTests
from haystack.testing.document_store import (
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
DeleteAllTest,
DeleteByFilterTest,
DocumentStoreBaseTests,
GetMetadataFieldMinMaxTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldUniqueValuesTest,
UpdateByFilterTest,
)

from haystack_integrations.components.retrievers.falkordb import (
FalkorDBCypherRetriever,
Expand Down Expand Up @@ -267,9 +277,103 @@ def test_write_documents_wraps_errors(self, mock_falkordb):
with pytest.raises(DocumentStoreError, match="Failed to write documents"):
FalkorDBDocumentStore().write_documents([Document(id="a", content="x")], policy=DuplicatePolicy.OVERWRITE)

def test_delete_all_documents(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([])]
FalkorDBDocumentStore().delete_all_documents()
assert "DETACH DELETE" in graph.query.call_args_list[-1].args[0]

def test_delete_by_filter_returns_count(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([[3]]), _result([])]
count = FalkorDBDocumentStore().delete_by_filter({"field": "year", "operator": "==", "value": 2024})
assert count == 3
assert "DETACH DELETE" in graph.query.call_args_list[-1].args[0]

def test_delete_by_filter_empty_result(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([]), _result([])]
assert FalkorDBDocumentStore().delete_by_filter({"field": "year", "operator": "==", "value": 2024}) == 0

def test_update_by_filter_returns_count(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([[2]])]
count = FalkorDBDocumentStore().update_by_filter(
{"field": "year", "operator": "==", "value": 2024}, {"status": "published"}
)
assert count == 2
assert "SET d +=" in graph.query.call_args_list[-1].args[0]

def test_update_by_filter_strips_meta_prefix(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([[1]])]
FalkorDBDocumentStore().update_by_filter(
{"field": "year", "operator": "==", "value": 2024}, {"meta.status": "published"}
)
assert graph.query.call_args_list[-1].args[1]["meta_update"] == {"status": "published"}

@pytest.mark.parametrize("rows, expected", [([[5]], 5), ([], 0)])
def test_count_documents_by_filter(self, mock_falkordb, rows, expected):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result(rows)]
count = FalkorDBDocumentStore().count_documents_by_filter({"field": "year", "operator": "==", "value": 2024})
assert count == expected

def test_count_unique_metadata_by_filter(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([[3]]), _result([[2]])]
result = FalkorDBDocumentStore().count_unique_metadata_by_filter({}, ["category", "status"])
assert result == {"category": 3, "status": 2}

def test_get_metadata_fields_info(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [
_result([]),
_result([]),
_result([[["category", "year"]]]),
_result([["A"]]),
_result([[2024]]),
]
info = FalkorDBDocumentStore().get_metadata_fields_info()
assert info["category"] == {"type": "str"}
assert info["year"] == {"type": "int"}

@pytest.mark.parametrize(
"rows, expected",
[([[2020, 2024]], {"min": 2020, "max": 2024}), ([], {"min": None, "max": None})],
)
def test_get_metadata_field_min_max(self, mock_falkordb, rows, expected):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result(rows)]
assert FalkorDBDocumentStore().get_metadata_field_min_max("year") == expected

def test_get_metadata_field_unique_values(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([["A"], ["B"], ["C"]])]
values, cursor = FalkorDBDocumentStore().get_metadata_field_unique_values("category", size=10)
assert values == ["A", "B", "C"]
assert cursor is None

def test_get_metadata_field_unique_values_pagination(self, mock_falkordb):
_, _, graph = mock_falkordb
graph.query.side_effect = [_result([]), _result([]), _result([["A"], ["B"], ["C"]])]
values, cursor = FalkorDBDocumentStore().get_metadata_field_unique_values("category", size=2)
assert values == ["A", "B"]
assert cursor == {"offset": 2}


@pytest.mark.integration
class TestDocumentStore(DocumentStoreBaseTests):
class TestDocumentStore(
DocumentStoreBaseTests,
DeleteAllTest,
DeleteByFilterTest,
UpdateByFilterTest,
CountDocumentsByFilterTest,
CountUniqueMetadataByFilterTest,
GetMetadataFieldsInfoTest,
GetMetadataFieldMinMaxTest,
GetMetadataFieldUniqueValuesTest,
):
"""
Test FalkorDBDocumentStore against the standard Haystack DocumentStore tests.
"""
Expand Down
Loading