Skip to content

Commit 7c13187

Browse files
committed
feat(cohere): add timeout and max_retries to chat generator
1 parent 6cd52fc commit 7c13187

3 files changed

Lines changed: 60 additions & 15 deletions

File tree

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
)
2525
from haystack.utils import Secret, deserialize_secrets_inplace
2626
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
27+
from httpx import AsyncClient as AsyncHTTPXClient
28+
from httpx import AsyncHTTPTransport, Client as HTTPXClient, HTTPTransport
2729

2830
from cohere import AsyncClientV2, ChatResponse, ClientV2, StreamedChatResponseV2
2931

@@ -502,6 +504,9 @@ def __init__(
502504
api_base_url: str | None = None,
503505
generation_kwargs: dict[str, Any] | None = None,
504506
tools: ToolsType | None = None,
507+
*,
508+
timeout: float | None = None,
509+
max_retries: int | None = None,
505510
**kwargs: Any,
506511
):
507512
"""
@@ -526,6 +531,13 @@ def __init__(
526531
mean less random generations.
527532
:param tools: A list of Tool and/or Toolset objects, or a single Toolset that the model can use.
528533
Each tool should have a unique name.
534+
:param timeout:
535+
Timeout for Cohere client calls. If not set, it defaults to the default set by the Cohere client.
536+
:param max_retries:
537+
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
538+
the Cohere client.
539+
:param kwargs:
540+
Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility.
529541
530542
"""
531543
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
@@ -534,23 +546,32 @@ def __init__(
534546
api_base_url = "https://api.cohere.com"
535547
if generation_kwargs is None:
536548
generation_kwargs = {}
549+
if kwargs:
550+
generation_kwargs = {**generation_kwargs, **kwargs}
537551
self.api_key = api_key
538552
self.model = model
539553
self.streaming_callback = streaming_callback
540554
self.api_base_url = api_base_url
541555
self.generation_kwargs = generation_kwargs
542556
self.tools = tools
543-
self.model_parameters = kwargs
544-
self.client = ClientV2(
545-
api_key=self.api_key.resolve_value(),
546-
base_url=self.api_base_url,
547-
client_name="haystack",
548-
)
549-
self.async_client = AsyncClientV2(
550-
api_key=self.api_key.resolve_value(),
551-
base_url=self.api_base_url,
552-
client_name="haystack",
553-
)
557+
self.timeout = timeout
558+
self.max_retries = max_retries
559+
560+
client_kwargs: dict[str, Any] = {
561+
"api_key": self.api_key.resolve_value(),
562+
"base_url": self.api_base_url,
563+
"client_name": "haystack",
564+
}
565+
if timeout is not None:
566+
client_kwargs["timeout"] = timeout
567+
if max_retries is not None:
568+
sync_httpx_client = HTTPXClient(transport=HTTPTransport(retries=max_retries))
569+
async_httpx_client = AsyncHTTPXClient(transport=AsyncHTTPTransport(retries=max_retries))
570+
self.client = ClientV2(**client_kwargs, httpx_client=sync_httpx_client)
571+
self.async_client = AsyncClientV2(**client_kwargs, httpx_client=async_httpx_client)
572+
else:
573+
self.client = ClientV2(**client_kwargs)
574+
self.async_client = AsyncClientV2(**client_kwargs)
554575

555576
def _get_telemetry_data(self) -> dict[str, Any]:
556577
"""
@@ -574,6 +595,8 @@ def to_dict(self) -> dict[str, Any]:
574595
api_key=self.api_key.to_dict(),
575596
generation_kwargs=self.generation_kwargs,
576597
tools=serialize_tools_or_toolset(self.tools),
598+
timeout=self.timeout,
599+
max_retries=self.max_retries,
577600
)
578601

579602
@classmethod

integrations/cohere/tests/test_chat_generator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def test_init_default(self, monkeypatch):
157157
assert component.streaming_callback is None
158158
assert component.api_base_url == "https://api.cohere.com"
159159
assert not component.generation_kwargs
160+
assert component.timeout is None
161+
assert component.max_retries is None
160162

161163
def test_init_fail_wo_api_key(self, monkeypatch):
162164
monkeypatch.delenv("COHERE_API_KEY", raising=False)
@@ -174,6 +176,8 @@ def test_init_with_parameters(self):
174176
"max_tokens": 10,
175177
"some_test_param": "test-params",
176178
},
179+
timeout=30.0,
180+
max_retries=5,
177181
)
178182
assert component.api_key == Secret.from_token("test-api-key")
179183
assert component.model == "command-nightly"
@@ -183,6 +187,8 @@ def test_init_with_parameters(self):
183187
"max_tokens": 10,
184188
"some_test_param": "test-params",
185189
}
190+
assert component.timeout == 30.0
191+
assert component.max_retries == 5
186192

187193
def test_to_dict_default(self, monkeypatch):
188194
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
@@ -201,6 +207,8 @@ def test_to_dict_default(self, monkeypatch):
201207
"api_base_url": "https://api.cohere.com",
202208
"generation_kwargs": {},
203209
"tools": None,
210+
"timeout": None,
211+
"max_retries": None,
204212
},
205213
}
206214

@@ -234,6 +242,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
234242
"some_test_param": "test-params",
235243
},
236244
"tools": None,
245+
"timeout": None,
246+
"max_retries": None,
237247
},
238248
}
239249

@@ -255,6 +265,8 @@ def test_from_dict(self, monkeypatch):
255265
"max_tokens": 10,
256266
"some_test_param": "test-params",
257267
},
268+
"timeout": None,
269+
"max_retries": None,
258270
},
259271
}
260272
component = CohereChatGenerator.from_dict(data)
@@ -265,6 +277,8 @@ def test_from_dict(self, monkeypatch):
265277
"max_tokens": 10,
266278
"some_test_param": "test-params",
267279
}
280+
assert component.timeout is None
281+
assert component.max_retries is None
268282

269283
def test_from_dict_fail_wo_env_var(self, monkeypatch):
270284
monkeypatch.delenv("COHERE_API_KEY", raising=False)
@@ -284,6 +298,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
284298
"max_tokens": 10,
285299
"some_test_param": "test-params",
286300
},
301+
"timeout": None,
302+
"max_retries": None,
287303
},
288304
}
289305
with pytest.raises(ValueError):
@@ -327,6 +343,8 @@ def test_serde_in_pipeline(self, monkeypatch):
327343
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
328344
"api_base_url": "https://api.cohere.com",
329345
"generation_kwargs": {"temperature": 0.7},
346+
"timeout": None,
347+
"max_retries": None,
330348
"tools": [
331349
{
332350
"type": "haystack.tools.tool.Tool",

integrations/cohere/tests/test_generator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_init_default(self, monkeypatch):
2222
assert component.model == "command-a-03-2025"
2323
assert component.streaming_callback is None
2424
assert component.api_base_url == COHERE_API_URL
25-
assert component.model_parameters == {}
25+
assert component.generation_kwargs == {}
2626

2727
def test_init_with_parameters(self):
2828
callback = lambda x: x # noqa: E731
@@ -38,7 +38,7 @@ def test_init_with_parameters(self):
3838
assert component.model == "command-light"
3939
assert component.streaming_callback == callback
4040
assert component.api_base_url == "test-base-url"
41-
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
41+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
4242

4343
def test_to_dict_default(self, monkeypatch):
4444
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
@@ -53,6 +53,8 @@ def test_to_dict_default(self, monkeypatch):
5353
"api_base_url": COHERE_API_URL,
5454
"generation_kwargs": {},
5555
"tools": None,
56+
"timeout": None,
57+
"max_retries": None,
5658
},
5759
}
5860

@@ -75,8 +77,10 @@ def test_to_dict_with_parameters(self, monkeypatch):
7577
"api_base_url": "test-base-url",
7678
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
7779
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
78-
"generation_kwargs": {},
80+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
7981
"tools": None,
82+
"timeout": None,
83+
"max_retries": None,
8084
},
8185
}
8286

@@ -100,7 +104,7 @@ def test_from_dict(self, monkeypatch):
100104
assert component.model == "command-a-03-2025"
101105
assert component.streaming_callback == print_streaming_chunk
102106
assert component.api_base_url == "test-base-url"
103-
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
107+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
104108

105109
@pytest.mark.skipif(
106110
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),

0 commit comments

Comments
 (0)