diff --git a/integrations/falkordb/src/haystack_integrations/document_stores/falkordb/document_store.py b/integrations/falkordb/src/haystack_integrations/document_stores/falkordb/document_store.py index 49ebeec193..e35abe98df 100644 --- a/integrations/falkordb/src/haystack_integrations/document_stores/falkordb/document_store.py +++ b/integrations/falkordb/src/haystack_integrations/document_stores/falkordb/document_store.py @@ -449,6 +449,185 @@ def delete_documents(self, document_ids: list[str]) -> None: {"ids": document_ids}, ) + def delete_all_documents(self) -> None: + """ + Delete all documents from the graph. + """ + self._ensure_connected() + self.graph.query(f"MATCH (d:{self.node_label}) DETACH DELETE d") + + def delete_by_filter(self, filters: dict[str, Any]) -> int: + """ + Delete all documents that match the provided filters. + + :param filters: Haystack filter dict. + :returns: Number of documents deleted. + """ + self._ensure_connected() + where_clause, params = _convert_filters(filters) + count_result = self.graph.query( + f"MATCH (d:{self.node_label}) WHERE {where_clause} RETURN count(d) AS n", + params, + ) + count = int(count_result.result_set[0][0]) if count_result.result_set else 0 + self.graph.query( + f"MATCH (d:{self.node_label}) WHERE {where_clause} DETACH DELETE d", + params, + ) + return count + + def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int: + """ + Update metadata fields on all documents that match the provided filters. + + :param filters: Haystack filter dict selecting which documents to update. + :param meta: Metadata fields to set. Keys may include or omit the `meta.` prefix. + :returns: Number of documents updated. + """ + self._ensure_connected() + where_clause, params = _convert_filters(filters) + flat_meta = {k[5:] if k.startswith("meta.") else k: v for k, v in meta.items()} + params["meta_update"] = flat_meta + result = self.graph.query( + f"MATCH (d:{self.node_label}) WHERE {where_clause} SET d += $meta_update RETURN count(d) AS n", + params, + ) + rows = result.result_set + return int(rows[0][0]) if rows else 0 + + def count_documents_by_filter(self, filters: dict[str, Any]) -> int: + """ + Return the number of documents that match the provided filters. + + :param filters: Haystack filter dict. + :returns: Integer count of matching document nodes. + """ + self._ensure_connected() + where_clause, params = _convert_filters(filters) + result = self.graph.query( + f"MATCH (d:{self.node_label}) WHERE {where_clause} RETURN count(d) AS n", + params, + ) + rows = result.result_set + return int(rows[0][0]) if rows else 0 + + def count_unique_metadata_by_filter(self, filters: dict[str, Any], metadata_fields: list[str]) -> dict[str, int]: + """ + Return the number of unique values for each metadata field among matching documents. + + :param filters: Haystack filter dict. Pass an empty dict to count across all documents. + :param metadata_fields: List of metadata field names. May include or omit the `meta.` prefix. + :returns: Dict mapping each field name (without `meta.` prefix) to its unique value count. + """ + self._ensure_connected() + if filters: + where_clause, params = _convert_filters(filters) + match = f"MATCH (d:{self.node_label}) WHERE {where_clause}" + else: + params = {} + match = f"MATCH (d:{self.node_label})" + + result: dict[str, int] = {} + for field in metadata_fields: + actual = field[5:] if field.startswith("meta.") else field + res = self.graph.query( + f"{match} RETURN count(DISTINCT d.{actual}) AS n", + params, + ) + rows = res.result_set + result[actual] = int(rows[0][0]) if rows else 0 + return result + + def get_metadata_fields_info(self) -> dict[str, dict[str, str]]: + """ + Return type information for each metadata field present on document nodes. + + :returns: Dict mapping field names to a `{"type": }` dict. + Type names are `"str"`, `"int"`, `"float"`, or `"bool"`. + """ + self._ensure_connected() + standard_fields = {"id", "content", "embedding", "score", "sparse_embedding"} + result = self.graph.query(f"MATCH (d:{self.node_label}) RETURN keys(d)") + all_keys: set[str] = set() + for row in result.result_set: + all_keys.update(row[0]) + all_keys -= standard_fields + + info: dict[str, dict[str, str]] = {} + for key in sorted(all_keys): + res = self.graph.query(f"MATCH (d:{self.node_label}) WHERE d.{key} IS NOT NULL RETURN d.{key} LIMIT 1") + if not res.result_set: + continue + val = res.result_set[0][0] + if isinstance(val, bool): + type_name = "bool" + elif isinstance(val, int): + type_name = "int" + elif isinstance(val, float): + type_name = "float" + else: + type_name = "str" + info[key] = {"type": type_name} + return info + + def get_metadata_field_min_max(self, metadata_field: str) -> dict[str, Any]: + """ + Return the minimum and maximum values for the given metadata field. + + :param metadata_field: Metadata field name. May include or omit the `meta.` prefix. + :returns: Dict with keys `"min"` and `"max"`. Values are `None` when no documents + have a non-null value for the field. + """ + self._ensure_connected() + field = metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field + result = self.graph.query( + f"MATCH (d:{self.node_label}) WHERE d.{field} IS NOT NULL RETURN min(d.{field}), max(d.{field})" + ) + if not result.result_set: + return {"min": None, "max": None} + row = result.result_set[0] + return {"min": row[0], "max": row[1]} + + def get_metadata_field_unique_values( + self, + metadata_field: str, + search_term: str | None = None, + size: int | None = 10000, + after: dict[str, Any] | None = None, + ) -> tuple[list[Any], dict[str, Any] | None]: + """ + Return distinct values for the given metadata field with optional filtering and pagination. + + :param metadata_field: Metadata field name. May include or omit the `meta.` prefix. + :param search_term: Optional substring filter applied to string field values. + :param size: Maximum number of values to return per page. Defaults to 10 000. + :param after: Pagination cursor returned by a previous call. Pass `None` for the first page. + :returns: Tuple of `(values, next_cursor)`. `next_cursor` is `None` on the last page. + """ + self._ensure_connected() + field = metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field + offset = after.get("offset", 0) if after else 0 + limit = size if size is not None else 10000 + + query_params: dict[str, Any] = {} + where_parts = [f"d.{field} IS NOT NULL"] + if search_term: + where_parts.append(f"toString(d.{field}) CONTAINS $search_term") + query_params["search_term"] = search_term + + where = " AND ".join(where_parts) + cypher = ( + f"MATCH (d:{self.node_label}) WHERE {where} " + f"RETURN DISTINCT d.{field} AS val " + f"ORDER BY val " + f"SKIP {offset} LIMIT {limit + 1}" + ) + result = self.graph.query(cypher, query_params) + rows = result.result_set + values = [row[0] for row in rows[:limit]] + next_cursor: dict[str, Any] | None = {"offset": offset + limit} if len(rows) > limit else None + return values, next_cursor + # ------------------------------------------------------------------ # Internal retrieval helpers (called by retriever components) # ------------------------------------------------------------------ diff --git a/integrations/falkordb/tests/test_document_store.py b/integrations/falkordb/tests/test_document_store.py index a0168e8b42..dd4cee4fd6 100644 --- a/integrations/falkordb/tests/test_document_store.py +++ b/integrations/falkordb/tests/test_document_store.py @@ -13,7 +13,17 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.errors import FilterError -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import ( + CountDocumentsByFilterTest, + CountUniqueMetadataByFilterTest, + DeleteAllTest, + DeleteByFilterTest, + DocumentStoreBaseTests, + GetMetadataFieldMinMaxTest, + GetMetadataFieldsInfoTest, + GetMetadataFieldUniqueValuesTest, + UpdateByFilterTest, +) from haystack_integrations.components.retrievers.falkordb import ( FalkorDBCypherRetriever, @@ -267,9 +277,103 @@ def test_write_documents_wraps_errors(self, mock_falkordb): with pytest.raises(DocumentStoreError, match="Failed to write documents"): FalkorDBDocumentStore().write_documents([Document(id="a", content="x")], policy=DuplicatePolicy.OVERWRITE) + def test_delete_all_documents(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([])] + FalkorDBDocumentStore().delete_all_documents() + assert "DETACH DELETE" in graph.query.call_args_list[-1].args[0] + + def test_delete_by_filter_returns_count(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([[3]]), _result([])] + count = FalkorDBDocumentStore().delete_by_filter({"field": "year", "operator": "==", "value": 2024}) + assert count == 3 + assert "DETACH DELETE" in graph.query.call_args_list[-1].args[0] + + def test_delete_by_filter_empty_result(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([]), _result([])] + assert FalkorDBDocumentStore().delete_by_filter({"field": "year", "operator": "==", "value": 2024}) == 0 + + def test_update_by_filter_returns_count(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([[2]])] + count = FalkorDBDocumentStore().update_by_filter( + {"field": "year", "operator": "==", "value": 2024}, {"status": "published"} + ) + assert count == 2 + assert "SET d +=" in graph.query.call_args_list[-1].args[0] + + def test_update_by_filter_strips_meta_prefix(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([[1]])] + FalkorDBDocumentStore().update_by_filter( + {"field": "year", "operator": "==", "value": 2024}, {"meta.status": "published"} + ) + assert graph.query.call_args_list[-1].args[1]["meta_update"] == {"status": "published"} + + @pytest.mark.parametrize("rows, expected", [([[5]], 5), ([], 0)]) + def test_count_documents_by_filter(self, mock_falkordb, rows, expected): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result(rows)] + count = FalkorDBDocumentStore().count_documents_by_filter({"field": "year", "operator": "==", "value": 2024}) + assert count == expected + + def test_count_unique_metadata_by_filter(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([[3]]), _result([[2]])] + result = FalkorDBDocumentStore().count_unique_metadata_by_filter({}, ["category", "status"]) + assert result == {"category": 3, "status": 2} + + def test_get_metadata_fields_info(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [ + _result([]), + _result([]), + _result([[["category", "year"]]]), + _result([["A"]]), + _result([[2024]]), + ] + info = FalkorDBDocumentStore().get_metadata_fields_info() + assert info["category"] == {"type": "str"} + assert info["year"] == {"type": "int"} + + @pytest.mark.parametrize( + "rows, expected", + [([[2020, 2024]], {"min": 2020, "max": 2024}), ([], {"min": None, "max": None})], + ) + def test_get_metadata_field_min_max(self, mock_falkordb, rows, expected): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result(rows)] + assert FalkorDBDocumentStore().get_metadata_field_min_max("year") == expected + + def test_get_metadata_field_unique_values(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([["A"], ["B"], ["C"]])] + values, cursor = FalkorDBDocumentStore().get_metadata_field_unique_values("category", size=10) + assert values == ["A", "B", "C"] + assert cursor is None + + def test_get_metadata_field_unique_values_pagination(self, mock_falkordb): + _, _, graph = mock_falkordb + graph.query.side_effect = [_result([]), _result([]), _result([["A"], ["B"], ["C"]])] + values, cursor = FalkorDBDocumentStore().get_metadata_field_unique_values("category", size=2) + assert values == ["A", "B"] + assert cursor == {"offset": 2} + @pytest.mark.integration -class TestDocumentStore(DocumentStoreBaseTests): +class TestDocumentStore( + DocumentStoreBaseTests, + DeleteAllTest, + DeleteByFilterTest, + UpdateByFilterTest, + CountDocumentsByFilterTest, + CountUniqueMetadataByFilterTest, + GetMetadataFieldsInfoTest, + GetMetadataFieldMinMaxTest, + GetMetadataFieldUniqueValuesTest, +): """ Test FalkorDBDocumentStore against the standard Haystack DocumentStore tests. """