Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Literal

from google.genai import Client
from google.genai import Client, types
from haystack import logging
from haystack.utils import Secret

Expand All @@ -16,6 +16,8 @@ def _get_client(
api: Literal["gemini", "vertex"],
vertex_ai_project: str | None,
vertex_ai_location: str | None,
timeout: float | None = None,
max_retries: int | None = None,
) -> Client:
"""
Internal utility function to get a Google GenAI client.
Expand All @@ -31,6 +33,8 @@ def _get_client(
Application Default Credentials.
:param vertex_ai_location: Google Cloud location for Vertex AI (e.g., "us-central1", "europe-west1"). Required
when using Vertex AI with Application Default Credentials.
:param timeout: Timeout for Google GenAI client calls.
:param max_retries: Maximum number of retries to attempt for failed requests.

:returns: A Google GenAI client.

Expand All @@ -43,6 +47,18 @@ def _get_client(
raise ValueError(msg)

resolved_api_key = api_key.resolve_value()
timeout_ms: int | None = None
retry_options: types.HttpRetryOptions | None = None
http_options: types.HttpOptions | None = None

if timeout is not None:
timeout_ms = int(timeout * 1000)

if max_retries is not None:
retry_options = types.HttpRetryOptions(attempts=max_retries)

if timeout_ms is not None or retry_options is not None:
http_options = types.HttpOptions(timeout=timeout_ms, retry_options=retry_options)

if api == "vertex":
if not resolved_api_key and not (vertex_ai_project and vertex_ai_location):
Expand All @@ -54,16 +70,21 @@ def _get_client(

if vertex_ai_project and vertex_ai_location:
logger.info("Using vertex_ai_project and vertex_ai_location for authentication.")
return Client(vertexai=True, project=vertex_ai_project, location=vertex_ai_location)
return Client(
vertexai=True,
project=vertex_ai_project,
location=vertex_ai_location,
http_options=http_options,
)

logger.info(
"No vertex_ai_project or vertex_ai_location provided for Vertex AI. Using the API key for authentication."
)
return Client(vertexai=True, api_key=resolved_api_key)
return Client(vertexai=True, api_key=resolved_api_key, http_options=http_options)

# Gemini API
if not resolved_api_key:
msg = "To use Gemini API, you must export the GOOGLE_API_KEY or GEMINI_API_KEY environment variable."
raise ValueError(msg)

return Client(api_key=resolved_api_key)
return Client(api_key=resolved_api_key, http_options=http_options)
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def __init__(
safety_settings: list[dict[str, Any]] | None = None,
streaming_callback: StreamingCallbackT | None = None,
tools: ToolsType | None = None,
timeout: float | None = None,
max_retries: int | None = None,
):
"""
Initialize a GoogleGenAIChatGenerator instance.
Expand Down Expand Up @@ -197,6 +199,12 @@ def __init__(
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
:param timeout:
Timeout for Google GenAI client calls. If not set, it defaults to the default set by the Google GenAI
client.
:param max_retries:
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
the Google GenAI client.
"""
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))

Expand All @@ -205,6 +213,8 @@ def __init__(
api=api,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
timeout=timeout,
max_retries=max_retries,
)

self._api_key = api_key
Expand All @@ -216,6 +226,8 @@ def __init__(
self._safety_settings = safety_settings or []
self._streaming_callback = streaming_callback
self._tools = tools
self._timeout = timeout
self._max_retries = max_retries

def to_dict(self) -> dict[str, Any]:
"""
Expand All @@ -236,6 +248,8 @@ def to_dict(self) -> dict[str, Any]:
safety_settings=self._safety_settings,
streaming_callback=callback_name,
tools=serialized_tools,
timeout=self._timeout,
max_retries=self._max_retries,
)

