diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 0cd0f1d21e..aa16cc2e2b 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from dataclasses import replace from typing import Any import requests @@ -219,7 +220,8 @@ def run(self, documents: list[Document]) -> dict[str, Any]: texts_to_embed=texts_to_embed, batch_size=self.batch_size, parameters=parameters ) + new_documents: list[Document] = [] for doc, emb in zip(documents, embeddings, strict=True): - doc.embedding = emb + new_documents.append(replace(doc, embedding=emb)) - return {"documents": documents, "meta": metadata} + return {"documents": new_documents, "meta": metadata} diff --git a/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py b/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py index bfb0ff16f4..b8b1867df5 100644 --- a/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py +++ b/integrations/jina/src/haystack_integrations/components/rankers/jina/ranker.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from dataclasses import replace from typing import Any import requests @@ -161,12 +162,12 @@ def run( relevance_score = result["relevance_score"] doc = documents[index] if top_k is None or len(ranked_docs) < top_k: - doc.score = relevance_score + scored_doc = replace(doc, score=relevance_score) if score_threshold is not None: if relevance_score >= score_threshold: - ranked_docs.append(doc) + ranked_docs.append(scored_doc) else: - ranked_docs.append(doc) + ranked_docs.append(scored_doc) else: break diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index 247b95effc..c2b2f9ca18 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -204,6 +204,29 @@ def test_run(self): assert all(isinstance(x, float) for x in doc.embedding) assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} + def test_run_does_not_modify_original_documents(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "jina-embeddings-v2-base-en" + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + embedder = JinaDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), + model=model, + ) + + result = embedder.run(documents=docs) + + # originals remain unchanged + for doc in docs: + assert doc.embedding is None + + # returned docs carry embeddings + for doc_with_embedding in result["documents"]: + assert doc_with_embedding.embedding is not None + def test_run_custom_batch_size(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), diff --git a/integrations/jina/tests/test_ranker.py b/integrations/jina/tests/test_ranker.py index e7fdaade03..a411abbe95 100644 --- a/integrations/jina/tests/test_ranker.py +++ b/integrations/jina/tests/test_ranker.py @@ -105,6 +105,32 @@ def test_run(self): assert doc.score == len(ranked_documents) - i assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} + def test_run_does_not_modify_original_documents(self): + docs = [ + Document(content="I love cheese"), + Document(content="A transformer is a deep learning architecture"), + Document(content="A transformer is something"), + Document(content="A transformer is not good"), + ] + query = "What is a transformer?" + + model = "jina-ranker" + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + ranker = JinaRanker( + api_key=Secret.from_token("fake-api-key"), + model=model, + ) + + result = ranker.run(query=query, documents=docs) + + # originals remain unchanged + for doc in docs: + assert doc.score is None + + # returned docs carry scores + for doc in result["documents"]: + assert doc.score is not None + def test_run_wrong_input_format(self): ranker = JinaRanker(api_key=Secret.from_token("fake-api-key"))