Skip to content

Commit a9828d7

Browse files
sjrljulian-risch
authored andcommitted
feat: Add Toolset to Agent (#9284)
* Add Toolset to Agent * Add reno
1 parent 5d3ec43 commit a9828d7

5 files changed

Lines changed: 186 additions & 10 deletions

File tree

haystack/components/agents/agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
from copy import deepcopy
7-
from typing import Any, Dict, List, Optional
7+
from typing import Any, Dict, List, Optional, Union
88

99
from haystack import component, default_from_dict, default_to_dict, logging, tracing
1010
from haystack.components.generators.chat.types import ChatGenerator
@@ -16,7 +16,7 @@
1616
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1717
from haystack.dataclasses.state_utils import merge_lists
1818
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
19-
from haystack.tools import Tool, deserialize_tools_or_toolset_inplace
19+
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
2020
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2121
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
2222

@@ -61,7 +61,7 @@ def __init__(
6161
self,
6262
*,
6363
chat_generator: ChatGenerator,
64-
tools: Optional[List[Tool]] = None,
64+
tools: Optional[Union[List[Tool], Toolset]] = None,
6565
system_prompt: Optional[str] = None,
6666
exit_conditions: Optional[List[str]] = None,
6767
state_schema: Optional[Dict[str, Any]] = None,
@@ -73,7 +73,7 @@ def __init__(
7373
Initialize the agent component.
7474
7575
:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
76-
:param tools: List of Tool objects available to the agent
76+
:param tools: List of Tool objects or a Toolset that the agent can use.
7777
:param system_prompt: System prompt for the agent.
7878
:param exit_conditions: List of conditions that will cause the agent to return.
7979
Can include "text" if the agent should return when it generates a message without tool calls,
@@ -166,7 +166,7 @@ def to_dict(self) -> Dict[str, Any]:
166166
return default_to_dict(
167167
self,
168168
chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"),
169-
tools=[t.to_dict() for t in self.tools],
169+
tools=serialize_tools_or_toolset(self.tools),
170170
system_prompt=self.system_prompt,
171171
exit_conditions=self.exit_conditions,
172172
# We serialize the original state schema, not the resolved one to reflect the original user input

haystack/components/tools/tool_invoker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,17 @@ def __init__(
184184

185185
# Convert Toolset to list for internal use
186186
if isinstance(tools, Toolset):
187-
tools = list(tools)
187+
converted_tools = list(tools)
188+
else:
189+
converted_tools = tools
188190

189-
_check_duplicate_tool_names(tools)
190-
tool_names = [tool.name for tool in tools]
191+
_check_duplicate_tool_names(converted_tools)
192+
tool_names = [tool.name for tool in converted_tools]
191193
duplicates = {name for name in tool_names if tool_names.count(name) > 1}
192194
if duplicates:
193195
raise ValueError(f"Duplicate tool names found: {duplicates}")
194196

195-
self._tools_with_names = dict(zip(tool_names, tools))
197+
self._tools_with_names = dict(zip(tool_names, converted_tools))
196198
self.raise_on_failure = raise_on_failure
197199
self.convert_result_to_json_string = convert_result_to_json_string
198200

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Agent now supports a List of Tools or a Toolset as input.

test/components/agents/test_agent.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,62 @@ def test_to_dict(self, weather_tool, component_tool, monkeypatch):
239239
},
240240
}
241241

242+
def test_to_dict_with_toolset(self, monkeypatch, weather_tool):
243+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
244+
toolset = Toolset(tools=[weather_tool])
245+
agent = Agent(chat_generator=OpenAIChatGenerator(), tools=toolset)
246+
serialized_agent = agent.to_dict()
247+
assert serialized_agent == {
248+
"type": "haystack.components.agents.agent.Agent",
249+
"init_parameters": {
250+
"chat_generator": {
251+
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
252+
"init_parameters": {
253+
"model": "gpt-4o-mini",
254+
"streaming_callback": None,
255+
"api_base_url": None,
256+
"organization": None,
257+
"generation_kwargs": {},
258+
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
259+
"timeout": None,
260+
"max_retries": None,
261+
"tools": None,
262+
"tools_strict": False,
263+
"http_client_kwargs": None,
264+
},
265+
},
266+
"tools": {
267+
"type": "haystack.tools.toolset.Toolset",
268+
"data": {
269+
"tools": [
270+
{
271+
"type": "haystack.tools.tool.Tool",
272+
"data": {
273+
"name": "weather_tool",
274+
"description": "Provides weather information for a given location.",
275+
"parameters": {
276+
"type": "object",
277+
"properties": {"location": {"type": "string"}},
278+
"required": ["location"],
279+
},
280+
"function": "test_agent.weather_function",
281+
"outputs_to_string": None,
282+
"inputs_from_state": None,
283+
"outputs_to_state": None,
284+
},
285+
}
286+
]
287+
},
288+
},
289+
"system_prompt": None,
290+
"exit_conditions": ["text"],
291+
"state_schema": {},
292+
"max_agent_steps": 100,
293+
"raise_on_tool_invocation_failure": False,
294+
"streaming_callback": None,
295+
},
296+
}
297+
242298
def test_from_dict(self, weather_tool, component_tool, monkeypatch):
243299
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
244300
data = {
@@ -318,6 +374,67 @@ def test_from_dict(self, weather_tool, component_tool, monkeypatch):
318374
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
319375
}
320376

377+
def test_from_dict_with_toolset(self, monkeypatch):
378+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
379+
data = {
380+
"type": "haystack.components.agents.agent.Agent",
381+
"init_parameters": {
382+
"chat_generator": {
383+
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
384+
"init_parameters": {
385+
"model": "gpt-4o-mini",
386+
"streaming_callback": None,
387+
"api_base_url": None,
388+
"organization": None,
389+
"generation_kwargs": {},
390+
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
391+
"timeout": None,
392+
"max_retries": None,
393+
"tools": None,
394+
"tools_strict": False,
395+
"http_client_kwargs": None,
396+
},
397+
},
398+
"tools": {
399+
"type": "haystack.tools.toolset.Toolset",
400+
"data": {
401+
"tools": [
402+
{
403+
"type": "haystack.tools.tool.Tool",
404+
"data": {
405+
"name": "weather_tool",
406+
"description": "Provides weather information for a given location.",
407+
"parameters": {
408+
"type": "object",
409+
"properties": {"location": {"type": "string"}},
410+
"required": ["location"],
411+
},
412+
"function": "test_agent.weather_function",
413+
"outputs_to_string": None,
414+
"inputs_from_state": None,
415+
"outputs_to_state": None,
416+
},
417+
}
418+
]
419+
},
420+
},
421+
"system_prompt": None,
422+
"exit_conditions": ["text"],
423+
"state_schema": {},
424+
"max_agent_steps": 100,
425+
"raise_on_tool_invocation_failure": False,
426+
"streaming_callback": None,
427+
},
428+
}
429+
agent = Agent.from_dict(data)
430+
assert isinstance(agent, Agent)
431+
assert isinstance(agent.chat_generator, OpenAIChatGenerator)
432+
assert agent.chat_generator.model == "gpt-4o-mini"
433+
assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY")
434+
assert isinstance(agent.tools, Toolset)
435+
assert agent.tools[0].function is weather_function
436+
assert agent.exit_conditions == ["text"]
437+
321438
def test_serde(self, weather_tool, component_tool, monkeypatch):
322439
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
323440
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))

test/components/tools/test_tool_invoker.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
1010
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
1111
from haystack.dataclasses.state import State
12-
from haystack.tools import ComponentTool, Tool
12+
from haystack.tools import ComponentTool, Tool, Toolset
1313
from haystack.tools.errors import ToolInvocationError
1414

1515

@@ -54,6 +54,34 @@ def faulty_tool_func(location):
5454
)
5555

