11import json
2- from typing import Any , Callable , Dict , Generator , List , Optional
2+ from typing import Any , Callable , Dict , Generator , List , Optional , Union
33
44from haystack import component , default_from_dict , default_to_dict , logging
55from haystack .dataclasses import ChatMessage , StreamingChunk , ToolCall
66from haystack .lazy_imports import LazyImport
7- from haystack .tools import Tool , _check_duplicate_tool_names
7+ from haystack .tools import (
8+ Tool ,
9+ Toolset ,
10+ _check_duplicate_tool_names ,
11+ deserialize_tools_or_toolset_inplace ,
12+ serialize_tools_or_toolset ,
13+ )
814from haystack .utils import Secret , deserialize_secrets_inplace
915from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
1016
11- # Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
12- try :
13- from haystack .tools import deserialize_tools_or_toolset_inplace
14- except ImportError :
15- from haystack .tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
16-
1717from cohere import ChatResponse
1818
1919with LazyImport (message = "Run 'pip install cohere'" ) as cohere_import :
@@ -300,7 +300,7 @@ def __init__(
300300 streaming_callback : Optional [Callable [[StreamingChunk ], None ]] = None ,
301301 api_base_url : Optional [str ] = None ,
302302 generation_kwargs : Optional [Dict [str , Any ]] = None ,
303- tools : Optional [List [Tool ]] = None ,
303+ tools : Optional [Union [ List [Tool ], Toolset ]] = None ,
304304 ** kwargs ,
305305 ):
306306 """
@@ -323,10 +323,11 @@ def __init__(
323323 `accurate` results or `fast` results.
324324 - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures
325325 mean less random generations.
326- :param tools: A list of Tool objects that the model can use. Each tool should have a unique name.
326+ :param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
327+
327328 """
328329 cohere_import .check ()
329- _check_duplicate_tool_names (tools )
330+ _check_duplicate_tool_names (list ( tools or [])) # handles Toolset as well
330331
331332 if not api_base_url :
332333 api_base_url = "https://api.cohere.com"
@@ -357,15 +358,14 @@ def to_dict(self) -> Dict[str, Any]:
357358 Dictionary with serialized data.
358359 """
359360 callback_name = serialize_callable (self .streaming_callback ) if self .streaming_callback else None
360- serialized_tools = [tool .to_dict () for tool in self .tools ] if self .tools else None
361361 return default_to_dict (
362362 self ,
363363 model = self .model ,
364364 streaming_callback = callback_name ,
365365 api_base_url = self .api_base_url ,
366366 api_key = self .api_key .to_dict (),
367367 generation_kwargs = self .generation_kwargs ,
368- tools = serialized_tools ,
368+ tools = serialize_tools_or_toolset ( self . tools ) ,
369369 )
370370
371371 @classmethod
@@ -391,7 +391,7 @@ def run(
391391 self ,
392392 messages : List [ChatMessage ],
393393 generation_kwargs : Optional [Dict [str , Any ]] = None ,
394- tools : Optional [List [Tool ]] = None ,
394+ tools : Optional [Union [ List [Tool ], Toolset ]] = None ,
395395 ):
396396 """
397397 Invoke the chat endpoint based on the provided messages and generation parameters.
@@ -401,7 +401,7 @@ def run(
401401 potentially override the parameters passed in the __init__ method.
402402 For more details on the parameters supported by the Cohere API, refer to the
403403 Cohere [documentation](https://docs.cohere.com/reference/chat).
404- :param tools: A list of tools for which the model can prepare calls. If set, it will override
404+ :param tools: A list of tools or a Toolset for which the model can prepare calls. If set, it will override
405405 the `tools` parameter set during component initialization.
406406 :returns: A dictionary with the following keys:
407407 - `replies`: a list of `ChatMessage` instances representing the generated responses.
@@ -411,6 +411,8 @@ def run(
411411
412412 # Handle tools
413413 tools = tools or self .tools
414+ if isinstance (tools , Toolset ):
415+ tools = list (tools )
414416 if tools :
415417 _check_duplicate_tool_names (tools )
416418 generation_kwargs ["tools" ] = [_format_tool (tool ) for tool in tools ]
0 commit comments