2222 flatten_tools_or_toolsets ,
2323 serialize_tools_or_toolset ,
2424)
25- from haystack .utils import Secret , deserialize_secrets_inplace
25+ from haystack .utils import Secret
2626from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
2727from httpx import AsyncClient as AsyncHTTPXClient
2828from httpx import AsyncHTTPTransport , HTTPTransport
@@ -508,7 +508,6 @@ def __init__(
508508 * ,
509509 timeout : float | None = None ,
510510 max_retries : int | None = None ,
511- ** kwargs : Any ,
512511 ):
513512 """
514513 Initialize the CohereChatGenerator instance.
@@ -537,23 +536,18 @@ def __init__(
537536 :param max_retries:
538537 Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
539538 the Cohere client.
540- :param kwargs:
541- Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility.
542539
543540 """
544541 _check_duplicate_tool_names (flatten_tools_or_toolsets (tools ))
545542
546543 if not api_base_url :
547544 api_base_url = "https://api.cohere.com"
548- if generation_kwargs is None :
549- generation_kwargs = {}
550- if kwargs :
551- generation_kwargs = {** generation_kwargs , ** kwargs }
545+
552546 self .api_key = api_key
553547 self .model = model
554548 self .streaming_callback = streaming_callback
555549 self .api_base_url = api_base_url
556- self .generation_kwargs = generation_kwargs
550+ self .generation_kwargs = generation_kwargs or {}
557551 self .tools = tools
558552 self .timeout = timeout
559553 self .max_retries = max_retries
@@ -565,14 +559,15 @@ def __init__(
565559 }
566560 if timeout is not None :
567561 client_kwargs ["timeout" ] = timeout
562+
563+ sync_kwargs = {** client_kwargs }
564+ async_kwargs = {** client_kwargs }
568565 if max_retries is not None :
569- sync_httpx_client = HTTPXClient (transport = HTTPTransport (retries = max_retries ))
570- async_httpx_client = AsyncHTTPXClient (transport = AsyncHTTPTransport (retries = max_retries ))
571- self .client = ClientV2 (** client_kwargs , httpx_client = sync_httpx_client )
572- self .async_client = AsyncClientV2 (** client_kwargs , httpx_client = async_httpx_client )
573- else :
574- self .client = ClientV2 (** client_kwargs )
575- self .async_client = AsyncClientV2 (** client_kwargs )
566+ sync_kwargs ["httpx_client" ] = HTTPXClient (transport = HTTPTransport (retries = max_retries ))
567+ async_kwargs ["httpx_client" ] = AsyncHTTPXClient (transport = AsyncHTTPTransport (retries = max_retries ))
568+
569+ self .client = ClientV2 (** sync_kwargs )
570+ self .async_client = AsyncClientV2 (** async_kwargs )
576571
577572 def _get_telemetry_data (self ) -> dict [str , Any ]:
578573 """
@@ -593,7 +588,7 @@ def to_dict(self) -> dict[str, Any]:
593588 model = self .model ,
594589 streaming_callback = callback_name ,
595590 api_base_url = self .api_base_url ,
596- api_key = self .api_key . to_dict () ,
591+ api_key = self .api_key ,
597592 generation_kwargs = self .generation_kwargs ,
598593 tools = serialize_tools_or_toolset (self .tools ),
599594 timeout = self .timeout ,
@@ -611,7 +606,6 @@ def from_dict(cls, data: dict[str, Any]) -> "CohereChatGenerator":
611606 Deserialized component.
612607 """
613608 init_params = data .get ("init_parameters" , {})
614- deserialize_secrets_inplace (init_params , ["api_key" ])
615609 deserialize_tools_or_toolset_inplace (init_params , key = "tools" )
616610 serialized_callback_handler = init_params .get ("streaming_callback" )
617611 if serialized_callback_handler :
0 commit comments