2424)
2525from haystack .utils import Secret , deserialize_secrets_inplace
2626from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
27+ from httpx import AsyncClient as AsyncHTTPXClient
28+ from httpx import AsyncHTTPTransport , HTTPTransport
29+ from httpx import Client as HTTPXClient
2730
2831from cohere import AsyncClientV2 , ChatResponse , ClientV2 , StreamedChatResponseV2
2932
@@ -502,6 +505,9 @@ def __init__(
502505 api_base_url : str | None = None ,
503506 generation_kwargs : dict [str , Any ] | None = None ,
504507 tools : ToolsType | None = None ,
508+ * ,
509+ timeout : float | None = None ,
510+ max_retries : int | None = None ,
505511 ** kwargs : Any ,
506512 ):
507513 """
@@ -526,6 +532,13 @@ def __init__(
526532 mean less random generations.
527533 :param tools: A list of Tool and/or Toolset objects, or a single Toolset that the model can use.
528534 Each tool should have a unique name.
535+ :param timeout:
536+ Timeout for Cohere client calls. If not set, it defaults to the default set by the Cohere client.
537+ :param max_retries:
538+ Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
539+ the Cohere client.
540+ :param kwargs:
541+ Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility.
529542
530543 """
531544 _check_duplicate_tool_names (flatten_tools_or_toolsets (tools ))
@@ -534,23 +547,32 @@ def __init__(
534547 api_base_url = "https://api.cohere.com"
535548 if generation_kwargs is None :
536549 generation_kwargs = {}
550+ if kwargs :
551+ generation_kwargs = {** generation_kwargs , ** kwargs }
537552 self .api_key = api_key
538553 self .model = model
539554 self .streaming_callback = streaming_callback
540555 self .api_base_url = api_base_url
541556 self .generation_kwargs = generation_kwargs
542557 self .tools = tools
543- self .model_parameters = kwargs
544- self .client = ClientV2 (
545- api_key = self .api_key .resolve_value (),
546- base_url = self .api_base_url ,
547- client_name = "haystack" ,
548- )
549- self .async_client = AsyncClientV2 (
550- api_key = self .api_key .resolve_value (),
551- base_url = self .api_base_url ,
552- client_name = "haystack" ,
553- )
558+ self .timeout = timeout
559+ self .max_retries = max_retries
560+
561+ client_kwargs : dict [str , Any ] = {
562+ "api_key" : self .api_key .resolve_value (),
563+ "base_url" : self .api_base_url ,
564+ "client_name" : "haystack" ,
565+ }
566+ if timeout is not None :
567+ client_kwargs ["timeout" ] = timeout
568+ if max_retries is not None :
569+ sync_httpx_client = HTTPXClient (transport = HTTPTransport (retries = max_retries ))
570+ async_httpx_client = AsyncHTTPXClient (transport = AsyncHTTPTransport (retries = max_retries ))
571+ self .client = ClientV2 (** client_kwargs , httpx_client = sync_httpx_client )
572+ self .async_client = AsyncClientV2 (** client_kwargs , httpx_client = async_httpx_client )
573+ else :
574+ self .client = ClientV2 (** client_kwargs )
575+ self .async_client = AsyncClientV2 (** client_kwargs )
554576
555577 def _get_telemetry_data (self ) -> dict [str , Any ]:
556578 """
@@ -574,6 +596,8 @@ def to_dict(self) -> dict[str, Any]:
574596 api_key = self .api_key .to_dict (),
575597 generation_kwargs = self .generation_kwargs ,
576598 tools = serialize_tools_or_toolset (self .tools ),
599+ timeout = self .timeout ,
600+ max_retries = self .max_retries ,
577601 )
578602
579603 @classmethod
0 commit comments