Skip to content

Commit 85eeab0

Browse files
authored
fix: fix wrong batching in Google GenAI Document Embedder (#2951)
* fix: fix wrong batching in Google GenAI Document Embedder * unrelated: update supporte models
1 parent 288b779 commit 85eeab0

3 files changed

Lines changed: 71 additions & 8 deletions

File tree

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _embed_batch(
207207
for batch in tqdm(
208208
batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
209209
):
210-
args: dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]}
210+
args: dict[str, Any] = {"model": self._model, "contents": list(batch)}
211211
if resolved_config:
212212
args["config"] = resolved_config
213213

@@ -238,7 +238,7 @@ async def _embed_batch_async(
238238
for batch in tqdm(
239239
batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
240240
):
241-
args: dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]}
241+
args: dict[str, Any] = {"model": self._model, "contents": list(batch)}
242242
if self._config:
243243
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None
244244

integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def weather_function(city: str):
153153
"""
154154

155155
SUPPORTED_MODELS: ClassVar[list[str]] = [
156+
"gemini-3.1-pro-preview",
157+
"gemini-3-flash-preview",
158+
"gemini-3.1-flash-lite-preview",
156159
"gemini-2.5-pro",
157160
"gemini-2.5-flash",
158161
"gemini-2.5-flash-lite",

integrations/google_genai/tests/test_document_embedder.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
import os
66
import random
7+
from unittest.mock import AsyncMock, MagicMock
78

9+
import numpy as np
810
import pytest
911
from haystack import Document
1012
from haystack.utils.auth import Secret
1113

12-
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder
14+
from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder, GoogleGenAITextEmbedder
1315

1416

1517
def mock_google_response(contents: list[str], model: str = "gemini-embedding-001", **kwargs) -> dict:
@@ -256,20 +258,66 @@ async def mock_embed_batch_async(texts_to_embed, batch_size):
256258
for doc_with_embedding in result["documents"]:
257259
assert doc_with_embedding.embedding == [0.1, 0.2, 0.3]
258260

261+
def test_embed_batch_passes_full_texts(self, monkeypatch):
262+
monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key")
263+
embedder = GoogleGenAIDocumentEmbedder(batch_size=2)
264+
265+
texts = ["first document text", "second document text", "third document text"]
266+
267+
mock_embedding = MagicMock()
268+
mock_embedding.values = [0.1, 0.2, 0.3]
269+
270+
mock_response = MagicMock()
271+
mock_response.embeddings = [mock_embedding]
272+
273+
embedder._client = MagicMock()
274+
embedder._client.models.embed_content.return_value = mock_response
275+
276+
embedder._embed_batch(texts, batch_size=2)
277+
278+
calls = embedder._client.models.embed_content.call_args_list
279+
assert len(calls) == 2
280+
assert calls[0].kwargs["contents"] == ["first document text", "second document text"]
281+
assert calls[1].kwargs["contents"] == ["third document text"]
282+
283+
@pytest.mark.asyncio
284+
async def test_embed_batch_async_passes_full_texts(self, monkeypatch):
285+
monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key")
286+
embedder = GoogleGenAIDocumentEmbedder(batch_size=2)
287+
288+
texts = ["first document text", "second document text", "third document text"]
289+
290+
mock_embedding = MagicMock()
291+
mock_embedding.values = [0.1, 0.2, 0.3]
292+
293+
mock_response = MagicMock()
294+
mock_response.embeddings = [mock_embedding]
295+
296+
embedder._client = MagicMock()
297+
embedder._client.aio.models.embed_content = AsyncMock(return_value=mock_response)
298+
299+
await embedder._embed_batch_async(texts, batch_size=2)
300+
301+
calls = embedder._client.aio.models.embed_content.call_args_list
302+
assert len(calls) == 2
303+
assert calls[0].kwargs["contents"] == ["first document text", "second document text"]
304+
assert calls[1].kwargs["contents"] == ["third document text"]
305+
259306
@pytest.mark.skipif(
260307
not os.environ.get("GOOGLE_API_KEY", None),
261308
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
262309
)
263310
@pytest.mark.integration
264311
def test_run(self):
312+
model = "gemini-embedding-001"
313+
265314
docs = [
266-
Document(content="I love cheese", meta={"topic": "Cuisine"}),
267-
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
315+
Document(content="The capybara is the largest rodent in the world and lives near rivers in South America."),
316+
Document(content="Dogs are domesticated mammals known for their loyalty and bond with humans."),
317+
Document(content="The tiger is the largest big cat, recognized by its orange coat with black stripes."),
268318
]
269319

270-
model = "gemini-embedding-001"
271-
272-
embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ")
320+
embedder = GoogleGenAIDocumentEmbedder(model=model, config={"task_type": "RETRIEVAL_DOCUMENT"})
273321

274322
result = embedder.run(documents=docs)
275323
documents_with_embeddings = result["documents"]
@@ -283,6 +331,18 @@ def test_run(self):
283331

284332
assert result["meta"]["model"] == model
285333

334+
text_embedder = GoogleGenAITextEmbedder(model=model, config={"task_type": "RETRIEVAL_QUERY"})
335+
query_embedding = text_embedder.run("capybara")["embedding"]
336+
query_vec = np.array(query_embedding)
337+
338+
similarities = []
339+
for doc in documents_with_embeddings:
340+
doc_vec = np.array(doc.embedding)
341+
cosine_sim = np.dot(query_vec, doc_vec) / (np.linalg.norm(query_vec) * np.linalg.norm(doc_vec))
342+
similarities.append(cosine_sim)
343+
344+
assert similarities[0] == max(similarities)
345+
286346
@pytest.mark.asyncio
287347
@pytest.mark.skipif(
288348
not os.environ.get("GOOGLE_API_KEY", None),

0 commit comments

Comments
 (0)