diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 26d07fd312..5ea706ac0b 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -24,6 +24,9 @@ ) from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from httpx import AsyncClient as AsyncHTTPXClient +from httpx import AsyncHTTPTransport, HTTPTransport +from httpx import Client as HTTPXClient from cohere import AsyncClientV2, ChatResponse, ClientV2, StreamedChatResponseV2 @@ -502,6 +505,9 @@ def __init__( api_base_url: str | None = None, generation_kwargs: dict[str, Any] | None = None, tools: ToolsType | None = None, + *, + timeout: float | None = None, + max_retries: int | None = None, **kwargs: Any, ): """ @@ -526,6 +532,13 @@ def __init__( mean less random generations. :param tools: A list of Tool and/or Toolset objects, or a single Toolset that the model can use. Each tool should have a unique name. + :param timeout: + Timeout for Cohere client calls. If not set, it defaults to the default set by the Cohere client. + :param max_retries: + Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by + the Cohere client. + :param kwargs: + Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility. """ _check_duplicate_tool_names(flatten_tools_or_toolsets(tools)) @@ -534,23 +547,32 @@ def __init__( api_base_url = "https://api.cohere.com" if generation_kwargs is None: generation_kwargs = {} + if kwargs: + generation_kwargs = {**generation_kwargs, **kwargs} self.api_key = api_key self.model = model self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.generation_kwargs = generation_kwargs self.tools = tools - self.model_parameters = kwargs - self.client = ClientV2( - api_key=self.api_key.resolve_value(), - base_url=self.api_base_url, - client_name="haystack", - ) - self.async_client = AsyncClientV2( - api_key=self.api_key.resolve_value(), - base_url=self.api_base_url, - client_name="haystack", - ) + self.timeout = timeout + self.max_retries = max_retries + + client_kwargs: dict[str, Any] = { + "api_key": self.api_key.resolve_value(), + "base_url": self.api_base_url, + "client_name": "haystack", + } + if timeout is not None: + client_kwargs["timeout"] = timeout + if max_retries is not None: + sync_httpx_client = HTTPXClient(transport=HTTPTransport(retries=max_retries)) + async_httpx_client = AsyncHTTPXClient(transport=AsyncHTTPTransport(retries=max_retries)) + self.client = ClientV2(**client_kwargs, httpx_client=sync_httpx_client) + self.async_client = AsyncClientV2(**client_kwargs, httpx_client=async_httpx_client) + else: + self.client = ClientV2(**client_kwargs) + self.async_client = AsyncClientV2(**client_kwargs) def _get_telemetry_data(self) -> dict[str, Any]: """ @@ -574,6 +596,8 @@ def to_dict(self) -> dict[str, Any]: api_key=self.api_key.to_dict(), generation_kwargs=self.generation_kwargs, tools=serialize_tools_or_toolset(self.tools), + timeout=self.timeout, + max_retries=self.max_retries, ) @classmethod diff --git a/integrations/cohere/tests/test_chat_generator.py b/integrations/cohere/tests/test_chat_generator.py index e9ad5d8725..9ecf0a7740 100644 --- a/integrations/cohere/tests/test_chat_generator.py +++ b/integrations/cohere/tests/test_chat_generator.py @@ -157,6 +157,8 @@ def test_init_default(self, monkeypatch): assert component.streaming_callback is None assert component.api_base_url == "https://api.cohere.com" assert not component.generation_kwargs + assert component.timeout is None + assert component.max_retries is None def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) @@ -174,6 +176,8 @@ def test_init_with_parameters(self): "max_tokens": 10, "some_test_param": "test-params", }, + timeout=30.0, + max_retries=5, ) assert component.api_key == Secret.from_token("test-api-key") assert component.model == "command-nightly" @@ -183,6 +187,8 @@ def test_init_with_parameters(self): "max_tokens": 10, "some_test_param": "test-params", } + assert component.timeout == 30.0 + assert component.max_retries == 5 def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") @@ -201,6 +207,8 @@ def test_to_dict_default(self, monkeypatch): "api_base_url": "https://api.cohere.com", "generation_kwargs": {}, "tools": None, + "timeout": None, + "max_retries": None, }, } @@ -234,6 +242,8 @@ def test_to_dict_with_parameters(self, monkeypatch): "some_test_param": "test-params", }, "tools": None, + "timeout": None, + "max_retries": None, }, } @@ -255,6 +265,8 @@ def test_from_dict(self, monkeypatch): "max_tokens": 10, "some_test_param": "test-params", }, + "timeout": None, + "max_retries": None, }, } component = CohereChatGenerator.from_dict(data) @@ -265,6 +277,8 @@ def test_from_dict(self, monkeypatch): "max_tokens": 10, "some_test_param": "test-params", } + assert component.timeout is None + assert component.max_retries is None def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) @@ -284,6 +298,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "max_tokens": 10, "some_test_param": "test-params", }, + "timeout": None, + "max_retries": None, }, } with pytest.raises(ValueError): @@ -327,6 +343,8 @@ def test_serde_in_pipeline(self, monkeypatch): "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "api_base_url": "https://api.cohere.com", "generation_kwargs": {"temperature": 0.7}, + "timeout": None, + "max_retries": None, "tools": [ { "type": "haystack.tools.tool.Tool", diff --git a/integrations/cohere/tests/test_generator.py b/integrations/cohere/tests/test_generator.py index 919908cffc..73048557cc 100644 --- a/integrations/cohere/tests/test_generator.py +++ b/integrations/cohere/tests/test_generator.py @@ -22,7 +22,7 @@ def test_init_default(self, monkeypatch): assert component.model == "command-a-03-2025" assert component.streaming_callback is None assert component.api_base_url == COHERE_API_URL - assert component.model_parameters == {} + assert component.generation_kwargs == {} def test_init_with_parameters(self): callback = lambda x: x # noqa: E731 @@ -38,7 +38,7 @@ def test_init_with_parameters(self): assert component.model == "command-light" assert component.streaming_callback == callback assert component.api_base_url == "test-base-url" - assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-api-key") @@ -53,6 +53,8 @@ def test_to_dict_default(self, monkeypatch): "api_base_url": COHERE_API_URL, "generation_kwargs": {}, "tools": None, + "timeout": None, + "max_retries": None, }, } @@ -75,8 +77,10 @@ def test_to_dict_with_parameters(self, monkeypatch): "api_base_url": "test-base-url", "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {}, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "tools": None, + "timeout": None, + "max_retries": None, }, } @@ -100,7 +104,7 @@ def test_from_dict(self, monkeypatch): assert component.model == "command-a-03-2025" assert component.streaming_callback == print_streaming_chunk assert component.api_base_url == "test-base-url" - assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),