Skip to content

Commit 60a8154

Browse files
committed
feat(google-genai): add timeout and max_retries to chat generator
1 parent 6cd52fc commit 60a8154

4 files changed

Lines changed: 75 additions & 1 deletion

File tree

integrations/google_genai/src/haystack_integrations/components/common/google_genai/utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Literal
66

7-
from google.genai import Client
7+
from google.genai import Client, types
88
from haystack import logging
99
from haystack.utils import Secret
1010

@@ -16,6 +16,8 @@ def _get_client(
1616
api: Literal["gemini", "vertex"],
1717
vertex_ai_project: str | None,
1818
vertex_ai_location: str | None,
19+
timeout: float | None = None,
20+
max_retries: int | None = None,
1921
) -> Client:
2022
"""
2123
Internal utility function to get a Google GenAI client.
@@ -31,6 +33,8 @@ def _get_client(
3133
Application Default Credentials.
3234
:param vertex_ai_location: Google Cloud location for Vertex AI (e.g., "us-central1", "europe-west1"). Required
3335
when using Vertex AI with Application Default Credentials.
36+
:param timeout: Timeout for Google GenAI client calls.
37+
:param max_retries: Maximum number of retries to attempt for failed requests.
3438
3539
:returns: A Google GenAI client.
3640
@@ -43,6 +47,18 @@ def _get_client(
4347
raise ValueError(msg)
4448

4549
resolved_api_key = api_key.resolve_value()
50+
timeout_ms: int | None = None
51+
retry_options: types.HttpRetryOptions | None = None
52+
http_options: types.HttpOptions | None = None
53+
54+
if timeout is not None:
55+
timeout_ms = int(timeout * 1000)
56+
57+
if max_retries is not None:
58+
retry_options = types.HttpRetryOptions(attempts=max_retries)
59+
60+
if timeout_ms is not None or retry_options is not None:
61+
http_options = types.HttpOptions(timeout=timeout_ms, retry_options=retry_options)
4662

4763
if api == "vertex":
4864
if not resolved_api_key and not (vertex_ai_project and vertex_ai_location):
@@ -54,16 +70,27 @@ def _get_client(
5470

5571
if vertex_ai_project and vertex_ai_location:
5672
logger.info("Using vertex_ai_project and vertex_ai_location for authentication.")
73+
if http_options is not None:
74+
return Client(
75+
vertexai=True,
76+
project=vertex_ai_project,
77+
location=vertex_ai_location,
78+
http_options=http_options,
79+
)
5780
return Client(vertexai=True, project=vertex_ai_project, location=vertex_ai_location)
5881

5982
logger.info(
6083
"No vertex_ai_project or vertex_ai_location provided for Vertex AI. Using the API key for authentication."
6184
)
85+
if http_options is not None:
86+
return Client(vertexai=True, api_key=resolved_api_key, http_options=http_options)
6287
return Client(vertexai=True, api_key=resolved_api_key)
6388

6489
# Gemini API
6590
if not resolved_api_key:
6691
msg = "To use Gemini API, you must export the GOOGLE_API_KEY or GEMINI_API_KEY environment variable."
6792
raise ValueError(msg)
6893

94+
if http_options is not None:
95+
return Client(api_key=resolved_api_key, http_options=http_options)
6996
return Client(api_key=resolved_api_key)

integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def __init__(
164164
safety_settings: list[dict[str, Any]] | None = None,
165165
streaming_callback: StreamingCallbackT | None = None,
166166
tools: ToolsType | None = None,
167+
timeout: float | None = None,
168+
max_retries: int | None = None,
167169
):
168170
"""
169171
Initialize a GoogleGenAIChatGenerator instance.
@@ -197,6 +199,12 @@ def __init__(
197199
:param streaming_callback: A callback function that is called when a new token is received from the stream.
198200
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
199201
Each tool should have a unique name.
202+
:param timeout:
203+
Timeout for Google GenAI client calls. If not set, it defaults to the default set by the Google GenAI
204+
client.
205+
:param max_retries:
206+
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
207+
the Google GenAI client.
200208
"""
201209
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
202210

@@ -205,6 +213,8 @@ def __init__(
205213
api=api,
206214
vertex_ai_project=vertex_ai_project,
207215
vertex_ai_location=vertex_ai_location,
216+
timeout=timeout,
217+
max_retries=max_retries,
208218
)
209219

210220
self._api_key = api_key
@@ -216,6 +226,8 @@ def __init__(
216226
self._safety_settings = safety_settings or []
217227
self._streaming_callback = streaming_callback
218228
self._tools = tools
229+
self._timeout = timeout
230+
self._max_retries = max_retries
219231

220232
def to_dict(self) -> dict[str, Any]:
221233
"""
@@ -236,6 +248,8 @@ def to_dict(self) -> dict[str, Any]:
236248
safety_settings=self._safety_settings,
237249
streaming_callback=callback_name,
238250
tools=serialized_tools,
251+
timeout=self._timeout,
252+
max_retries=self._max_retries,
239253
)
240254