5656

57+
def add_function(num1: int, num2: int):
58+
return num1 + num2
59+
60+
61+
@pytest.fixture
62+
def tool_set():
63+
return Toolset(
64+
tools=[
65+
Tool(
66+
name="weather_tool",
67+
description="Provides weather information for a given location.",
68+
parameters=weather_parameters,
69+
function=weather_function,
70+
),
71+
Tool(
72+
name="addition_tool",
73+
description="A tool that adds two numbers.",
74+
parameters={
75+
"type": "object",
76+
"properties": {"num1": {"type": "integer"}, "num2": {"type": "integer"}},
77+
"required": ["num1", "num2"],
78+
},
79+
function=add_function,
80+
),
81+
]
82+
)
83+
84+
5785
@pytest.fixture
5886
def invoker(weather_tool):
5987
return ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
@@ -73,6 +101,11 @@ def test_init(self, weather_tool):
73101
assert invoker.raise_on_failure
74102
assert not invoker.convert_result_to_json_string
75103

104+
def test_init_with_toolset(self, tool_set):
105+
tool_invoker = ToolInvoker(tools=tool_set)
106+
assert tool_invoker.tools == tool_set
107+
assert tool_invoker._tools_with_names == {"weather_tool": tool_set.tools[0], "addition_tool": tool_set.tools[1]}
108+
76109
def test_init_fails_wo_tools(self):
77110
with pytest.raises(ValueError):
78111
ToolInvoker(tools=[])
@@ -149,6 +182,26 @@ def test_run(self, invoker):
149182
assert tool_call_result.origin == tool_call
150183
assert not tool_call_result.error
151184

185+
def test_run_with_toolset(self, tool_set):
186+
tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False)
187+
tool_call = ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3})
188+
message = ChatMessage.from_assistant(tool_calls=[tool_call])
189+
190+
result = tool_invoker.run(messages=[message])
191+
assert "tool_messages" in result
192+
assert len(result["tool_messages"]) == 1
193+
194+
tool_message = result["tool_messages"][0]
195+
assert isinstance(tool_message, ChatMessage)
196+
assert tool_message.is_from(ChatRole.TOOL)
197+
assert tool_message.tool_call_results
198+
199+
tool_call_result = tool_message.tool_call_result
200+
assert isinstance(tool_call_result, ToolCallResult)
201+
assert tool_call_result.result == str(8)
202+
assert tool_call_result.origin == tool_call
203+
assert not tool_call_result.error
204+
152205
def test_run_no_messages(self, invoker):
153206
result = invoker.run(messages=[])
154207
assert result["tool_messages"] == []

0 commit comments

Comments
 (0)