Skip to content

Commit 03d9f0f

Browse files
authored
fix: prevent in-place mutation of documents in Document Embedders (#9693)
* fix: prevent in-place mutation of documents after embeddings by using deepcopy * Add tests * use from dataclasses import replace instead of deepcopy * Address PR comments
1 parent 35e6936 commit 03d9f0f

10 files changed

Lines changed: 93 additions & 59 deletions

haystack/components/embedders/hugging_face_api_document_embedder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from dataclasses import replace
56
from typing import Any, Optional, Union
67

78
from tqdm import tqdm
@@ -328,10 +329,11 @@ def run(self, documents: list[Document]):
328329

329330
embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
330331

332+
new_documents = []
331333
for doc, emb in zip(documents, embeddings):
332-
doc.embedding = emb
334+
new_documents.append(replace(doc, embedding=emb))
333335

334-
return {"documents": documents}
336+
return {"documents": new_documents}
335337

336338
@component.output_types(documents=list[Document])
337339
async def run_async(self, documents: list[Document]):
@@ -355,7 +357,8 @@ async def run_async(self, documents: list[Document]):
355357

356358
embeddings = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
357359

360+
new_documents = []
358361
for doc, emb in zip(documents, embeddings):
359-
doc.embedding = emb
362+
new_documents.append(replace(doc, embedding=emb))
360363

361-
return {"documents": documents}
364+
return {"documents": new_documents}

haystack/components/embedders/image/sentence_transformers_doc_image_embedder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from copy import copy
5+
from dataclasses import replace
66
from typing import Any, Literal, Optional
77

88
from haystack import Document, component, default_from_dict, default_to_dict
@@ -281,10 +281,12 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
281281

282282
docs_with_embeddings = []
283283
for doc, emb in zip(documents, embeddings):
284-
copied_doc = copy(doc)
285-
copied_doc.embedding = emb
286284
# we store this information for later inspection
287-
copied_doc.meta["embedding_source"] = {"type": "image", "file_path_meta_field": self.file_path_meta_field}
288-
docs_with_embeddings.append(copied_doc)
285+
new_meta = {
286+
**doc.meta,
287+
"embedding_source": {"type": "image", "file_path_meta_field": self.file_path_meta_field},
288+
}
289+
new_doc = replace(doc, meta=new_meta, embedding=emb)
290+
docs_with_embeddings.append(new_doc)
289291

290292
return {"documents": docs_with_embeddings}

haystack/components/embedders/openai_document_embedder.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6+
from dataclasses import replace
67
from typing import Any, Optional
78

89
from more_itertools import batched
@@ -307,11 +308,14 @@ def run(self, documents: list[Document]):
307308

308309
doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
309310

310-
doc_id_to_document = {doc.id: doc for doc in documents}
311-
for doc_id, emb in doc_ids_to_embeddings.items():
312-
doc_id_to_document[doc_id].embedding = emb
311+
new_documents = []
312+
for doc in documents:
313+
if doc.id in doc_ids_to_embeddings:
314+
new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
315+
else:
316+
new_documents.append(replace(doc))
313317

314-
return {"documents": list(doc_id_to_document.values()), "meta": meta}
318+
return {"documents": new_documents, "meta": meta}
315319

316320
@component.output_types(documents=list[Document], meta=dict[str, Any])
317321
async def run_async(self, documents: list[Document]):
@@ -338,8 +342,11 @@ async def run_async(self, documents: list[Document]):
338342
texts_to_embed=texts_to_embed, batch_size=self.batch_size
339343
)
340344

341-
doc_id_to_document = {doc.id: doc for doc in documents}
342-
for doc_id, emb in doc_ids_to_embeddings.items():
343-
doc_id_to_document[doc_id].embedding = emb
345+
new_documents = []
346+
for doc in documents:
347+
if doc.id in doc_ids_to_embeddings:
348+
new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
349+
else:
350+
new_documents.append(replace(doc))
344351

345-
return {"documents": list(doc_id_to_document.values()), "meta": meta}
352+
return {"documents": new_documents, "meta": meta}

