diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index a64689cfef..b8b9ef862d 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from dataclasses import replace from typing import Any, Literal from google.genai import types @@ -281,10 +282,11 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]] | dict[str meta: dict[str, Any] embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size) + new_documents = [] for doc, emb in zip(documents, embeddings, strict=True): - doc.embedding = emb + new_documents.append(replace(doc, embedding=emb)) - return {"documents": documents, "meta": meta} + return {"documents": new_documents, "meta": meta} @component.output_types(documents=list[Document], meta=dict[str, Any]) async def run_async(self, documents: list[Document]) -> dict[str, list[Document]] | dict[str, Any]: @@ -310,7 +312,8 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document] embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self._batch_size) + new_documents = [] for doc, emb in zip(documents, embeddings, strict=True): - doc.embedding = emb + new_documents.append(replace(doc, embedding=emb)) - return {"documents": documents, "meta": meta} + return {"documents": new_documents, "meta": meta} diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 2579801bc6..d2f1af4985 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -201,6 +201,61 @@ def test_run_on_empty_list(self): assert result["documents"] is not None assert not result["documents"] # empty list + def test_run_does_not_modify_original_documents(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleGenAIDocumentEmbedder() + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + # Mock the _embed_batch method to return fake embeddings + def mock_embed_batch(texts_to_embed, batch_size): + embeddings = [[0.1, 0.2, 0.3] for _ in texts_to_embed] + meta = {"model": "text-embedding-004"} + return embeddings, meta + + embedder._embed_batch = mock_embed_batch + + result = embedder.run(documents=docs) + + # Check that the original documents are not modified + for doc in docs: + assert doc.embedding is None + + # Check that the returned documents have embeddings + for doc_with_embedding in result["documents"]: + assert doc_with_embedding.embedding == [0.1, 0.2, 0.3] + + @pytest.mark.asyncio + async def test_run_async_does_not_modify_original_documents(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleGenAIDocumentEmbedder() + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + # Mock the _embed_batch_async method to return fake embeddings + async def mock_embed_batch_async(texts_to_embed, batch_size): + embeddings = [[0.1, 0.2, 0.3] for _ in texts_to_embed] + meta = {"model": "text-embedding-004"} + return embeddings, meta + + embedder._embed_batch_async = mock_embed_batch_async + + result = await embedder.run_async(documents=docs) + + # Check that the original documents are not modified + for doc in docs: + assert doc.embedding is None + + # Check that the returned documents have embeddings + for doc_with_embedding in result["documents"]: + assert doc_with_embedding.embedding == [0.1, 0.2, 0.3] + @pytest.mark.skipif( not os.environ.get("GOOGLE_API_KEY", None), reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",