Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions test/unit/connectors/test_mongodb_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from unittest.mock import MagicMock

import pytest
from pymongo.errors import (
AutoReconnect,
BulkWriteError,
OperationFailure,
ServerSelectionTimeoutError,
)

from unstructured_ingest.error import (
DestinationConnectionError,
QuotaError,
TimeoutError,
WriteError,
)
from unstructured_ingest.processes.connectors.mongodb import (
MongoDBConnectionConfig,
MongoDBUploader,
MongoDBUploaderConfig,
)


def _make_uploader():
connection_config = MagicMock(spec=MongoDBConnectionConfig)
connection_config.host = "test_host"
upload_config = MagicMock(spec=MongoDBUploaderConfig)
upload_config.record_id_key = "record_id"
upload_config.database = "test_db"
upload_config.collection = "test_collection"
upload_config.batch_size = 100
return MongoDBUploader(
connection_config=connection_config,
upload_config=upload_config,
)


def _mock_client(uploader, collection_side_effect=None):
mock_client = MagicMock()
uploader.connection_config.get_client.return_value.__enter__ = MagicMock(
return_value=mock_client
)
uploader.connection_config.get_client.return_value.__exit__ = MagicMock(return_value=False)
mock_collection = mock_client.__getitem__("test_db").__getitem__("test_collection")
# Make can_delete return False so we skip delete_by_record_id and go straight to insert
mock_collection.list_indexes.return_value = []
if collection_side_effect:
mock_collection.insert_many.side_effect = collection_side_effect
return mock_collection


class TestRunDataErrorHandling:
def test_operation_failure_quota_raises_quota_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
_mock_client(uploader, OperationFailure("quota exceeded for writes"))

with pytest.raises(QuotaError):
uploader.run_data(data=[{"key": "value"}], file_data=file_data)

def test_operation_failure_other_raises_destination_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
_mock_client(uploader, OperationFailure("some other failure"))

with pytest.raises(DestinationConnectionError):
uploader.run_data(data=[{"key": "value"}], file_data=file_data)

def test_server_selection_timeout_raises_timeout_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
_mock_client(uploader, ServerSelectionTimeoutError("timeout"))

with pytest.raises(TimeoutError):
uploader.run_data(data=[{"key": "value"}], file_data=file_data)

def test_bulk_write_error_raises_write_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
_mock_client(uploader, BulkWriteError({"writeErrors": [{"errmsg": "fail"}]}))

with pytest.raises(WriteError):
uploader.run_data(data=[{"key": "value"}], file_data=file_data)

def test_auto_reconnect_raises_destination_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
_mock_client(uploader, AutoReconnect("connection lost"))

with pytest.raises(DestinationConnectionError):
uploader.run_data(data=[{"key": "value"}], file_data=file_data)


class TestDeleteByRecordIdErrorHandling:
def test_operation_failure_quota_raises_quota_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
collection = MagicMock()
collection.delete_many.side_effect = OperationFailure("quota exceeded")

with pytest.raises(QuotaError):
uploader.delete_by_record_id(collection=collection, file_data=file_data)

def test_server_selection_timeout_raises_timeout_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
collection = MagicMock()
collection.delete_many.side_effect = ServerSelectionTimeoutError("timeout")

with pytest.raises(TimeoutError):
uploader.delete_by_record_id(collection=collection, file_data=file_data)

def test_auto_reconnect_raises_destination_error(self):
uploader = _make_uploader()
file_data = MagicMock()
file_data.identifier = "test_id"
collection = MagicMock()
collection.delete_many.side_effect = AutoReconnect("connection lost")

