Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def __init__(
exit_conditions: Optional[List[str]] = None,
state_schema: Optional[Dict[str, Any]] = None,
max_agent_steps: int = 100,
raise_on_tool_invocation_failure: bool = False,
streaming_callback: Optional[StreamingCallbackT] = None,
raise_on_tool_invocation_failure: bool = False,
tool_invoker_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Initialize the agent component.
Expand All @@ -82,10 +83,11 @@ def __init__(
:param state_schema: The schema for the runtime state used by the tools.
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
If the agent exceeds this number of steps, it will stop and return the current state.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
:raises ValueError: If the exit_conditions are not valid.
"""
Expand Down Expand Up @@ -135,9 +137,15 @@ def __init__(
component.set_input_type(self, name=param, type=config["type"], default=None)
component.set_output_types(self, **output_types)

self.tool_invoker_kwargs = tool_invoker_kwargs
self._tool_invoker = None
if self.tools:
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
resolved_tool_invoker_kwargs = {
"tools": self.tools,
"raise_on_failure": self.raise_on_tool_invocation_failure,
**(tool_invoker_kwargs or {}),
}
self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
else:
logger.warning(
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
Expand Down Expand Up @@ -175,8 +183,9 @@ def to_dict(self) -> Dict[str, Any]:
# We serialize the original state schema, not the resolved one to reflect the original user input
state_schema=_schema_to_dict(self._state_schema),
max_agent_steps=self.max_agent_steps,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
streaming_callback=streaming_callback,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
tool_invoker_kwargs=self.tool_invoker_kwargs,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Added the `tool_invoker_kwargs` param to Agent so additional kwargs can be passed to the ToolInvoker like `max_workers` and `enable_streaming_callback_passthrough`.
10 changes: 9 additions & 1 deletion test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_to_dict(self, weather_tool, component_tool, monkeypatch):
tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}},
tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True},
)
serialized_agent = agent.to_dict()
assert serialized_agent == {
Expand Down Expand Up @@ -236,8 +237,9 @@ def test_to_dict(self, weather_tool, component_tool, monkeypatch):
"exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
"raise_on_tool_invocation_failure": False,
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
},
}

Expand Down Expand Up @@ -294,6 +296,7 @@ def test_to_dict_with_toolset(self, monkeypatch, weather_tool):
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
"tool_invoker_kwargs": None,
},
}

Expand Down Expand Up @@ -361,6 +364,7 @@ def test_from_dict(self, monkeypatch):
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
},
}
agent = Agent.from_dict(data)
Expand All @@ -375,6 +379,9 @@ def test_from_dict(self, monkeypatch):
"foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
}
assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True}
assert agent._tool_invoker.max_workers == 5
assert agent._tool_invoker.enable_streaming_callback_passthrough is True

def test_from_dict_with_toolset(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
Expand Down Expand Up @@ -426,6 +433,7 @@ def test_from_dict_with_toolset(self, monkeypatch):
"max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None,
"tool_invoker_kwargs": None,
},
}
agent = Agent.from_dict(data)
Expand Down
Loading