Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def __init__(
The name of the model to use for calculating embeddings.
The default model is `gemini-embedding-001`.
:param prefix:
A string to add at the beginning of each text.
A string to add at the beginning of each text. It can be used to specify a task type for
`gemini-embedding-2`. For available task types, see
[Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
:param suffix:
A string to add at the end of each text.
:param batch_size:
Expand All @@ -114,9 +116,12 @@ def __init__(
:param embedding_separator:
Separator used to concatenate the metadata fields to the document text.
:param config:
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
If not specified, it defaults to `{"task_type": "SEMANTIC_SIMILARITY"}`.
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
A dictionary of keyword arguments to configure embedding content configuration.
See [Google API documentation](https://googleapis.github.io/python-genai/genai.html#genai.types.EmbedContentConfig)
for the available options.
Specifying task types in `config` does not take effect for `gemini-embedding-2`.
See [Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types) for more
information.
"""
self._api_key = api_key
self._api = api
Expand All @@ -129,7 +134,7 @@ def __init__(
self._progress_bar = progress_bar
self._meta_fields_to_embed = meta_fields_to_embed or []
self._embedding_separator = embedding_separator
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
self._config = config

self._client = _get_client(
api_key=api_key,
Expand Down Expand Up @@ -207,7 +212,8 @@ def _embed_batch(
range(0, len(texts_to_embed), batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
args: dict[str, Any] = {"model": self._model, "contents": batch}
contents = [types.Content(parts=[types.Part.from_text(text=t)]) for t in batch]
args: dict[str, Any] = {"model": self._model, "contents": contents}
if resolved_config:
args["config"] = resolved_config

Expand Down Expand Up @@ -239,7 +245,8 @@ async def _embed_batch_async(
range(0, len(texts_to_embed), batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
args: dict[str, Any] = {"model": self._model, "contents": batch}
contents = [types.Content(parts=[types.Part.from_text(text=t)]) for t in batch]
args: dict[str, Any] = {"model": self._model, "contents": contents}
if self._config:
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
file_path_meta_field: str = "file_path",
root_path: str | None = None,
image_size: tuple[int, int] | None = None,
model: str = "gemini-embedding-2-preview",
model: str = "gemini-embedding-2",
batch_size: int = 6,
progress_bar: bool = True,
config: dict[str, Any] | None = None,
Expand Down Expand Up @@ -213,11 +213,10 @@ def __init__(
:param progress_bar:
If `True`, shows a progress bar when running.
:param config:
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
A dictionary of keyword arguments to configure embedding content configuration.
You can for example set the output dimensionality of the embedding: `{"output_dimensionality": 768}`.
It also allows customizing the task type. If the task type is not specified, it defaults to
`{"task_type": "RETRIEVAL_DOCUMENT"}`.
For more information, see the [Google AI documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
See [Google API documentation](https://googleapis.github.io/python-genai/genai.html#genai.types.EmbedContentConfig)
for the available options.
"""
self._api_key = api_key
self._api = api
Expand All @@ -229,9 +228,6 @@ def __init__(
self._image_size = image_size
self._batch_size = batch_size
self._progress_bar = progress_bar

config = config or {}
config.setdefault("task_type", "RETRIEVAL_DOCUMENT")
self._config = config

self._client = _get_client(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,18 @@ def __init__(
The name of the model to use for calculating embeddings.
The default model is `gemini-embedding-001`.
:param prefix:
A string to add at the beginning of each text to embed.
A string to add at the beginning of each text. It can be used to specify a task type for
`gemini-embedding-2`. For available task types, see
[Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
:param suffix:
A string to add at the end of each text to embed.
:param config:
A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`.
If not specified, it defaults to `{"task_type": "SEMANTIC_SIMILARITY"}`.
For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types).
A dictionary of keyword arguments to configure embedding content configuration.
See [Google API documentation](https://googleapis.github.io/python-genai/genai.html#genai.types.EmbedContentConfig)
for the available options.
Specifying task types in `config` does not take effect for `gemini-embedding-2`.
See [Gemini documentation](https://ai.google.dev/gemini-api/docs/embeddings#task-types) for more
information.
"""

self._api_key = api_key
Expand All @@ -115,7 +120,7 @@ def __init__(
self._model_name = model
self._prefix = prefix
self._suffix = suffix
self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"}
self._config = config
self._client = _get_client(
api_key=api_key,
api=api,
Expand Down
38 changes: 24 additions & 14 deletions integrations/google_genai/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_init_default(self, monkeypatch):
assert embedder._progress_bar is True
assert embedder._meta_fields_to_embed == []
assert embedder._embedding_separator == "\n"
assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"}
assert embedder._config is None

def test_init_with_parameters(self, monkeypatch):
embedder = GoogleGenAIDocumentEmbedder(
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_to_dict(self, monkeypatch):
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY", "GEMINI_API_KEY"], "strict": False},
"config": {"task_type": "SEMANTIC_SIMILARITY"},
"config": None,
"api": "gemini",
"vertex_ai_project": None,
"vertex_ai_location": None,
Expand Down Expand Up @@ -277,8 +277,11 @@ def test_embed_batch_passes_full_texts(self, monkeypatch):

calls = embedder._client.models.embed_content.call_args_list
assert len(calls) == 2
assert calls[0].kwargs["contents"] == ["first document text", "second document text"]
assert calls[1].kwargs["contents"] == ["third document text"]
assert [c.parts[0].text for c in calls[0].kwargs["contents"]] == [
"first document text",
"second document text",
]
assert [c.parts[0].text for c in calls[1].kwargs["contents"]] == ["third document text"]

@pytest.mark.asyncio
async def test_embed_batch_async_passes_full_texts(self, monkeypatch):
Expand All @@ -300,24 +303,32 @@ async def test_embed_batch_async_passes_full_texts(self, monkeypatch):

calls = embedder._client.aio.models.embed_content.call_args_list
assert len(calls) == 2
assert calls[0].kwargs["contents"] == ["first document text", "second document text"]
assert calls[1].kwargs["contents"] == ["third document text"]
assert [c.parts[0].text for c in calls[0].kwargs["contents"]] == [
"first document text",
"second document text",
]
assert [c.parts[0].text for c in calls[1].kwargs["contents"]] == ["third document text"]

@pytest.mark.skipif(
not os.environ.get("GOOGLE_API_KEY", None),
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
)
@pytest.mark.integration
def test_run(self):
model = "gemini-embedding-001"

@pytest.mark.parametrize(
"model,doc_config,query_config",
[
("gemini-embedding-001", {"task_type": "RETRIEVAL_DOCUMENT"}, {"task_type": "RETRIEVAL_QUERY"}),
("gemini-embedding-2", None, None),
],
)
def test_run(self, model, doc_config, query_config):
docs = [
Document(content="The capybara is the largest rodent in the world and lives near rivers in South America."),
Document(content="Dogs are domesticated mammals known for their loyalty and bond with humans."),
Document(content="The tiger is the largest big cat, recognized by its orange coat with black stripes."),
]

embedder = GoogleGenAIDocumentEmbedder(model=model, config={"task_type": "RETRIEVAL_DOCUMENT"})
embedder = GoogleGenAIDocumentEmbedder(model=model, config=doc_config)

result = embedder.run(documents=docs)
documents_with_embeddings = result["documents"]
Expand All @@ -331,7 +342,7 @@ def test_run(self):

assert result["meta"]["model"] == model

text_embedder = GoogleGenAITextEmbedder(model=model, config={"task_type": "RETRIEVAL_QUERY"})
text_embedder = GoogleGenAITextEmbedder(model=model, config=query_config)
query_embedding = text_embedder.run("capybara")["embedding"]
query_vec = np.array(query_embedding)

Expand All @@ -349,14 +360,13 @@ def test_run(self):
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
)
@pytest.mark.integration
async def test_run_async(self):
@pytest.mark.parametrize("model", ["gemini-embedding-001", "gemini-embedding-2"])
async def test_run_async(self, model):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]

model = "gemini-embedding-001"

embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ")

result = await embedder.run_async(documents=docs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_to_dict(self, monkeypatch):
"haystack_integrations.components.embedders.google_genai.multimodal_document_embedder.GoogleGenAIMultimodalDocumentEmbedder"
),
"init_parameters": {
"model": "gemini-embedding-2-preview",
"model": "gemini-embedding-2",
"file_path_meta_field": "file_path",
"root_path": None,
"image_size": None,
Expand All @@ -123,7 +123,7 @@ def test_from_dict(self, monkeypatch):
"haystack_integrations.components.embedders.google_genai.multimodal_document_embedder.GoogleGenAIMultimodalDocumentEmbedder"
),
"init_parameters": {
"model": "gemini-embedding-2-preview",
"model": "gemini-embedding-2",
"file_path_meta_field": "file_path",
"root_path": "some_root_path",
"image_size": (1024, 1024),
Expand All @@ -140,7 +140,7 @@ def test_from_dict(self, monkeypatch):

embedder = component_from_dict(GoogleGenAIMultimodalDocumentEmbedder, data, "embedder")
assert embedder._api_key.resolve_value() == "fake-api-key"
assert embedder._model == "gemini-embedding-2-preview"
assert embedder._model == "gemini-embedding-2"
assert embedder._file_path_meta_field == "file_path"
assert embedder._root_path == "some_root_path"
assert embedder._image_size == (1024, 1024)
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_run_with_mocked_client(self, test_files_path):
assert len(result["documents"]) == 2
for doc in result["documents"]:
assert doc.embedding == [0.1, 0.2, 0.3]
assert result["meta"]["model"] == "gemini-embedding-2-preview"
assert result["meta"]["model"] == "gemini-embedding-2"

@pytest.mark.asyncio
async def test_run_async_with_mocked_client(self, test_files_path):
Expand All @@ -253,7 +253,7 @@ async def test_run_async_with_mocked_client(self, test_files_path):
result = await embedder.run_async(documents=docs)
assert len(result["documents"]) == 1
assert result["documents"][0].embedding == [0.4, 0.5, 0.6]
assert result["meta"]["model"] == "gemini-embedding-2-preview"
assert result["meta"]["model"] == "gemini-embedding-2"

@pytest.mark.integration
@pytest.mark.skipif(
Expand All @@ -271,7 +271,7 @@ def test_live_run(self, test_files_path):
assert len(result["documents"]) == 2
for doc in result["documents"]:
assert len(doc.embedding) == 3072
assert result["meta"]["model"] == "gemini-embedding-2-preview"
assert result["meta"]["model"] == "gemini-embedding-2"

@pytest.mark.integration
@pytest.mark.asyncio
Expand All @@ -290,4 +290,4 @@ async def test_live_run_async(self, test_files_path):
assert len(result["documents"]) == 2
for doc in result["documents"]:
assert len(doc.embedding) == 3072
assert result["meta"]["model"] == "gemini-embedding-2-preview"
assert result["meta"]["model"] == "gemini-embedding-2"
24 changes: 7 additions & 17 deletions integrations/google_genai/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os

import pytest
from google.genai.types import ContentEmbedding, EmbedContentConfig, EmbedContentResponse
from google.genai.types import ContentEmbedding, EmbedContentResponse
from haystack.utils.auth import Secret

from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder
Expand All @@ -20,7 +20,7 @@ def test_init_default(self, monkeypatch):
assert embedder._model_name == "gemini-embedding-001"
assert embedder._prefix == ""
assert embedder._suffix == ""
assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"}
assert embedder._config is None
assert embedder._api == "gemini"
assert embedder._vertex_ai_project is None
assert embedder._vertex_ai_location is None
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_to_dict(self, monkeypatch):
"model": "gemini-embedding-001",
"prefix": "",
"suffix": "",
"config": {"task_type": "SEMANTIC_SIMILARITY"},
"config": None,
"api": "gemini",
"vertex_ai_project": None,
"vertex_ai_location": None,
Expand Down Expand Up @@ -112,14 +112,6 @@ def test_prepare_input(self, monkeypatch):
assert prepared_input == {
"model": "gemini-embedding-001",
"contents": "The food was delicious",
"config": EmbedContentConfig(
http_options=None,
task_type="SEMANTIC_SIMILARITY",
title=None,
output_dimensionality=None,
mime_type=None,
auto_truncate=None,
),
}

def test_prepare_output(self, monkeypatch):
Expand Down Expand Up @@ -149,9 +141,8 @@ def test_run_wrong_input_format(self):
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
)
@pytest.mark.integration
def test_run(self):
model = "gemini-embedding-001"

@pytest.mark.parametrize("model", ["gemini-embedding-001", "gemini-embedding-2"])
def test_run(self, model):
embedder = GoogleGenAITextEmbedder(model=model)
result = embedder.run(text="The food was delicious")

Expand All @@ -166,9 +157,8 @@ def test_run(self):
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
)
@pytest.mark.integration
async def test_run_async(self):
model = "gemini-embedding-001"

@pytest.mark.parametrize("model", ["gemini-embedding-001", "gemini-embedding-2"])
async def test_run_async(self, model):
embedder = GoogleGenAITextEmbedder(model=model)
result = await embedder.run_async(text="The food was delicious")

Expand Down
Loading