Skip to content

Commit ceee0d5

Browse files
authored
refactor: Google GenAI embedders - adaptations for Gemini Embedding 2 general availability (#3251)
1 parent 78b139e commit ceee0d5

6 files changed

Lines changed: 66 additions & 58 deletions

File tree

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def __init__(
102102
The name of the model to use for calculating embeddings.
103103
The default model is `gemini-embedding-001`.
104104
:param prefix:
105-
A string to add at the beginning of each text.
105+
A string to add at the beginning of each text. It can be used to specify a task type for
106+
`gemini-embedding-2`. For available task types, see
107+
[Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
106108
:param suffix:
107109
A string to add at the end of each text.
108110
:param batch_size:
@@ -114,9 +116,12 @@ def __init__(
114116
:param embedding_separator:
115117
Separator used to concatenate the metadata fields to the document text.
116118
:param config:
117-
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
118-
If not specified, it defaults to `{"task_type": "SEMANTIC_SIMILARITY"}`.
119-
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
119+
A dictionary of keyword arguments to configure embedding content configuration.
120+
See [Google API documentation](https://googleapis.github.io/python-genai/genai.html#genai.types.EmbedContentConfig)
121+
for the available options.
122+
Specifying task types in `config` does not take effect for `gemini-embedding-2`.
123+
See [Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types) for more
124+
information.
120125
"""
121126
self._api_key = api_key
122127
self._api = api
@@ -129,7 +134,7 @@ def __init__(
129134
self._progress_bar = progress_bar
130135
self._meta_fields_to_embed = meta_fields_to_embed or []
131136
self._embedding_separator = embedding_separator
132-
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
137+
self._config = config
133138

134139
self._client = _get_client(
135140
api_key=api_key,
@@ -207,7 +212,8 @@ def _embed_batch(
207212
range(0, len(texts_to_embed), batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
208213
):
209214
batch = texts_to_embed[i : i + batch_size]
210-
args: dict[str, Any] = {"model": self._model, "contents": batch}
215+
contents = [types.Content(parts=[types.Part.from_text(text=t)]) for t in batch]
216+
args: dict[str, Any] = {"model": self._model, "contents": contents}
211217
if resolved_config:
212218
args["config"] = resolved_config
213219

@@ -239,7 +245,8 @@ async def _embed_batch_async(
239245
range(0, len(texts_to_embed), batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
240246
):
241247
batch = texts_to_embed[i : i + batch_size]
242-
args: dict[str, Any] = {"model": self._model, "contents": batch}
248+
contents = [types.Content(parts=[types.Part.from_text(text=t)]) for t in batch]
249+
args: dict[str, Any] = {"model": self._model, "contents": contents}
243250
if self._config:
244251
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None
245252

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/multimodal_document_embedder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __init__(
177177
file_path_meta_field: str = "file_path",
178178
root_path: str | None = None,
179179
image_size: tuple[int, int] | None = None,
180-
model: str = "gemini-embedding-2-preview",
180+
model: str = "gemini-embedding-2",
181181
batch_size: int = 6,
182182
progress_bar: bool = True,
183183
config: dict[str, Any] | None = None,
@@ -213,11 +213,10 @@ def __init__(
213213
:param progress_bar:
214214
If `True`, shows a progress bar when running.
215215
:param config:
216-
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
216+
A dictionary of keyword arguments to configure embedding content configuration.
217217
You can for example set the output dimensionality of the embedding: `{"output_dimensionality": 768}`.
218-
It also allows customizing the task type. If the task type is not specified, it defaults to
219-
`{"task_type": "RETRIEVAL_DOCUMENT"}`.
220-
For more information, see the [Google AI documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
218+
See [Google API documentation](https://googleapis.github.io/python-genai/genai.html#genai.types.EmbedContentConfig)
219+
for the available options.
221220
"""
222221
self._api_key = api_key
223222
self._api = api
@@ -229,9 +228,6 @@ def __init__(
229228
self._image_size = image_size
230229
self._batch_size = batch_size
231230
self._progress_bar = progress_bar
232-
233-
config = config or {}
234-
config.setdefault("task_type", "RETRIEVAL_DOCUMENT")
235231
self._config = config
236232

237233
self._client = _get_client(

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,18 @@ def __init__(
9999
The name of the model to use for calculating embeddings.
100100
The default model is `gemini-embedding-001`.
101101
:param prefix:
102-
A string to add at the beginning of each text to embed.
102+
A string to add at the beginning of each text. It can be used to specify a task type for
103+
`gemini-embedding-2`. For available task types, see
104+
[Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
103105
:param suffix:
104106
A string to add at the end of each text to embed.
105107
:param config:
106-
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
107-
If not specified, it defaults to `{"task_type": "SEMANTIC_SIMILARITY"}`.
108-
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
108+
A dictionary of keyword arguments to configure embedding content configuration.
109+
See [Google API documentation](https://googleapis.github.io/python-genai/genai.html#genai.types.EmbedContentConfig)
110+
for the available options.
111+
Specifying task types in `config` does not take effect for `gemini-embedding-2`.
112+
See [Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types) for more
113+
information.
109114
"""
110115

111116
self._api_key = api_key
@@ -115,7 +120,7 @@ def __init__(
115120
self._model_name = model
116121
self._prefix = prefix
117122
self._suffix = suffix
118-
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
123+
self._config = config
119124
self._client = _get_client(
120125
api_key=api_key,
121126
api=api,

integrations/google_genai/tests/test_document_embedder.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_init_default(self, monkeypatch):
3939
assert embedder._progress_bar is True
4040
assert embedder._meta_fields_to_embed == []
4141
assert embedder._embedding_separator == "\n"
42-
assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"}
42+
assert embedder._config is None
4343

4444
def test_init_with_parameters(self, monkeypatch):
4545
embedder = GoogleGenAIDocumentEmbedder(
@@ -86,7 +86,7 @@ def test_to_dict(self, monkeypatch):
8686
"meta_fields_to_embed": [],
8787
"embedding_separator": "\n",
8888
"api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY", "GEMINI_API_KEY"], "strict": False},
89-
"config": {"task_type": "SEMANTIC_SIMILARITY"},
89+
"config": None,
9090
"api": "gemini",
9191
"vertex_ai_project": None,
9292
"vertex_ai_location": None,
@@ -277,8 +277,11 @@ def test_embed_batch_passes_full_texts(self, monkeypatch):
277277

278278
calls = embedder._client.models.embed_content.call_args_list
279279
assert len(calls) == 2
280-
assert calls[0].kwargs["contents"] == ["first document text", "second document text"]
281-
assert calls[1].kwargs["contents"] == ["third document text"]
280+
assert [c.parts[0].text for c in calls[0].kwargs["contents"]] == [
281+
"first document text",
282+
"second document text",
283+
]
284+
assert [c.parts[0].text for c in calls[1].kwargs["contents"]] == ["third document text"]
282285

283286
@pytest.mark.asyncio
284287
async def test_embed_batch_async_passes_full_texts(self, monkeypatch):
@@ -300,24 +303,32 @@ async def test_embed_batch_async_passes_full_texts(self, monkeypatch):
300303

301304
calls = embedder._client.aio.models.embed_content.call_args_list
302305
assert len(calls) == 2
303-
assert calls[0].kwargs["contents"] == ["first document text", "second document text"]
304-
assert calls[1].kwargs["contents"] == ["third document text"]
306+
assert [c.parts[0].text for c in calls[0].kwargs["contents"]] == [
307+
"first document text",
308+
"second document text",
309+
]
310+
assert [c.parts[0].text for c in calls[1].kwargs["contents"]] == ["third document text"]
305311

306312
@pytest.mark.skipif(
307313
not os.environ.get("GOOGLE_API_KEY", None),
308314
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
309315
)
310316
@pytest.mark.integration
311-
def test_run(self):
312-
model = "gemini-embedding-001"
313-
317+
@pytest.mark.parametrize(
318+
"model,doc_config,query_config",
319+
[
320+
("gemini-embedding-001", {"task_type": "RETRIEVAL_DOCUMENT"}, {"task_type": "RETRIEVAL_QUERY"}),
321+
("gemini-embedding-2", None, None),
322+
],
323+
)
324+
def test_run(self, model, doc_config, query_config):
314325
docs = [
315326
Document(content="The capybara is the largest rodent in the world and lives near rivers in South America."),
316327
Document(content="Dogs are domesticated mammals known for their loyalty and bond with humans."),
317328
Document(content="The tiger is the largest big cat, recognized by its orange coat with black stripes."),
318329
]
319330

320-
embedder = GoogleGenAIDocumentEmbedder(model=model, config={"task_type": "RETRIEVAL_DOCUMENT"})
331+
embedder = GoogleGenAIDocumentEmbedder(model=model, config=doc_config)
321332

322333
result = embedder.run(documents=docs)
323334
documents_with_embeddings = result["documents"]
@@ -331,7 +342,7 @@ def test_run(self):
331342

332343
assert result["meta"]["model"] == model
333344

334-
text_embedder = GoogleGenAITextEmbedder(model=model, config={"task_type": "RETRIEVAL_QUERY"})
345+
text_embedder = GoogleGenAITextEmbedder(model=model, config=query_config)
335346
query_embedding = text_embedder.run("capybara")["embedding"]
336347
query_vec = np.array(query_embedding)
337348

@@ -349,14 +360,13 @@ def test_run(self):
349360
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
350361
)
351362
@pytest.mark.integration
352-
async def test_run_async(self):
363+
@pytest.mark.parametrize("model", ["gemini-embedding-001", "gemini-embedding-2"])
364+
async def test_run_async(self, model):
353365
docs = [
354366
Document(content="I love cheese", meta={"topic": "Cuisine"}),
355367
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
356368
]
357369

358-
model = "gemini-embedding-001"
359-
360370
embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ")
361371

362372
result = await embedder.run_async(documents=docs)

integrations/google_genai/tests/test_multimodal_document_embedder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_to_dict(self, monkeypatch):
103103
"haystack_integrations.components.embedders.google_genai.multimodal_document_embedder.GoogleGenAIMultimodalDocumentEmbedder"
104104
),
105105
"init_parameters": {
106-
"model": "gemini-embedding-2-preview",
106+
"model": "gemini-embedding-2",
107107
"file_path_meta_field": "file_path",
108108
"root_path": None,
109109
"image_size": None,
@@ -123,7 +123,7 @@ def test_from_dict(self, monkeypatch):
123123
"haystack_integrations.components.embedders.google_genai.multimodal_document_embedder.GoogleGenAIMultimodalDocumentEmbedder"
124124
),
125125
"init_parameters": {
126-
"model": "gemini-embedding-2-preview",
126+
"model": "gemini-embedding-2",
127127
"file_path_meta_field": "file_path",
128128
"root_path": "some_root_path",
129129
"image_size": (1024, 1024),
@@ -140,7 +140,7 @@ def test_from_dict(self, monkeypatch):
140140

141141
embedder = component_from_dict(GoogleGenAIMultimodalDocumentEmbedder, data, "embedder")
142142
assert embedder._api_key.resolve_value() == "fake-api-key"
143-
assert embedder._model == "gemini-embedding-2-preview"
143+
assert embedder._model == "gemini-embedding-2"
144144
assert embedder._file_path_meta_field == "file_path"
145145
assert embedder._root_path == "some_root_path"
146146
assert embedder._image_size == (1024, 1024)
@@ -235,7 +235,7 @@ def test_run_with_mocked_client(self, test_files_path):
235235
assert len(result["documents"]) == 2
236236
for doc in result["documents"]:
237237
assert doc.embedding == [0.1, 0.2, 0.3]
238-
assert result["meta"]["model"] == "gemini-embedding-2-preview"
238+
assert result["meta"]["model"] == "gemini-embedding-2"
239239

240240
@pytest.mark.asyncio
241241
async def test_run_async_with_mocked_client(self, test_files_path):
@@ -253,7 +253,7 @@ async def test_run_async_with_mocked_client(self, test_files_path):
253253
result = await embedder.run_async(documents=docs)
254254
assert len(result["documents"]) == 1
255255
assert result["documents"][0].embedding == [0.4, 0.5, 0.6]
256-
assert result["meta"]["model"] == "gemini-embedding-2-preview"
256+
assert result["meta"]["model"] == "gemini-embedding-2"
257257

258258
@pytest.mark.integration
259259
@pytest.mark.skipif(
@@ -271,7 +271,7 @@ def test_live_run(self, test_files_path):
271271
assert len(result["documents"]) == 2
272272
for doc in result["documents"]:
273273
assert len(doc.embedding) == 3072
274-
assert result["meta"]["model"] == "gemini-embedding-2-preview"
274+
assert result["meta"]["model"] == "gemini-embedding-2"
275275

276276
@pytest.mark.integration
277277
@pytest.mark.asyncio
@@ -290,4 +290,4 @@ async def test_live_run_async(self, test_files_path):
290290
assert len(result["documents"]) == 2
291291
for doc in result["documents"]:
292292
assert len(doc.embedding) == 3072
293-
assert result["meta"]["model"] == "gemini-embedding-2-preview"
293+
assert result["meta"]["model"] == "gemini-embedding-2"

integrations/google_genai/tests/test_text_embedder.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66

77
import pytest
8-
from google.genai.types import ContentEmbedding, EmbedContentConfig, EmbedContentResponse
8+
from google.genai.types import ContentEmbedding, EmbedContentResponse
99
from haystack.utils.auth import Secret
1010

1111
from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder
@@ -20,7 +20,7 @@ def test_init_default(self, monkeypatch):
2020
assert embedder._model_name == "gemini-embedding-001"
2121
assert embedder._prefix == ""
2222
assert embedder._suffix == ""
23-
assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"}
23+
assert embedder._config is None
2424
assert embedder._api == "gemini"
2525
assert embedder._vertex_ai_project is None
2626
assert embedder._vertex_ai_location is None
@@ -50,7 +50,7 @@ def test_to_dict(self, monkeypatch):
5050
"model": "gemini-embedding-001",
5151
"prefix": "",
5252
"suffix": "",
53-
"config": {"task_type": "SEMANTIC_SIMILARITY"},
53+
"config": None,
5454
"api": "gemini",
5555
"vertex_ai_project": None,
5656
"vertex_ai_location": None,
@@ -112,14 +112,6 @@ def test_prepare_input(self, monkeypatch):
112112
assert prepared_input == {
113113
"model": "gemini-embedding-001",
114114
"contents": "The food was delicious",
115-
"config": EmbedContentConfig(
116-
http_options=None,
117-
task_type="SEMANTIC_SIMILARITY",
118-
title=None,
119-
output_dimensionality=None,
120-
mime_type=None,
121-
auto_truncate=None,
122-
),
123115
}
124116

125117
def test_prepare_output(self, monkeypatch):
@@ -149,9 +141,8 @@ def test_run_wrong_input_format(self):
149141
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
150142
)
151143
@pytest.mark.integration
152-
def test_run(self):
153-
model = "gemini-embedding-001"
154-
144+
@pytest.mark.parametrize("model", ["gemini-embedding-001", "gemini-embedding-2"])
145+
def test_run(self, model):
155146
embedder = GoogleGenAITextEmbedder(model=model)
156147
result = embedder.run(text="The food was delicious")
157148

@@ -166,9 +157,8 @@ def test_run(self):
166157
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
167158
)
168159
@pytest.mark.integration
169-
async def test_run_async(self):
170-
model = "gemini-embedding-001"
171-
160+
@pytest.mark.parametrize("model", ["gemini-embedding-001", "gemini-embedding-2"])
161+
async def test_run_async(self, model):
172162
embedder = GoogleGenAITextEmbedder(model=model)
173163
result = await embedder.run_async(text="The food was delicious")
174164

0 commit comments

Comments
 (0)