haystack/components/embedders/sentence_transformers_document_embedder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from dataclasses import replace
56
from typing import Any, Literal, Optional
67

78
from haystack import Document, component, default_from_dict, default_to_dict
@@ -257,7 +258,8 @@ def run(self, documents: list[Document]):
257258
**(self.encode_kwargs if self.encode_kwargs else {}),
258259
)
259260

261+
new_documents = []
260262
for doc, emb in zip(documents, embeddings):
261-
doc.embedding = emb
263+
new_documents.append(replace(doc, embedding=emb))
262264

263-
return {"documents": documents}
265+
return {"documents": new_documents}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Prevented in-place mutation of input `Document` objects in all `DocumentEmbedder` components
5+
by creating copies with `dataclasses.replace` before processing.

test/components/embedders/image/test_sentence_transformers_doc_image_embedder.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,16 @@ def test_run(self, test_files_path):
209209

210210
assert isinstance(result["documents"], list)
211211
assert len(result["documents"]) == len(documents)
212-
for doc in result["documents"]:
213-
assert isinstance(doc, Document)
214-
assert isinstance(doc.embedding, list)
215-
assert isinstance(doc.embedding[0], float)
216-
assert "embedding_source" in doc.meta
217-
assert doc.meta["embedding_source"]["type"] == "image"
218-
assert "file_path_meta_field" in doc.meta["embedding_source"]
212+
for doc, new_doc in zip(documents, result["documents"]):
213+
assert doc.embedding is None
214+
assert new_doc is not doc
215+
assert isinstance(new_doc, Document)
216+
assert isinstance(new_doc.embedding, list)
217+
assert isinstance(new_doc.embedding[0], float)
218+
assert "embedding_source" not in doc.meta
219+
assert "embedding_source" in new_doc.meta
220+
assert new_doc.meta["embedding_source"]["type"] == "image"
221+
assert "file_path_meta_field" in new_doc.meta["embedding_source"]
219222

220223
def test_run_no_warmup(self):
221224
embedder = SentenceTransformersDocumentImageEmbedder(model="model")
@@ -338,11 +341,14 @@ def test_live_run(self, test_files_path, monkeypatch):
338341

339342
result = embedder.run(documents=documents)
340343
assert len(result["documents"]) == len(documents)
341-
for doc in result["documents"]:
342-
assert isinstance(doc, Document)
343-
assert isinstance(doc.embedding, list)
344-
assert len(doc.embedding) == 512
345-
assert all(isinstance(x, float) for x in doc.embedding)
346-
assert "embedding_source" in doc.meta
347-
assert doc.meta["embedding_source"]["type"] == "image"
348-
assert "file_path_meta_field" in doc.meta["embedding_source"]
344+
for doc, new_doc in zip(documents, result["documents"]):
345+
assert doc.embedding is None
346+
assert new_doc is not doc
347+
assert isinstance(new_doc, Document)
348+
assert isinstance(new_doc.embedding, list)
349+
assert len(new_doc.embedding) == 512
350+
assert all(isinstance(x, float) for x in new_doc.embedding)
351+
assert "embedding_source" not in doc.meta
352+
assert "embedding_source" in new_doc.meta
353+
assert new_doc.meta["embedding_source"]["type"] == "image"
354+
assert "file_path_meta_field" in new_doc.meta["embedding_source"]

test/components/embedders/test_azure_document_embedder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,13 @@ def test_run(self):
265265

266266
assert isinstance(documents_with_embeddings, list)
267267
assert len(documents_with_embeddings) == len(docs)
268-
for doc in documents_with_embeddings:
269-
assert isinstance(doc, Document)
270-
assert isinstance(doc.embedding, list)
271-
assert len(doc.embedding) == 1536
272-
assert all(isinstance(x, float) for x in doc.embedding)
268+
for doc, new_doc in zip(docs, documents_with_embeddings):
269+
assert doc.embedding is None
270+
assert new_doc is not doc
271+
assert isinstance(new_doc, Document)
272+
assert isinstance(new_doc.embedding, list)
273+
assert len(new_doc.embedding) == 1536
274+
assert all(isinstance(x, float) for x in new_doc.embedding)
273275