@classmethod
Expand Down
8 changes: 8 additions & 0 deletions integrations/google_genai/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def test_init_default(self, monkeypatch):
assert component._tools is None
assert component._api_key is not None
assert component._api_key.resolve_value() == "test-api-key"
assert component._timeout is None
assert component._max_retries is None

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
Expand All @@ -81,6 +83,8 @@ def test_init_with_parameters(self, monkeypatch):
generation_kwargs={"temperature": 0.5, "max_output_tokens": 100},
safety_settings=[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}],
tools=[tool],
timeout=30.0,
max_retries=5,
)
assert component._model == "gemini-2.5-flash"
assert component._streaming_callback is print_streaming_chunk
Expand All @@ -89,6 +93,8 @@ def test_init_with_parameters(self, monkeypatch):
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
]
assert component._tools == [tool]
assert component._timeout == 30.0
assert component._max_retries == 5

def test_init_with_toolset(self, tools, monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
Expand Down Expand Up @@ -140,6 +146,8 @@ def test_to_dict_with_toolset(self, tools, monkeypatch):
assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset"
assert "tools" in data["init_parameters"]["tools"]["data"]
assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools)
assert data["init_parameters"]["timeout"] is None
assert data["init_parameters"]["max_retries"] is None

def test_from_dict_with_toolset(self, tools, monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
Expand Down
33 changes: 30 additions & 3 deletions integrations/google_genai/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch

import pytest
from google.genai import types
from haystack.utils import Secret

from haystack_integrations.components.common.google_genai.utils import _get_client
Expand Down Expand Up @@ -28,7 +29,9 @@ def test_get_client_vertex_project_and_location(monkeypatch):
client = _get_client(
api_key=api_key, api="vertex", vertex_ai_project="test-project", vertex_ai_location="test-location"
)
mock_client.assert_called_once_with(vertexai=True, project="test-project", location="test-location")
mock_client.assert_called_once_with(
vertexai=True, project="test-project", location="test-location", http_options=None
)
assert client is not None


Expand All @@ -38,7 +41,7 @@ def test_get_client_vertex_api_key(monkeypatch):

with patch("haystack_integrations.components.common.google_genai.utils.Client") as mock_client:
client = _get_client(api_key=api_key, api="vertex", vertex_ai_project=None, vertex_ai_location=None)
mock_client.assert_called_once_with(vertexai=True, api_key="test-api-key")
mock_client.assert_called_once_with(vertexai=True, api_key="test-api-key", http_options=None)
assert client is not None


Expand All @@ -48,7 +51,7 @@ def test_get_client_gemini_api_key(monkeypatch):

with patch("haystack_integrations.components.common.google_genai.utils.Client") as mock_client:
client = _get_client(api_key=api_key, api="gemini", vertex_ai_project=None, vertex_ai_location=None)
mock_client.assert_called_once_with(api_key="test-api-key")
mock_client.assert_called_once_with(api_key="test-api-key", http_options=None)
assert client is not None


Expand All @@ -57,3 +60,27 @@ def test_get_client_gemini_api_key_no_env_var_raises(monkeypatch):
api_key = Secret.from_env_var("GEMINI_API_KEY", strict=False)
with pytest.raises(ValueError):
_get_client(api_key=api_key, api="gemini", vertex_ai_project=None, vertex_ai_location=None)


def test_get_client_forwards_timeout_and_max_retries(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "test-api-key")
api_key = Secret.from_env_var("GEMINI_API_KEY", strict=False)

with patch("haystack_integrations.components.common.google_genai.utils.Client") as mock_client:
client = _get_client(
api_key=api_key,
api="gemini",
vertex_ai_project=None,
vertex_ai_location=None,
timeout=30.0,
max_retries=5,
)
mock_client.assert_called_once()
_, kwargs = mock_client.call_args
assert kwargs["api_key"] == "test-api-key"
assert "http_options" in kwargs
assert isinstance(kwargs["http_options"], types.HttpOptions)
assert kwargs["http_options"].timeout == 30000
assert kwargs["http_options"].retry_options is not None
assert kwargs["http_options"].retry_options.attempts == 5
assert client is not None