Skip to content

Commit 8e48a7b

Browse files
fix(cohere): use dataclass replace to avoid modifying input documents (#2755)
This PR fixes the Cohere document embedder and ranker to not modify input Documents in place when setting embeddings or scores. Instead of mutating the original documents: doc.embedding = embeddings doc.score = score We now create new document instances using dataclass replace: replace(doc, embedding=embeddings) replace(doc, score=score) This follows the established pattern from haystack-ai/haystack#9693 and aligns with other integrations (FastEmbed, Optimum, Nvidia, Bedrock). Related to: #2174
1 parent f3b768a commit 8e48a7b

4 files changed

Lines changed: 75 additions & 6 deletions

File tree

integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
from dataclasses import replace
45
from typing import Any
56

67
from haystack import Document, component, default_from_dict, default_to_dict
@@ -195,10 +196,11 @@ def run(self, documents: list[Document]) -> dict[str, list[Document] | dict[str,
195196
self.embedding_type,
196197
)
197198

199+
new_documents = []
198200
for doc, embeddings in zip(documents, all_embeddings, strict=True):
199-
doc.embedding = embeddings
201+
new_documents.append(replace(doc, embedding=embeddings))
200202

201-
return {"documents": documents, "meta": metadata}
203+
return {"documents": new_documents, "meta": metadata}
202204

203205
@component.output_types(documents=list[Document], meta=dict[str, Any])
204206
async def run_async(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]:
@@ -228,7 +230,8 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document]
228230
embedding_type=self.embedding_type,
229231
)
230232

233+
new_documents = []
231234
for doc, embeddings in zip(documents, all_embeddings, strict=True):
232-
doc.embedding = embeddings
235+
new_documents.append(replace(doc, embedding=embeddings))
233236

234-
return {"documents": documents, "meta": metadata}
237+
return {"documents": new_documents, "meta": metadata}

integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import replace
12
from typing import Any
23

34
from haystack import Document, component, default_from_dict, default_to_dict, logging
@@ -162,6 +163,5 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) -
162163
sorted_docs = []
163164
for idx, score in zip(indices, scores, strict=True):
164165
doc = documents[idx]
165-
doc.score = score
166-
sorted_docs.append(documents[idx])
166+
sorted_docs.append(replace(doc, score=score))
167167
return {"documents": sorted_docs}

integrations/cohere/tests/test_document_embedder.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,51 @@ async def test_run_async(self, mock_get_response):
195195
assert doc_with_embedding.meta == doc.meta
196196
assert doc_with_embedding.embedding == embedding
197197

198+
@patch("haystack_integrations.components.embedders.cohere.document_embedder.get_response")
199+
def test_run_does_not_modify_original_documents(self, mock_get_response):
200+
embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key"))
201+
202+
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
203+
mock_get_response.return_value = (embeddings, {"api_version": "1.0"})
204+
205+
docs = [
206+
Document(content="I love cheese", meta={"topic": "Cuisine"}),
207+
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
208+
]
209+
210+
result = embedder.run(docs)
211+
212+
# Check that the original documents are not modified
213+
for doc in docs:
214+
assert doc.embedding is None
215+
216+
# Check that the returned documents have embeddings
217+
for doc_with_embedding, embedding in zip(result["documents"], embeddings, strict=True):
218+
assert doc_with_embedding.embedding == embedding
219+
220+
@pytest.mark.asyncio
221+
@patch("haystack_integrations.components.embedders.cohere.document_embedder.get_async_response")
222+
async def test_run_async_does_not_modify_original_documents(self, mock_get_response):
223+
embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key"))
224+
225+
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
226+
mock_get_response.return_value = (embeddings, {"api_version": "1.0"})
227+
228+
docs = [
229+
Document(content="I love cheese", meta={"topic": "Cuisine"}),
230+
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
231+
]
232+
233+
result = await embedder.run_async(docs)
234+
235+
# Check that the original documents are not modified
236+
for doc in docs:
237+
assert doc.embedding is None
238+
239+
# Check that the returned documents have embeddings
240+
for doc_with_embedding, embedding in zip(result["documents"], embeddings, strict=True):
241+
assert doc_with_embedding.embedding == embedding
242+
198243
@pytest.mark.skipif(
199244
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
200245
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",

integrations/cohere/tests/test_ranker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,27 @@ def test_run_topk_set_in_init(self, monkeypatch, mock_ranker_response): # noqa:
295295
Document(id="efgh", content="doc2", score=0.95),
296296
]
297297

298+
def test_run_does_not_modify_original_documents(self, monkeypatch, mock_ranker_response): # noqa: ARG002
299+
monkeypatch.setenv("CO_API_KEY", "test-api-key")
300+
ranker = CohereRanker(top_k=2)
301+
query = "test"
302+
documents = [
303+
Document(id="abcd", content="doc1"),
304+
Document(id="efgh", content="doc2"),
305+
Document(id="ijkl", content="doc3"),
306+
]
307+
308+
ranker_results = ranker.run(query, documents)
309+
310+
# Check that the original documents are not modified
311+
for doc in documents:
312+
assert doc.score is None
313+
314+
# Check that the returned documents have scores
315+
reranked_docs = ranker_results["documents"]
316+
for doc in reranked_docs:
317+
assert doc.score is not None
318+
298319
@pytest.mark.skipif(
299320
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
300321
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",

0 commit comments

Comments
 (0)