diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 4c7132eeab..1658219143 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from dataclasses import replace from typing import Any from haystack import Document, component, default_from_dict, default_to_dict @@ -195,10 +196,11 @@ def run(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, self.embedding_type, ) + new_documents = [] for doc, embeddings in zip(documents, all_embeddings, strict=True): - doc.embedding = embeddings + new_documents.append(replace(doc, embedding=embeddings)) - return {"documents": documents, "meta": metadata} + return {"documents": new_documents, "meta": metadata} @component.output_types(documents=list[Document], meta=dict[str, Any]) async def run_async(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]: @@ -228,7 +230,8 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document] embedding_type=self.embedding_type, ) + new_documents = [] for doc, embeddings in zip(documents, all_embeddings, strict=True): - doc.embedding = embeddings + new_documents.append(replace(doc, embedding=embeddings)) - return {"documents": documents, "meta": metadata} + return {"documents": new_documents, "meta": metadata} diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py index 58b81111aa..824e1feec8 100644 --- a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py @@ -1,3 +1,4 @@ +from dataclasses import replace from typing import Any from haystack import Document, component, default_from_dict, default_to_dict, logging @@ -162,6 +163,5 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) - sorted_docs = [] for idx, score in zip(indices, scores, strict=True): doc = documents[idx] - doc.score = score - sorted_docs.append(documents[idx]) + sorted_docs.append(replace(doc, score=score)) return {"documents": sorted_docs} diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index 674a277b32..6ae885e0e1 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -195,6 +195,51 @@ async def test_run_async(self, mock_get_response): assert doc_with_embedding.meta == doc.meta assert doc_with_embedding.embedding == embedding + @patch("haystack_integrations.components.embedders.cohere.document_embedder.get_response") + def test_run_does_not_modify_original_documents(self, mock_get_response): + embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key")) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + mock_get_response.return_value = (embeddings, {"api_version": "1.0"}) + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + result = embedder.run(docs) + + # Check that the original documents are not modified + for doc in docs: + assert doc.embedding is None + + # Check that the returned documents have embeddings + for doc_with_embedding, embedding in zip(result["documents"], embeddings, strict=True): + assert doc_with_embedding.embedding == embedding + + @pytest.mark.asyncio + @patch("haystack_integrations.components.embedders.cohere.document_embedder.get_async_response") + async def test_run_async_does_not_modify_original_documents(self, mock_get_response): + embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key")) + + embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + mock_get_response.return_value = (embeddings, {"api_version": "1.0"}) + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + result = await embedder.run_async(docs) + + # Check that the original documents are not modified + for doc in docs: + assert doc.embedding is None + + # Check that the returned documents have embeddings + for doc_with_embedding, embedding in zip(result["documents"], embeddings, strict=True): + assert doc_with_embedding.embedding == embedding + @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", diff --git a/integrations/cohere/tests/test_ranker.py b/integrations/cohere/tests/test_ranker.py index f310715e16..f5db7bec66 100644 --- a/integrations/cohere/tests/test_ranker.py +++ b/integrations/cohere/tests/test_ranker.py @@ -295,6 +295,27 @@ def test_run_topk_set_in_init(self, monkeypatch, mock_ranker_response): # noqa: Document(id="efgh", content="doc2", score=0.95), ] + def test_run_does_not_modify_original_documents(self, monkeypatch, mock_ranker_response): # noqa: ARG002 + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker(top_k=2) + query = "test" + documents = [ + Document(id="abcd", content="doc1"), + Document(id="efgh", content="doc2"), + Document(id="ijkl", content="doc3"), + ] + + ranker_results = ranker.run(query, documents) + + # Check that the original documents are not modified + for doc in documents: + assert doc.score is None + + # Check that the returned documents have scores + reranked_docs = ranker_results["documents"] + for doc in reranked_docs: + assert doc.score is not None + @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",