diff --git a/test/unit/connectors/test_mongodb_errors.py b/test/unit/connectors/test_mongodb_errors.py new file mode 100644 index 000000000..801611f69 --- /dev/null +++ b/test/unit/connectors/test_mongodb_errors.py @@ -0,0 +1,128 @@ +from unittest.mock import MagicMock + +import pytest +from pymongo.errors import ( + AutoReconnect, + BulkWriteError, + OperationFailure, + ServerSelectionTimeoutError, +) + +from unstructured_ingest.error import ( + DestinationConnectionError, + QuotaError, + TimeoutError, + WriteError, +) +from unstructured_ingest.processes.connectors.mongodb import ( + MongoDBConnectionConfig, + MongoDBUploader, + MongoDBUploaderConfig, +) + + +def _make_uploader(): + connection_config = MagicMock(spec=MongoDBConnectionConfig) + connection_config.host = "test_host" + upload_config = MagicMock(spec=MongoDBUploaderConfig) + upload_config.record_id_key = "record_id" + upload_config.database = "test_db" + upload_config.collection = "test_collection" + upload_config.batch_size = 100 + return MongoDBUploader( + connection_config=connection_config, + upload_config=upload_config, + ) + + +def _mock_client(uploader, collection_side_effect=None): + mock_client = MagicMock() + uploader.connection_config.get_client.return_value.__enter__ = MagicMock( + return_value=mock_client + ) + uploader.connection_config.get_client.return_value.__exit__ = MagicMock(return_value=False) + mock_collection = mock_client.__getitem__("test_db").__getitem__("test_collection") + # Make can_delete return False so we skip delete_by_record_id and go straight to insert + mock_collection.list_indexes.return_value = [] + if collection_side_effect: + mock_collection.insert_many.side_effect = collection_side_effect + return mock_collection + + +class TestRunDataErrorHandling: + def test_operation_failure_quota_raises_quota_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + _mock_client(uploader, OperationFailure("quota exceeded for writes")) + + with pytest.raises(QuotaError): + uploader.run_data(data=[{"key": "value"}], file_data=file_data) + + def test_operation_failure_other_raises_destination_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + _mock_client(uploader, OperationFailure("some other failure")) + + with pytest.raises(DestinationConnectionError): + uploader.run_data(data=[{"key": "value"}], file_data=file_data) + + def test_server_selection_timeout_raises_timeout_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + _mock_client(uploader, ServerSelectionTimeoutError("timeout")) + + with pytest.raises(TimeoutError): + uploader.run_data(data=[{"key": "value"}], file_data=file_data) + + def test_bulk_write_error_raises_write_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + _mock_client(uploader, BulkWriteError({"writeErrors": [{"errmsg": "fail"}]})) + + with pytest.raises(WriteError): + uploader.run_data(data=[{"key": "value"}], file_data=file_data) + + def test_auto_reconnect_raises_destination_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + _mock_client(uploader, AutoReconnect("connection lost")) + + with pytest.raises(DestinationConnectionError): + uploader.run_data(data=[{"key": "value"}], file_data=file_data) + + +class TestDeleteByRecordIdErrorHandling: + def test_operation_failure_quota_raises_quota_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + collection = MagicMock() + collection.delete_many.side_effect = OperationFailure("quota exceeded") + + with pytest.raises(QuotaError): + uploader.delete_by_record_id(collection=collection, file_data=file_data) + + def test_server_selection_timeout_raises_timeout_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + collection = MagicMock() + collection.delete_many.side_effect = ServerSelectionTimeoutError("timeout") + + with pytest.raises(TimeoutError): + uploader.delete_by_record_id(collection=collection, file_data=file_data) + + def test_auto_reconnect_raises_destination_error(self): + uploader = _make_uploader() + file_data = MagicMock() + file_data.identifier = "test_id" + collection = MagicMock() + collection.delete_many.side_effect = AutoReconnect("connection lost") + + with pytest.raises(DestinationConnectionError): + uploader.delete_by_record_id(collection=collection, file_data=file_data) diff --git a/unstructured_ingest/processes/connectors/mongodb.py b/unstructured_ingest/processes/connectors/mongodb.py index 7e02e8ac2..3abfbef1d 100644 --- a/unstructured_ingest/processes/connectors/mongodb.py +++ b/unstructured_ingest/processes/connectors/mongodb.py @@ -17,8 +17,11 @@ from unstructured_ingest.error import ( ConnectionError, DestinationConnectionError, + QuotaError, SourceConnectionError, + TimeoutError, ValueError, + WriteError, ) from unstructured_ingest.interfaces import ( AccessConfig, @@ -341,18 +344,40 @@ def can_delete(self, collection: "Collection") -> bool: return self.upload_config.record_id_key in indexed_keys def delete_by_record_id(self, collection: "Collection", file_data: FileData) -> None: + from pymongo.errors import ( + AutoReconnect, + OperationFailure, + ServerSelectionTimeoutError, + ) + logger.debug( f"deleting any content with metadata " f"{self.upload_config.record_id_key}={file_data.identifier} " f"from collection: {collection.name}" ) query = {self.upload_config.record_id_key: file_data.identifier} - delete_results = collection.delete_many(filter=query) + try: + delete_results = collection.delete_many(filter=query) + except OperationFailure as e: + if "quota" in str(e).lower(): + raise QuotaError(f"MongoDB quota exceeded: {e}") from e + raise DestinationConnectionError(f"MongoDB operation failed: {e}") from e + except ServerSelectionTimeoutError as e: + raise TimeoutError(f"MongoDB server unreachable: {e}") from e + except AutoReconnect as e: + raise DestinationConnectionError(f"MongoDB connection lost: {e}") from e logger.info( f"deleted {delete_results.deleted_count} records from collection {collection.name}" ) def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None: + from pymongo.errors import ( + AutoReconnect, + BulkWriteError, + OperationFailure, + ServerSelectionTimeoutError, + ) + logger.info( f"writing {len(data)} objects to destination " f"db, {self.upload_config.database}, " @@ -363,15 +388,28 @@ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None # is done, setting the record id field in the uploader for element in data: element[self.upload_config.record_id_key] = file_data.identifier - with self.connection_config.get_client() as client: - db = client[self.upload_config.database] - collection = db[self.upload_config.collection] - if self.can_delete(collection=collection): - self.delete_by_record_id(file_data=file_data, collection=collection) - else: - logger.warning("criteria for deleting previous content not met, skipping") - for chunk in batch_generator(data, self.upload_config.batch_size): - collection.insert_many(chunk) + try: + with self.connection_config.get_client() as client: + db = client[self.upload_config.database] + collection = db[self.upload_config.collection] + if self.can_delete(collection=collection): + self.delete_by_record_id(file_data=file_data, collection=collection) + else: + logger.warning("criteria for deleting previous content not met, skipping") + for chunk in batch_generator(data, self.upload_config.batch_size): + collection.insert_many(chunk) + except BulkWriteError as e: + raise WriteError(f"MongoDB bulk write failed: {e}") from e + except OperationFailure as e: + if "quota" in str(e).lower(): + raise QuotaError(f"MongoDB quota exceeded: {e}") from e + raise DestinationConnectionError(f"MongoDB operation failed: {e}") from e + except ServerSelectionTimeoutError as e: + raise TimeoutError(f"MongoDB server unreachable: {e}") from e + except AutoReconnect as e: + raise DestinationConnectionError(f"MongoDB connection lost: {e}") from e + except Exception as e: + raise DestinationConnectionError(f"MongoDB error: {e}") from e mongodb_destination_entry = DestinationRegistryEntry(