with pytest.raises(DestinationConnectionError):
uploader.delete_by_record_id(collection=collection, file_data=file_data)
58 changes: 48 additions & 10 deletions unstructured_ingest/processes/connectors/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
from unstructured_ingest.error import (
ConnectionError,
DestinationConnectionError,
QuotaError,
SourceConnectionError,
TimeoutError,
ValueError,
WriteError,
)
from unstructured_ingest.interfaces import (
AccessConfig,
Expand Down Expand Up @@ -341,18 +344,40 @@ def can_delete(self, collection: "Collection") -> bool:
return self.upload_config.record_id_key in indexed_keys

def delete_by_record_id(self, collection: "Collection", file_data: FileData) -> None:
from pymongo.errors import (
AutoReconnect,
OperationFailure,
ServerSelectionTimeoutError,
)

logger.debug(
f"deleting any content with metadata "
f"{self.upload_config.record_id_key}={file_data.identifier} "
f"from collection: {collection.name}"
)
query = {self.upload_config.record_id_key: file_data.identifier}
delete_results = collection.delete_many(filter=query)
try:
delete_results = collection.delete_many(filter=query)
except OperationFailure as e:
if "quota" in str(e).lower():
raise QuotaError(f"MongoDB quota exceeded: {e}") from e
raise DestinationConnectionError(f"MongoDB operation failed: {e}") from e
except ServerSelectionTimeoutError as e:
raise TimeoutError(f"MongoDB server unreachable: {e}") from e
except AutoReconnect as e:
raise DestinationConnectionError(f"MongoDB connection lost: {e}") from e
logger.info(
f"deleted {delete_results.deleted_count} records from collection {collection.name}"
)

def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
from pymongo.errors import (
AutoReconnect,
BulkWriteError,
OperationFailure,
ServerSelectionTimeoutError,
)

logger.info(
f"writing {len(data)} objects to destination "
f"db, {self.upload_config.database}, "
Expand All @@ -363,15 +388,28 @@ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None
# is done, setting the record id field in the uploader
for element in data:
element[self.upload_config.record_id_key] = file_data.identifier
with self.connection_config.get_client() as client:
db = client[self.upload_config.database]
collection = db[self.upload_config.collection]
if self.can_delete(collection=collection):
self.delete_by_record_id(file_data=file_data, collection=collection)
else:
logger.warning("criteria for deleting previous content not met, skipping")
for chunk in batch_generator(data, self.upload_config.batch_size):
collection.insert_many(chunk)
try:
with self.connection_config.get_client() as client:
db = client[self.upload_config.database]
collection = db[self.upload_config.collection]
if self.can_delete(collection=collection):
self.delete_by_record_id(file_data=file_data, collection=collection)
else:
logger.warning("criteria for deleting previous content not met, skipping")
for chunk in batch_generator(data, self.upload_config.batch_size):
collection.insert_many(chunk)
except BulkWriteError as e:
raise WriteError(f"MongoDB bulk write failed: {e}") from e
except OperationFailure as e:
if "quota" in str(e).lower():
raise QuotaError(f"MongoDB quota exceeded: {e}") from e
raise DestinationConnectionError(f"MongoDB operation failed: {e}") from e
except ServerSelectionTimeoutError as e:
raise TimeoutError(f"MongoDB server unreachable: {e}") from e
except AutoReconnect as e:
raise DestinationConnectionError(f"MongoDB connection lost: {e}") from e
except Exception as e:
raise DestinationConnectionError(f"MongoDB error: {e}") from e
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catch-all swallows correctly classified errors from delete_by_record_id

High Severity

The except Exception catch-all in run_data re-wraps already-classified exceptions from delete_by_record_id as DestinationConnectionError. When delete_by_record_id raises QuotaError or TimeoutError, those propagate into run_data's try block, skip all the pymongo-specific except handlers, and get caught by the generic except Exception — converting them back to DestinationConnectionError. This directly undermines the PR's goal of correct error classification. The catch-all needs to re-raise UnstructuredIngestError subclasses before falling through to the generic handler.

Additional Locations (1)
Fix in Cursor Fix in Web



mongodb_destination_entry = DestinationRegistryEntry(
Expand Down
Loading