Skip to content

Commit b9a34df

Browse files
abdokasebanakin87
andauthored
Fix: prevent in-place mutation of documents in Document Classifiers and Extractors (#9703)
* modify Documents Classifiers and Extractors to not make in-place changes * Add e2e test for NER * Add unit test for NER * fixes + refinements --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
1 parent f8d3a82 commit b9a34df

11 files changed

Lines changed: 102 additions & 22 deletions

e2e/pipelines/test_named_entity_extractor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
# Note: We only test the Spacy backend in this module, which is executed in the e2e environment.
6+
# We don't test Spacy in test/components/extractors/test_named_entity_extractor.py, which is executed in the
7+
# test environment. Spacy is not installed in the test environment to keep the CI fast.
8+
59
import os
610
import pytest
711

@@ -113,7 +117,15 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, monkey
113117
def _extract_and_check_predictions(extractor, texts, expected, batch_size):
114118
docs = [Document(content=text) for text in texts]
115119
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
116-
assert all(id(a) == id(b) for a, b in zip(docs, outputs))
120+
for original_doc, output_doc in zip(docs, outputs):
121+
# we don't modify documents in place
122+
assert original_doc is not output_doc
123+
124+
# apart from meta, the documents should be identical
125+
output_doc_dict = output_doc.to_dict(flatten=False)
126+
output_doc_dict.pop("meta", None)
127+
assert original_doc.to_dict() == output_doc_dict
128+
117129
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
118130

119131
_check_predictions(predicted, expected)

haystack/components/classifiers/document_language_classifier.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from dataclasses import replace
56
from typing import Optional
67

78
from haystack import Document, component, logging
@@ -89,14 +90,17 @@ def run(self, documents: list[Document]):
8990
output: dict[str, list[Document]] = {language: [] for language in self.languages}
9091
output["unmatched"] = []
9192

93+
new_documents = []
9294
for document in documents:
9395
detected_language = self._detect_language(document)
96+
new_meta = {**document.meta}
9497
if detected_language in self.languages:
95-
document.meta["language"] = detected_language
98+
new_meta["language"] = detected_language
9699
else:
97-
document.meta["language"] = "unmatched"
100+
new_meta["language"] = "unmatched"
101+
new_documents.append(replace(document, meta=new_meta))
98102

99-
return {"documents": documents}
103+
return {"documents": new_documents}
100104

101105
def _detect_language(self, document: Document) -> Optional[str]:
102106
language = None

haystack/components/classifiers/zero_shot_document_classifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from dataclasses import replace
56
from typing import Any, Optional
67

78
from haystack import Document, component, default_from_dict, default_to_dict
@@ -232,12 +233,14 @@ def run(self, documents: list[Document], batch_size: int = 1):
232233

233234
predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
234235

236+
new_documents = []
235237
for prediction, document in zip(predictions, documents):
236238
formatted_prediction = {
237239
"label": prediction["labels"][0],
238240
"score": prediction["scores"][0],
239241
"details": dict(zip(prediction["labels"], prediction["scores"])),
240242
}
241-
document.meta["classification"] = formatted_prediction
243+
new_meta = {**document.meta, "classification": formatted_prediction}
244+
new_documents.append(replace(document, meta=new_meta))
242245

243-
return {"documents": documents}
246+
return {"documents": new_documents}

haystack/components/extractors/llm_metadata_extractor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import json
77
from concurrent.futures import ThreadPoolExecutor
8+
from dataclasses import replace
89
from typing import Any, Optional, Union
910

1011
from jinja2 import meta
@@ -318,24 +319,25 @@ def run(self, documents: list[Document], page_range: Optional[list[Union[str, in
318319
successful_documents = []
319320
failed_documents = []
320321
for document, result in zip(documents, results):
322+
new_meta = {**document.meta}
321323
if "error" in result:
322-
document.meta["metadata_extraction_error"] = result["error"]
323-
document.meta["metadata_extraction_response"] = None
324-
failed_documents.append(document)
324+
new_meta["metadata_extraction_error"] = result["error"]
325+
new_meta["metadata_extraction_response"] = None
326+
failed_documents.append(replace(document, meta=new_meta))
325327
continue
326328

327329
parsed_metadata = self._extract_metadata(result["replies"][0].text)
328330
if "error" in parsed_metadata:
329-
document.meta["metadata_extraction_error"] = parsed_metadata["error"]
330-
document.meta["metadata_extraction_response"] = result["replies"][0]
331-
failed_documents.append(document)
331+
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
332+
new_meta["metadata_extraction_response"] = result["replies"][0]
333+
failed_documents.append(replace(document, meta=new_meta))
332334
continue
333335

334336
for key in parsed_metadata:
335-
document.meta[key] = parsed_metadata[key]
337+
new_meta[key] = parsed_metadata[key]
336338
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
337-
document.meta.pop("metadata_extraction_error", None)
338-
document.meta.pop("metadata_extraction_response", None)
339-
successful_documents.append(document)
339+
new_meta.pop("metadata_extraction_error", None)
340+
new_meta.pop("metadata_extraction_response", None)
341+
successful_documents.append(replace(document, meta=new_meta))
340342

341343
return {"documents": successful_documents, "failed_documents": failed_documents}

haystack/components/extractors/named_entity_extractor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from abc import ABC, abstractmethod
66
from contextlib import contextmanager
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, replace
88
from enum import Enum
99
from typing import Any, Optional, Union
1010

@@ -204,10 +204,12 @@ def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
204204
f"got {len(annotations)} but expected {len(documents)}"
205205
)
206206

207+
new_documents = []
207208
for doc, doc_annotations in zip(documents, annotations):
208-
doc.meta[self._METADATA_KEY] = doc_annotations
209+
new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
210+
new_documents.append(replace(doc, meta=new_meta))
209211

210-
return {"documents": documents}
212+
return {"documents": new_documents}
211213

212214
def to_dict(self) -> dict[str, Any]:
213215
"""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ extra-dependencies = [
173173
]
174174

175175
[tool.hatch.envs.e2e.scripts]
176-
test = "pytest e2e"
176+
test = "pytest {args:e2e}"
177177

178178
[tool.hatch.envs.readme]
179179
installer = "uv"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Prevented in-place mutation of input `Document` objects in all `Extractor` and `Classifier` components
5+
by creating copies with `dataclasses.replace` before processing.

test/components/classifiers/test_document_language_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def test_classify_as_en_and_unmatched(self):
4242
result = classifier.run(documents=[english_document, german_document])
4343
assert result["documents"][0].meta["language"] == "en"
4444
assert result["documents"][1].meta["language"] == "unmatched"
45+
assert "language" not in english_document.meta
46+
assert "language" not in german_document.meta
4547

4648
def test_warning_if_no_language_detected(self, caplog):
4749
with caplog.at_level(logging.WARNING):

test/components/classifiers/test_zero_shot_document_classifier.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def test_run_unit(self, hf_pipeline_mock):
135135
assert component.pipeline is not None
136136
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
137137
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
138+
assert "classification" not in positive_document.to_dict()
139+
assert "classification" not in negative_document.to_dict()
138140

139141
@pytest.mark.integration
140142
@pytest.mark.slow
@@ -150,6 +152,8 @@ def test_run(self, monkeypatch):
150152
assert component.pipeline is not None
151153
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
152154
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
155+
assert "classification" not in positive_document.to_dict()
156+
assert "classification" not in negative_document.to_dict()
153157

154158
def test_serialization_and_deserialization_pipeline(self):
155159
pipeline = Pipeline()

test/components/extractors/test_llm_metadata_extractor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,13 @@ def test_run_with_document_content_none(self, monkeypatch):
251251
assert failed_doc_none.id == doc_with_none_content.id
252252
assert "metadata_extraction_error" in failed_doc_none.meta
253253
assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
254+
assert "metadata_extraction_error" not in doc_with_none_content.meta
254255

255256
failed_doc_empty = result["failed_documents"][1]
256257
assert failed_doc_empty.id == doc_with_empty_content.id
257258
assert "metadata_extraction_error" in failed_doc_empty.meta
258259
assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
260+
assert "metadata_extraction_error" not in doc_with_empty_content.meta
259261

260262
# Ensure no attempt was made to call the LLM
261263
mock_chat_generator.run.assert_not_called()

0 commit comments

Comments
 (0)