diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 701595a959..b9d618b3ee 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -67,8 +67,9 @@ def __init__( exit_conditions: Optional[List[str]] = None, state_schema: Optional[Dict[str, Any]] = None, max_agent_steps: int = 100, - raise_on_tool_invocation_failure: bool = False, streaming_callback: Optional[StreamingCallbackT] = None, + raise_on_tool_invocation_failure: bool = False, + tool_invoker_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Initialize the agent component. @@ -82,10 +83,11 @@ def __init__( :param state_schema: The schema for the runtime state used by the tools. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. If the agent exceeds this number of steps, it will stop and return the current state. - :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? - If set to False, the exception will be turned into a chat message and passed to the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? + If set to False, the exception will be turned into a chat message and passed to the LLM. + :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. :raises TypeError: If the chat_generator does not support tools parameter in its run method. :raises ValueError: If the exit_conditions are not valid. """ @@ -135,9 +137,15 @@ def __init__( component.set_input_type(self, name=param, type=config["type"], default=None) component.set_output_types(self, **output_types) + self.tool_invoker_kwargs = tool_invoker_kwargs self._tool_invoker = None if self.tools: - self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure) + resolved_tool_invoker_kwargs = { + "tools": self.tools, + "raise_on_failure": self.raise_on_tool_invocation_failure, + **(tool_invoker_kwargs or {}), + } + self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs) else: logger.warning( "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " @@ -175,8 +183,9 @@ def to_dict(self) -> Dict[str, Any]: # We serialize the original state schema, not the resolved one to reflect the original user input state_schema=_schema_to_dict(self._state_schema), max_agent_steps=self.max_agent_steps, - raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, streaming_callback=streaming_callback, + raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, + tool_invoker_kwargs=self.tool_invoker_kwargs, ) @classmethod diff --git a/releasenotes/notes/add-tool-invoker-kwargs-to-agent-a4357dd39b0fd030.yaml b/releasenotes/notes/add-tool-invoker-kwargs-to-agent-a4357dd39b0fd030.yaml new file mode 100644 index 0000000000..c809b96143 --- /dev/null +++ b/releasenotes/notes/add-tool-invoker-kwargs-to-agent-a4357dd39b0fd030.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Added the `tool_invoker_kwargs` param to Agent so additional kwargs can be passed to the ToolInvoker like `max_workers` and `enable_streaming_callback_passthrough`. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 1b7b54e4a4..496f857f04 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -174,6 +174,7 @@ def test_to_dict(self, weather_tool, component_tool, monkeypatch): tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"], state_schema={"foo": {"type": str}}, + tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True}, ) serialized_agent = agent.to_dict() assert serialized_agent == { @@ -236,8 +237,9 @@ def test_to_dict(self, weather_tool, component_tool, monkeypatch): "exit_conditions": ["text", "weather_tool"], "state_schema": {"foo": {"type": "str"}}, "max_agent_steps": 100, - "raise_on_tool_invocation_failure": False, "streaming_callback": None, + "raise_on_tool_invocation_failure": False, + "tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True}, }, } @@ -294,6 +296,7 @@ def test_to_dict_with_toolset(self, monkeypatch, weather_tool): "max_agent_steps": 100, "raise_on_tool_invocation_failure": False, "streaming_callback": None, + "tool_invoker_kwargs": None, }, } @@ -361,6 +364,7 @@ def test_from_dict(self, monkeypatch): "max_agent_steps": 100, "raise_on_tool_invocation_failure": False, "streaming_callback": None, + "tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True}, }, } agent = Agent.from_dict(data) @@ -375,6 +379,9 @@ def test_from_dict(self, monkeypatch): "foo": {"type": str}, "messages": {"handler": merge_lists, "type": List[ChatMessage]}, } + assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True} + assert agent._tool_invoker.max_workers == 5 + assert agent._tool_invoker.enable_streaming_callback_passthrough is True def test_from_dict_with_toolset(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") @@ -426,6 +433,7 @@ def test_from_dict_with_toolset(self, monkeypatch): "max_agent_steps": 100, "raise_on_tool_invocation_failure": False, "streaming_callback": None, + "tool_invoker_kwargs": None, }, } agent = Agent.from_dict(data)