diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 6d097202f5..ab84bd82c7 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 + import json from typing import Any, Optional, Union from warnings import warn @@ -356,3 +360,22 @@ def count_documents(self, upper_bound: int = 10000) -> int: :returns: the number of documents in the index """ return self._astra_db_collection.count_documents({}, upper_bound=upper_bound) + + def update( + self, + *, + filters: dict[str, Union[str, float, int, bool, list, dict]], + update: dict[str, Any], + ) -> int: + """ + Update multiple documents in the Astra index that match the filter. + + :param filters: the filter to match documents to update + :param update: the update operations to apply (e.g., {"$set": {...}}) + + :returns: + The number of documents updated + """ + update_result = self._astra_db_collection.update_many(filter=filters, update=update, upsert=False) + + return update_result.update_info["nModified"] diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index d0565733d1..f3bc8b1f02 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 + from typing import Any, Optional, Union from haystack import default_from_dict, default_to_dict, logging @@ -400,7 +401,6 @@ def delete_documents(self, document_ids: list[str]) -> None: Deletes documents from the document store. :param document_ids: IDs of the documents to delete. - :param delete_all: if `True`, delete all documents. :raises MissingDocumentError: if no document was deleted but document IDs were provided. """ if self.index.find_one_document({"filter": {}}) is not None: @@ -420,7 +420,6 @@ def delete_all_documents(self) -> None: """ Deletes all documents from the document store. """ - deletion_counter = 0 try: deletion_counter = self.index.delete_all_documents() @@ -432,3 +431,59 @@ def delete_all_documents(self) -> None: logger.info("All documents deleted") else: logger.error("Could not delete all documents") + + def delete_by_filter(self, filters: dict[str, Any]) -> int: + """ + Deletes documents that match the provided filters. + + :param filters: The filters to apply to find documents to delete. + :returns: The number of documents deleted. + :raises AstraDocumentStoreFilterError: if the filter is invalid or not supported. + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise AstraDocumentStoreFilterError(msg) + + if "id" in filters: + filters["_id"] = filters.pop("id") + + converted_filters = _convert_filters(filters) + deletion_count = self.index.delete(filters=converted_filters) + + logger.info(f"{deletion_count} documents deleted by filter") + return deletion_count + + def update_by_filter(self, filters: dict[str, Any], meta: dict[str, Any]) -> int: + """ + Updates documents that match the provided filters with the given metadata. + + :param filters: The filters to apply to find documents to update. + :param meta: The metadata fields to update. This will be merged with existing metadata. + + :returns: + The number of documents updated. + + :raises: + AstraDocumentStoreFilterError: if the filter is invalid or not supported. + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise AstraDocumentStoreFilterError(msg) + + if not isinstance(meta, dict): + msg = "Meta must be a dictionary" + raise AstraDocumentStoreFilterError(msg) + + if "id" in filters: + filters["_id"] = filters.pop("id") + + converted_filters = _convert_filters(filters) + + # use dot notation to update nested fields in the meta-object - ensures fields are created if they don't exist + update_fields = {f"meta.{key}": value for key, value in meta.items()} + update_operation = {"$set": update_fields} + update_count = self.index.update(filters=converted_filters, update=update_operation) # type: ignore + + logger.info(f"{update_count} documents updated by filter") + + return update_count diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py index 493f629177..1150f1b4c3 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 + from haystack.document_stores.errors import DocumentStoreError from haystack.errors import FilterError diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py index 81dc97f528..7905cca813 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Optional from haystack.errors import FilterError diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index 69b6d8d50f..3d2695d543 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 + import operator import os from unittest import mock @@ -210,6 +211,67 @@ def test_delete_all_documents(self, document_store: AstraDocumentStore): document_store.delete_all_documents() assert document_store.count_documents() == 0 + def test_delete_by_filter(self, document_store: AstraDocumentStore, filterable_docs): + document_store.write_documents(filterable_docs) + initial_count = document_store.count_documents() + assert initial_count > 0 + + # count documents that match the filter before deletion + matching_docs = [d for d in filterable_docs if d.meta.get("chapter") == "intro"] + expected_deleted_count = len(matching_docs) + + # delete all documents with chapter="intro" + deleted_count = document_store.delete_by_filter( + filters={"field": "meta.chapter", "operator": "==", "value": "intro"} + ) + + assert deleted_count == expected_deleted_count + assert document_store.count_documents() == initial_count - deleted_count + + # remaining documents don't have chapter="intro" + remaining_docs = document_store.filter_documents() + for doc in remaining_docs: + assert doc.meta.get("chapter") != "intro" + + # all documents with chapter="intro" were deleted + intro_docs = document_store.filter_documents( + filters={"field": "meta.chapter", "operator": "==", "value": "intro"} + ) + assert len(intro_docs) == 0 + + def test_update_by_filter(self, document_store: AstraDocumentStore, filterable_docs): + document_store.write_documents(filterable_docs) + initial_count = document_store.count_documents() + assert initial_count > 0 + + # count documents that match the filter before update + matching_docs = [d for d in filterable_docs if d.meta.get("chapter") == "intro"] + expected_updated_count = len(matching_docs) + + # update all documents with chapter="intro" to have status="updated" + updated_count = document_store.update_by_filter( + filters={"field": "meta.chapter", "operator": "==", "value": "intro"}, + meta={"status": "updated"}, + ) + + assert updated_count == expected_updated_count + assert document_store.count_documents() == initial_count + + # verify the updated documents have the new metadata + updated_docs = document_store.filter_documents( + filters={"field": "meta.status", "operator": "==", "value": "updated"} + ) + assert len(updated_docs) == expected_updated_count + for doc in updated_docs: + assert doc.meta.get("chapter") == "intro" + assert doc.meta.get("status") == "updated" + + # verify other documents weren't affected + all_docs = document_store.filter_documents() + for doc in all_docs: + if doc.meta.get("chapter") != "intro": + assert doc.meta.get("status") != "updated" + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass