Skip to content

Commit 734cfda

Browse files
authored
feat: Add Toolset support to CohereChatGenerator (#1700)
1 parent e830011 commit 734cfda

2 files changed

Lines changed: 18 additions & 16 deletions

File tree

integrations/cohere/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai", "cohere==5.*"]
26+
dependencies = ["haystack-ai>=2.13.1", "cohere==5.*"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme"

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import json
2-
from typing import Any, Callable, Dict, Generator, List, Optional
2+
from typing import Any, Callable, Dict, Generator, List, Optional, Union
33

44
from haystack import component, default_from_dict, default_to_dict, logging
55
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
66
from haystack.lazy_imports import LazyImport
7-
from haystack.tools import Tool, _check_duplicate_tool_names
7+
from haystack.tools import (
8+
Tool,
9+
Toolset,
10+
_check_duplicate_tool_names,
11+
deserialize_tools_or_toolset_inplace,
12+
serialize_tools_or_toolset,
13+
)
814
from haystack.utils import Secret, deserialize_secrets_inplace
915
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1016

11-
# Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
12-
try:
13-
from haystack.tools import deserialize_tools_or_toolset_inplace
14-
except ImportError:
15-
from haystack.tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
16-
1717
from cohere import ChatResponse
1818

1919
with LazyImport(message="Run 'pip install cohere'") as cohere_import:
@@ -300,7 +300,7 @@ def __init__(
300300
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
301301
api_base_url: Optional[str] = None,
302302
generation_kwargs: Optional[Dict[str, Any]] = None,
303-
tools: Optional[List[Tool]] = None,
303+
tools: Optional[Union[List[Tool], Toolset]] = None,
304304
**kwargs,
305305
):
306306
"""
@@ -323,10 +323,11 @@ def __init__(
323323
`accurate` results or `fast` results.
324324
- 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures
325325
mean less random generations.
326-
:param tools: A list of Tool objects that the model can use. Each tool should have a unique name.
326+
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
327+
327328
"""
328329
cohere_import.check()
329-
_check_duplicate_tool_names(tools)
330+
_check_duplicate_tool_names(list(tools or [])) # handles Toolset as well
330331

331332
if not api_base_url:
332333
api_base_url = "https://api.cohere.com"
@@ -357,15 +358,14 @@ def to_dict(self) -> Dict[str, Any]:
357358
Dictionary with serialized data.
358359
"""
359360
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
360-
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
361361
return default_to_dict(
362362
self,
363363
model=self.model,
364364
streaming_callback=callback_name,
365365
api_base_url=self.api_base_url,
366366
api_key=self.api_key.to_dict(),
367367
generation_kwargs=self.generation_kwargs,
368-
tools=serialized_tools,
368+
tools=serialize_tools_or_toolset(self.tools),
369369
)
370370

371371
@classmethod
@@ -391,7 +391,7 @@ def run(
391391
self,
392392
messages: List[ChatMessage],
393393
generation_kwargs: Optional[Dict[str, Any]] = None,
394-
tools: Optional[List[Tool]] = None,
394+
tools: Optional[Union[List[Tool], Toolset]] = None,
395395
):
396396
"""
397397
Invoke the chat endpoint based on the provided messages and generation parameters.
@@ -401,7 +401,7 @@ def run(
401401
potentially override the parameters passed in the __init__ method.
402402
For more details on the parameters supported by the Cohere API, refer to the
403403
Cohere [documentation](https://docs.cohere.com/reference/chat).
404-
:param tools: A list of tools for which the model can prepare calls. If set, it will override
404+
:param tools: A list of tools or a Toolset for which the model can prepare calls. If set, it will override
405405
the `tools` parameter set during component initialization.
406406
:returns: A dictionary with the following keys:
407407
- `replies`: a list of `ChatMessage` instances representing the generated responses.
@@ -411,6 +411,8 @@ def run(
411411

412412
# Handle tools
413413
tools = tools or self.tools
414+
if isinstance(tools, Toolset):
415+
tools = list(tools)
414416
if tools:
415417
_check_duplicate_tool_names(tools)
416418
generation_kwargs["tools"] = [_format_tool(tool) for tool in tools]

0 commit comments

Comments
 (0)