Skip to content

Commit ed6f209

Browse files
fix(google_genai): use dataclass replace to avoid modifying input documents (#2762)
This PR fixes the Google GenAI document embedder to not modify input Documents in place when setting embeddings. Instead of mutating the original documents: doc.embedding = embeddings We now create new document instances using dataclass replace: replace(doc, embedding=embeddings) This follows the established pattern from haystack-ai/haystack#9693 and aligns with other integrations (FastEmbed, Optimum, Nvidia, Bedrock, Cohere). Related to: #2174 Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent 65672d2 commit ed6f209

2 files changed

Lines changed: 62 additions & 4 deletions

File tree

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from dataclasses import replace
56
from typing import Any, Literal
67

78
from google.genai import types
@@ -281,10 +282,11 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]] | dict[str
281282
meta: dict[str, Any]
282283
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size)
283284

285+
new_documents = []
284286
for doc, emb in zip(documents, embeddings, strict=True):
285-
doc.embedding = emb
287+
new_documents.append(replace(doc, embedding=emb))
286288

287-
return {"documents": documents, "meta": meta}
289+
return {"documents": new_documents, "meta": meta}
288290

289291
@component.output_types(documents=list[Document], meta=dict[str, Any])
290292
async def run_async(self, documents: list[Document]) -> dict[str, list[Document]] | dict[str, Any]:
@@ -310,7 +312,8 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document]
310312

311313
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self._batch_size)
312314

315+
new_documents = []
313316
for doc, emb in zip(documents, embeddings, strict=True):
314-
doc.embedding = emb
317+
new_documents.append(replace(doc, embedding=emb))
315318

316-
return {"documents": documents, "meta": meta}
319+
return {"documents": new_documents, "meta": meta}

integrations/google_genai/tests/test_document_embedder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,61 @@ def test_run_on_empty_list(self):
201201
assert result["documents"] is not None
202202
assert not result["documents"] # empty list
203203

204+
def test_run_does_not_modify_original_documents(self, monkeypatch):
205+
monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key")
206+
embedder = GoogleGenAIDocumentEmbedder()
207+
208+
docs = [
209+
Document(content="I love cheese", meta={"topic": "Cuisine"}),
210+
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
211+
]
212+
213+
# Mock the _embed_batch method to return fake embeddings
214+
def mock_embed_batch(texts_to_embed, batch_size):
215+
embeddings = [[0.1, 0.2, 0.3] for _ in texts_to_embed]
216+
meta = {"model": "text-embedding-004"}
217+
return embeddings, meta
218+
219+
embedder._embed_batch = mock_embed_batch
220+
221+
result = embedder.run(documents=docs)
222+
223+
# Check that the original documents are not modified
224+
for doc in docs:
225+
assert doc.embedding is None
226+
227+
# Check that the returned documents have embeddings
228+
for doc_with_embedding in result["documents"]:
229+
assert doc_with_embedding.embedding == [0.1, 0.2, 0.3]
230+
231+
@pytest.mark.asyncio
232+
async def test_run_async_does_not_modify_original_documents(self, monkeypatch):
233+
monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key")
234+
embedder = GoogleGenAIDocumentEmbedder()
235+
236+
docs = [
237+
Document(content="I love cheese", meta={"topic": "Cuisine"}),
238+
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
239+
]
240+
241+
# Mock the _embed_batch_async method to return fake embeddings
242+
async def mock_embed_batch_async(texts_to_embed, batch_size):
243+
embeddings = [[0.1, 0.2, 0.3] for _ in texts_to_embed]
244+
meta = {"model": "text-embedding-004"}
245+
return embeddings, meta
246+
247+
embedder._embed_batch_async = mock_embed_batch_async
248+
249+
result = await embedder.run_async(documents=docs)
250+
251+
# Check that the original documents are not modified
252+
for doc in docs:
253+
assert doc.embedding is None
254+
255+
# Check that the returned documents have embeddings
256+
for doc_with_embedding in result["documents"]:
257+
assert doc_with_embedding.embedding == [0.1, 0.2, 0.3]
258+
204259
@pytest.mark.skipif(
205260
not os.environ.get("GOOGLE_API_KEY", None),
206261
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",

0 commit comments

Comments
 (0)