274276
assert metadata["usage"]["prompt_tokens"] == 15
275277
assert metadata["usage"]["total_tokens"] == 15

test/components/embedders/test_hugging_face_api_document_embedder.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def test_run(self, mock_check_valid_model):
287287
Document(content="I love cheese", meta={"topic": "Cuisine"}),
288288
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
289289
]
290-
291290
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
292291
mock_embedding_patch.side_effect = mock_embedding_generation
293292

@@ -316,11 +315,13 @@ def test_run(self, mock_check_valid_model):
316315

317316
assert isinstance(documents_with_embeddings, list)
318317
assert len(documents_with_embeddings) == len(docs)
319-
for doc in documents_with_embeddings:
320-
assert isinstance(doc, Document)
321-
assert isinstance(doc.embedding, list)
322-
assert len(doc.embedding) == 384
323-
assert all(isinstance(x, float) for x in doc.embedding)
318+
for doc, new_doc in zip(docs, documents_with_embeddings):
319+
assert doc.embedding is None
320+
assert new_doc is not doc
321+
assert isinstance(new_doc, Document)
322+
assert isinstance(new_doc.embedding, list)
323+
assert len(new_doc.embedding) == 384
324+
assert all(isinstance(x, float) for x in new_doc.embedding)
324325

325326
def test_run_custom_batch_size(self, mock_check_valid_model):
326327
docs = [

test/components/embedders/test_openai_document_embedder.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,13 @@ def test_run(self):
281281

282282
assert isinstance(documents_with_embeddings, list)
283283
assert len(documents_with_embeddings) == len(docs)
284-
for doc in documents_with_embeddings:
285-
assert isinstance(doc, Document)
286-
assert isinstance(doc.embedding, list)
287-
assert len(doc.embedding) == 1536
288-
assert all(isinstance(x, float) for x in doc.embedding)
284+
for doc, new_doc in zip(docs, documents_with_embeddings):
285+
assert doc.embedding is None
286+
assert new_doc is not doc
287+
assert isinstance(new_doc, Document)
288+
assert isinstance(new_doc.embedding, list)
289+
assert len(new_doc.embedding) == 1536
290+
assert all(isinstance(x, float) for x in new_doc.embedding)
289291

290292
assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], (
291293
"The model name does not contain 'text' and 'ada'"
@@ -311,11 +313,13 @@ async def test_run_async(self):
311313

312314
assert isinstance(documents_with_embeddings, list)
313315
assert len(documents_with_embeddings) == len(docs)
314-
for doc in documents_with_embeddings:
315-
assert isinstance(doc, Document)
316-
assert isinstance(doc.embedding, list)
317-
assert len(doc.embedding) == 1536
318-
assert all(isinstance(x, float) for x in doc.embedding)
316+
for doc, new_doc in zip(docs, documents_with_embeddings):
317+
assert doc.embedding is None
318+
assert new_doc is not doc
319+
assert isinstance(new_doc, Document)
320+
assert isinstance(new_doc.embedding, list)
321+
assert len(new_doc.embedding) == 1536
322+
assert all(isinstance(x, float) for x in new_doc.embedding)
319323

320324
assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], (
321325
"The model name does not contain 'text' and 'ada'"

test/components/embedders/test_sentence_transformers_document_embedder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,12 @@ def test_run(self):
293293

294294
assert isinstance(result["documents"], list)
295295
assert len(result["documents"]) == len(documents)
296-
for doc in result["documents"]:
297-
assert isinstance(doc, Document)
298-
assert isinstance(doc.embedding, list)
299-
assert isinstance(doc.embedding[0], float)
296+
for doc, new_doc in zip(documents, result["documents"]):
297+
assert new_doc is not doc
298+
assert doc.embedding is None
299+
assert isinstance(new_doc, Document)
300+
assert isinstance(new_doc.embedding, list)
301+
assert isinstance(new_doc.embedding[0], float)
300302

301303
def test_run_wrong_input_format(self):
302304
embedder = SentenceTransformersDocumentEmbedder(model="model")

0 commit comments

Comments
 (0)