241255
@classmethod

integrations/google_genai/tests/test_chat_generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def test_init_default(self, monkeypatch):
5959
assert component._tools is None
6060
assert component._api_key is not None
6161
assert component._api_key.resolve_value() == "test-api-key"
62+
assert component._timeout is None
63+
assert component._max_retries is None
6264

6365
def test_init_fail_wo_api_key(self, monkeypatch):
6466
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
@@ -81,6 +83,8 @@ def test_init_with_parameters(self, monkeypatch):
8183
generation_kwargs={"temperature": 0.5, "max_output_tokens": 100},
8284
safety_settings=[{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}],
8385
tools=[tool],
86+
timeout=30.0,
87+
max_retries=5,
8488
)
8589
assert component._model == "gemini-2.5-flash"
8690
assert component._streaming_callback is print_streaming_chunk
@@ -89,6 +93,8 @@ def test_init_with_parameters(self, monkeypatch):
8993
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
9094
]
9195
assert component._tools == [tool]
96+
assert component._timeout == 30.0
97+
assert component._max_retries == 5
9298

9399
def test_init_with_toolset(self, tools, monkeypatch):
94100
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
@@ -140,6 +146,8 @@ def test_to_dict_with_toolset(self, tools, monkeypatch):
140146
assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset"
141147
assert "tools" in data["init_parameters"]["tools"]["data"]
142148
assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools)
149+
assert data["init_parameters"]["timeout"] is None
150+
assert data["init_parameters"]["max_retries"] is None
143151

144152
def test_from_dict_with_toolset(self, tools, monkeypatch):
145153
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")

integrations/google_genai/tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4+
from google.genai import types
45
from haystack.utils import Secret
56

67
from haystack_integrations.components.common.google_genai.utils import _get_client
@@ -57,3 +58,27 @@ def test_get_client_gemini_api_key_no_env_var_raises(monkeypatch):
5758
api_key = Secret.from_env_var("GEMINI_API_KEY", strict=False)
5859
with pytest.raises(ValueError):
5960
_get_client(api_key=api_key, api="gemini", vertex_ai_project=None, vertex_ai_location=None)
61+
62+
63+
def test_get_client_forwards_timeout_and_max_retries(monkeypatch):
64+
monkeypatch.setenv("GEMINI_API_KEY", "test-api-key")
65+
api_key = Secret.from_env_var("GEMINI_API_KEY", strict=False)
66+
67+
with patch("haystack_integrations.components.common.google_genai.utils.Client") as mock_client:
68+
client = _get_client(
69+
api_key=api_key,
70+
api="gemini",
71+
vertex_ai_project=None,
72+
vertex_ai_location=None,
73+
timeout=30.0,
74+
max_retries=5,
75+
)
76+
mock_client.assert_called_once()
77+
_, kwargs = mock_client.call_args
78+
assert kwargs["api_key"] == "test-api-key"
79+
assert "http_options" in kwargs
80+
assert isinstance(kwargs["http_options"], types.HttpOptions)
81+
assert kwargs["http_options"].timeout == 30000
82+
assert kwargs["http_options"].retry_options is not None
83+
assert kwargs["http_options"].retry_options.attempts == 5
84+
assert client is not None

0 commit comments

Comments
 (0)