Skip to content

Commit c8278f3

Browse files
feat(cohere): add timeout and max_retries to chat generator (#2873)
1 parent cd54e91 commit c8278f3

3 files changed

Lines changed: 61 additions & 15 deletions

File tree

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
)
2525
from haystack.utils import Secret, deserialize_secrets_inplace
2626
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
27+
from httpx import AsyncClient as AsyncHTTPXClient
28+
from httpx import AsyncHTTPTransport, HTTPTransport
29+
from httpx import Client as HTTPXClient
2730

2831
from cohere import AsyncClientV2, ChatResponse, ClientV2, StreamedChatResponseV2
2932

@@ -502,6 +505,9 @@ def __init__(
502505
api_base_url: str | None = None,
503506
generation_kwargs: dict[str, Any] | None = None,
504507
tools: ToolsType | None = None,
508+
*,
509+
timeout: float | None = None,
510+
max_retries: int | None = None,
505511
**kwargs: Any,
506512
):
507513
"""
@@ -526,6 +532,13 @@ def __init__(
526532
mean less random generations.
527533
:param tools: A list of Tool and/or Toolset objects, or a single Toolset that the model can use.
528534
Each tool should have a unique name.
535+
:param timeout:
536+
Timeout for Cohere client calls. If not set, it defaults to the default set by the Cohere client.
537+
:param max_retries:
538+
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
539+
the Cohere client.
540+
:param kwargs:
541+
Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility.
529542
530543
"""
531544
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
@@ -534,23 +547,32 @@ def __init__(
534547
api_base_url = "https://api.cohere.com"
535548
if generation_kwargs is None:
536549
generation_kwargs = {}
550+
if kwargs:
551+
generation_kwargs = {**generation_kwargs, **kwargs}
537552
self.api_key = api_key
538553
self.model = model
539554
self.streaming_callback = streaming_callback
540555
self.api_base_url = api_base_url
541556
self.generation_kwargs = generation_kwargs
542557
self.tools = tools
543-
self.model_parameters = kwargs
544-
self.client = ClientV2(
545-
api_key=self.api_key.resolve_value(),
546-
base_url=self.api_base_url,
547-
client_name="haystack",
548-
)
549-
self.async_client = AsyncClientV2(
550-
api_key=self.api_key.resolve_value(),
551-
base_url=self.api_base_url,
552-
client_name="haystack",
553-
)
558+
self.timeout = timeout
559+
self.max_retries = max_retries
560+
561+
client_kwargs: dict[str, Any] = {
562+
"api_key": self.api_key.resolve_value(),
563+
"base_url": self.api_base_url,
564+
"client_name": "haystack",
565+
}
566+
if timeout is not None:
567+
client_kwargs["timeout"] = timeout
568+
if max_retries is not None:
569+
sync_httpx_client = HTTPXClient(transport=HTTPTransport(retries=max_retries))
570+
async_httpx_client = AsyncHTTPXClient(transport=AsyncHTTPTransport(retries=max_retries))
571+
self.client = ClientV2(**client_kwargs, httpx_client=sync_httpx_client)
572+
self.async_client = AsyncClientV2(**client_kwargs, httpx_client=async_httpx_client)
573+
else:
574+
self.client = ClientV2(**client_kwargs)
575+
self.async_client = AsyncClientV2(**client_kwargs)
554576

555577
def _get_telemetry_data(self) -> dict[str, Any]:
556578
"""
@@ -574,6 +596,8 @@ def to_dict(self) -> dict[str, Any]:
574596
api_key=self.api_key.to_dict(),
575597
generation_kwargs=self.generation_kwargs,
576598
tools=serialize_tools_or_toolset(self.tools),
599+
timeout=self.timeout,
600+
max_retries=self.max_retries,
577601
)
578602

579603
@classmethod

integrations/cohere/tests/test_chat_generator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def test_init_default(self, monkeypatch):
157157
assert component.streaming_callback is None
158158
assert component.api_base_url == "https://api.cohere.com"
159159
assert not component.generation_kwargs
160+
assert component.timeout is None
161+
assert component.max_retries is None
160162

161163
def test_init_fail_wo_api_key(self, monkeypatch):
162164
monkeypatch.delenv("COHERE_API_KEY", raising=False)
@@ -174,6 +176,8 @@ def test_init_with_parameters(self):
174176
"max_tokens": 10,
175177
"some_test_param": "test-params",
176178
},
179+
timeout=30.0,
180+
max_retries=5,
177181
)
178182
assert component.api_key == Secret.from_token("test-api-key")
179183
assert component.model == "command-nightly"
@@ -183,6 +187,8 @@ def test_init_with_parameters(self):
183187
"max_tokens": 10,
184188
"some_test_param": "test-params",
185189
}
190+
assert component.timeout == 30.0
191+
assert component.max_retries == 5
186192

187193
def test_to_dict_default(self, monkeypatch):
188194
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
@@ -201,6 +207,8 @@ def test_to_dict_default(self, monkeypatch):
201207
"api_base_url": "https://api.cohere.com",
202208
"generation_kwargs": {},
203209
"tools": None,
210+
"timeout": None,
211+
"max_retries": None,
204212
},
205213
}
206214

@@ -234,6 +242,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
234242
"some_test_param": "test-params",
235243
},
236244
"tools": None,
245+
"timeout": None,
246+
"max_retries": None,
237247
},
238248
}
239249

@@ -255,6 +265,8 @@ def test_from_dict(self, monkeypatch):
255265
"max_tokens": 10,
256266
"some_test_param": "test-params",
257267
},
268+
"timeout": None,
269+
"max_retries": None,
258270
},
259271
}
260272
component = CohereChatGenerator.from_dict(data)
@@ -265,6 +277,8 @@ def test_from_dict(self, monkeypatch):
265277
"max_tokens": 10,
266278
"some_test_param": "test-params",
267279
}
280+
assert component.timeout is None
281+
assert component.max_retries is None
268282

269283
def test_from_dict_fail_wo_env_var(self, monkeypatch):
270284
monkeypatch.delenv("COHERE_API_KEY", raising=False)
@@ -284,6 +298,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
284298
"max_tokens": 10,
285299
"some_test_param": "test-params",
286300
},
301+
"timeout": None,
302+
"max_retries": None,
287303
},
288304
}
289305
with pytest.raises(ValueError):
@@ -327,6 +343,8 @@ def test_serde_in_pipeline(self, monkeypatch):
327343
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
328344
"api_base_url": "https://api.cohere.com",
329345
"generation_kwargs": {"temperature": 0.7},
346+
"timeout": None,
347+
"max_retries": None,
330348
"tools": [
331349
{
332350
"type": "haystack.tools.tool.Tool",

integrations/cohere/tests/test_generator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_init_default(self, monkeypatch):
2222
assert component.model == "command-a-03-2025"
2323
assert component.streaming_callback is None
2424
assert component.api_base_url == COHERE_API_URL
25-
assert component.model_parameters == {}
25+
assert component.generation_kwargs == {}
2626

2727
def test_init_with_parameters(self):
2828
callback = lambda x: x # noqa: E731
@@ -38,7 +38,7 @@ def test_init_with_parameters(self):
3838
assert component.model == "command-light"
3939
assert component.streaming_callback == callback
4040
assert component.api_base_url == "test-base-url"
41-
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
41+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
4242

4343
def test_to_dict_default(self, monkeypatch):
4444
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
@@ -53,6 +53,8 @@ def test_to_dict_default(self, monkeypatch):
5353
"api_base_url": COHERE_API_URL,
5454
"generation_kwargs": {},
5555
"tools": None,
56+
"timeout": None,
57+
"max_retries": None,
5658
},
5759
}
5860

@@ -75,8 +77,10 @@ def test_to_dict_with_parameters(self, monkeypatch):
7577
"api_base_url": "test-base-url",
7678
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
7779
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
78-
"generation_kwargs": {},
80+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
7981
"tools": None,
82+
"timeout": None,
83+
"max_retries": None,
8084
},
8185
}
8286

@@ -100,7 +104,7 @@ def test_from_dict(self, monkeypatch):
100104
assert component.model == "command-a-03-2025"
101105
assert component.streaming_callback == print_streaming_chunk
102106
assert component.api_base_url == "test-base-url"
103-
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
107+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
104108

105109
@pytest.mark.skipif(
106110
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),

0 commit comments

Comments
 (0)