From f3769e5acb0eeb237de0f238d1fe0cac21daa073 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Jun 2025 13:27:12 +0200 Subject: [PATCH 01/17] Enable parallel tool execution in ToolInvoker --- haystack/components/tools/tool_invoker.py | 73 ++++++++++++++++++----- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 7c7dd9faed..03b921d77d 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -558,7 +558,7 @@ def run( return {"tool_messages": tool_messages, "state": state} @component.output_types(tool_messages=List[ChatMessage], state=State) - async def run_async( + async def run_async( # noqa: PLR0915 self, messages: List[ChatMessage], state: Optional[State] = None, @@ -569,6 +569,7 @@ async def run_async( """ Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools. + Tool invocations are performed concurrently for better performance. :param messages: A list of ChatMessage objects. :param state: The runtime state that should be used by the tools. @@ -594,6 +595,7 @@ async def run_async( :raises ToolOutputMergeError: If merging tool outputs into state fails and `raise_on_failure` is True. """ + if state is None: state = State(schema={}) @@ -609,17 +611,22 @@ async def run_async( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - tool_messages = [] + # Collect all tool calls and prepare them for concurrent execution + tool_call_tasks = [] + tool_call_metadata = [] # Keep track of original tool_call and message info + for message in messages_with_tool_calls: for tool_call in message.tool_calls: tool_name = tool_call.tool_name - # Check if the tool is available, otherwise return an error message + # Check if the tool is available, otherwise create error message immediately if tool_name not in self._tools_with_names: error_message = self._handle_error( ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) ) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) + tool_call_metadata.append( + {"tool_call": tool_call, "error_message": error_message, "is_error": True} + ) continue tool_to_invoke = self._tools_with_names[tool_name] @@ -638,24 +645,59 @@ async def run_async( if "streaming_callback" in invoke_params: final_args["streaming_callback"] = streaming_callback - # 2) Invoke the tool asynchronously - try: - tool_result = await asyncio.get_running_loop().run_in_executor( - self.executor, partial(tool_to_invoke.invoke, **final_args) - ) + # Create async task for tool invocation + task = asyncio.get_running_loop().run_in_executor( + self.executor, partial(tool_to_invoke.invoke, **final_args) + ) + tool_call_tasks.append(task) + tool_call_metadata.append({"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "is_error": False}) - except ToolInvocationError as e: - error_message = self._handle_error(e) + # Execute all tool calls concurrently + if tool_call_tasks: + tool_results = await asyncio.gather(*tool_call_tasks, return_exceptions=True) + else: + tool_results = [] + + # Process results and handle errors + tool_messages = [] + task_index = 0 # Index for tracking non-error tasks + + for metadata in tool_call_metadata: + if metadata["is_error"]: + # Handle pre-validation errors (tool not found) + tool_messages.append( + ChatMessage.from_tool( + tool_result=metadata["error_message"], origin=metadata["tool_call"], error=True + ) + ) + else: + # Handle tool execution results + tool_call = metadata["tool_call"] + tool_to_invoke = metadata["tool_to_invoke"] + result = tool_results[task_index] + task_index += 1 + + # Check if the result is an exception + if isinstance(result, Exception): + if isinstance(result, ToolInvocationError): + error_message = self._handle_error(result) + else: + # Wrap other exceptions as ToolInvocationError + error_message = self._handle_error( + ToolInvocationError(f"Tool '{tool_call.tool_name}' execution failed: {result}") + ) tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) continue # 3) Merge outputs into state try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) + self._merge_tool_outputs(tool_to_invoke, result, state) except Exception as e: try: error_message = self._handle_error( - ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") + ToolOutputMergeError( + f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" + ) ) tool_messages.append( ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) @@ -667,11 +709,10 @@ async def run_async( # 4) Prepare the tool result ChatMessage message tool_messages.append( - self._prepare_tool_result_message( - result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke - ) + self._prepare_tool_result_message(result=result, tool_call=tool_call, tool_to_invoke=tool_to_invoke) ) + # Handle streaming callback if streaming_callback is not None: await streaming_callback( StreamingChunk( From 7ff28ed5939cc77762660b7736bde927d5363aee Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Jun 2025 14:44:25 +0200 Subject: [PATCH 02/17] Update handling of errors --- haystack/components/tools/tool_invoker.py | 40 ++++++----------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 03b921d77d..f43bd9a783 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -569,7 +569,7 @@ async def run_async( # noqa: PLR0915 """ Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools. - Tool invocations are performed concurrently for better performance. + Multiple tool calls are performed concurrently. :param messages: A list of ChatMessage objects. :param state: The runtime state that should be used by the tools. @@ -611,22 +611,21 @@ async def run_async( # noqa: PLR0915 init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - # Collect all tool calls and prepare them for concurrent execution + # Collect valid tool calls for concurrent execution tool_call_tasks = [] - tool_call_metadata = [] # Keep track of original tool_call and message info + valid_tool_calls = [] # Only store valid tool calls and their tools + tool_messages = [] # Start building results immediately for message in messages_with_tool_calls: for tool_call in message.tool_calls: tool_name = tool_call.tool_name - # Check if the tool is available, otherwise create error message immediately + # Handle invalid tools immediately if tool_name not in self._tools_with_names: error_message = self._handle_error( ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) ) - tool_call_metadata.append( - {"tool_call": tool_call, "error_message": error_message, "is_error": True} - ) + tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) continue tool_to_invoke = self._tools_with_names[tool_name] @@ -650,33 +649,14 @@ async def run_async( # noqa: PLR0915 self.executor, partial(tool_to_invoke.invoke, **final_args) ) tool_call_tasks.append(task) - tool_call_metadata.append({"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "is_error": False}) + valid_tool_calls.append((tool_call, tool_to_invoke)) - # Execute all tool calls concurrently + # Execute all valid tool calls concurrently if tool_call_tasks: tool_results = await asyncio.gather(*tool_call_tasks, return_exceptions=True) - else: - tool_results = [] - - # Process results and handle errors - tool_messages = [] - task_index = 0 # Index for tracking non-error tasks - - for metadata in tool_call_metadata: - if metadata["is_error"]: - # Handle pre-validation errors (tool not found) - tool_messages.append( - ChatMessage.from_tool( - tool_result=metadata["error_message"], origin=metadata["tool_call"], error=True - ) - ) - else: - # Handle tool execution results - tool_call = metadata["tool_call"] - tool_to_invoke = metadata["tool_to_invoke"] - result = tool_results[task_index] - task_index += 1 + # Process results + for (tool_call, tool_to_invoke), result in zip(valid_tool_calls, tool_results): # Check if the result is an exception if isinstance(result, Exception): if isinstance(result, ToolInvocationError): From 5385f28da1eb684d4f5e05e01aee9fae423a6661 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Jun 2025 14:50:48 +0200 Subject: [PATCH 03/17] Small fixes --- haystack/components/tools/tool_invoker.py | 24 ++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index f43bd9a783..b0bcac49c6 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -613,14 +613,14 @@ async def run_async( # noqa: PLR0915 # Collect valid tool calls for concurrent execution tool_call_tasks = [] - valid_tool_calls = [] # Only store valid tool calls and their tools - tool_messages = [] # Start building results immediately + valid_tool_calls = [] + tool_messages = [] for message in messages_with_tool_calls: for tool_call in message.tool_calls: tool_name = tool_call.tool_name - # Handle invalid tools immediately + # Check if the tool is available, otherwise return an error message if tool_name not in self._tools_with_names: error_message = self._handle_error( ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) @@ -656,22 +656,22 @@ async def run_async( # noqa: PLR0915 tool_results = await asyncio.gather(*tool_call_tasks, return_exceptions=True) # Process results - for (tool_call, tool_to_invoke), result in zip(valid_tool_calls, tool_results): - # Check if the result is an exception - if isinstance(result, Exception): - if isinstance(result, ToolInvocationError): - error_message = self._handle_error(result) + for (tool_call, tool_to_invoke), tool_result in zip(valid_tool_calls, tool_results): + # Check if the tool_result is an exception + if isinstance(tool_result, Exception): + if isinstance(tool_result, ToolInvocationError): + error_message = self._handle_error(tool_result) else: # Wrap other exceptions as ToolInvocationError error_message = self._handle_error( - ToolInvocationError(f"Tool '{tool_call.tool_name}' execution failed: {result}") + ToolInvocationError(f"Tool '{tool_name}' execution failed: {tool_result}") ) tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) continue # 3) Merge outputs into state try: - self._merge_tool_outputs(tool_to_invoke, result, state) + self._merge_tool_outputs(tool_to_invoke, tool_result, state) except Exception as e: try: error_message = self._handle_error( @@ -689,7 +689,9 @@ async def run_async( # noqa: PLR0915 # 4) Prepare the tool result ChatMessage message tool_messages.append( - self._prepare_tool_result_message(result=result, tool_call=tool_call, tool_to_invoke=tool_to_invoke) + self._prepare_tool_result_message( + result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke + ) ) # Handle streaming callback From 0b3fef9de0a7e19e8f03fc933c7eda8d83a4d353 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Jun 2025 14:52:09 +0200 Subject: [PATCH 04/17] Small fixes --- haystack/components/tools/tool_invoker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index b0bcac49c6..874492a649 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -675,9 +675,7 @@ async def run_async( # noqa: PLR0915 except Exception as e: try: error_message = self._handle_error( - ToolOutputMergeError( - f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" - ) + ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") ) tool_messages.append( ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) From 6270ba31cfb7e3adf4ea3d5bb0b96d9fdb395a18 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Jun 2025 15:13:25 +0200 Subject: [PATCH 05/17] Adapt number of executors --- haystack/components/tools/tool_invoker.py | 171 ++++++++++++---------- 1 file changed, 94 insertions(+), 77 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 874492a649..f4fca4caef 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -567,7 +567,7 @@ async def run_async( # noqa: PLR0915 enable_streaming_callback_passthrough: Optional[bool] = None, ) -> Dict[str, Any]: """ - Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools. + Asynchronously processes ChatMessage objects containing tool calls. Multiple tool calls are performed concurrently. :param messages: @@ -611,98 +611,115 @@ async def run_async( # noqa: PLR0915 init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - # Collect valid tool calls for concurrent execution - tool_call_tasks = [] - valid_tool_calls = [] - tool_messages = [] + # Count total calls so we can size the pool + total_calls = sum(len(message.tool_calls) for message in messages_with_tool_calls) + max_workers = total_calls or 1 - for message in messages_with_tool_calls: - for tool_call in message.tool_calls: - tool_name = tool_call.tool_name + # Use a local executor sized to exactly the number of calls + loop = asyncio.get_running_loop() - # Check if the tool is available, otherwise return an error message - if tool_name not in self._tools_with_names: - error_message = self._handle_error( - ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) - ) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue - - tool_to_invoke = self._tools_with_names[tool_name] - - # 1) Combine user + state inputs - llm_args = tool_call.arguments.copy() - final_args = self._inject_state_args(tool_to_invoke, llm_args, state) - - # Check whether to inject streaming_callback - if ( - resolved_enable_streaming_passthrough - and streaming_callback is not None - and "streaming_callback" not in final_args - ): - invoke_params = self._get_func_params(tool_to_invoke) - if "streaming_callback" in invoke_params: - final_args["streaming_callback"] = streaming_callback + async def invoke_tool_safely( + executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any] + ) -> Any: + """Safely invoke a tool with proper exception handling.""" + try: + return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args)) + except ToolInvocationError as e: + return e - # Create async task for tool invocation - task = asyncio.get_running_loop().run_in_executor( - self.executor, partial(tool_to_invoke.invoke, **final_args) - ) - tool_call_tasks.append(task) - valid_tool_calls.append((tool_call, tool_to_invoke)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + tool_call_tasks = [] + valid_tool_calls = [] + tool_messages = [] - # Execute all valid tool calls concurrently - if tool_call_tasks: - tool_results = await asyncio.gather(*tool_call_tasks, return_exceptions=True) + for message in messages_with_tool_calls: + for tool_call in message.tool_calls: + tool_name = tool_call.tool_name - # Process results - for (tool_call, tool_to_invoke), tool_result in zip(valid_tool_calls, tool_results): - # Check if the tool_result is an exception - if isinstance(tool_result, Exception): - if isinstance(tool_result, ToolInvocationError): - error_message = self._handle_error(tool_result) - else: - # Wrap other exceptions as ToolInvocationError + # Check if the tool is available, otherwise return an error message + if tool_name not in self._tools_with_names: error_message = self._handle_error( - ToolInvocationError(f"Tool '{tool_name}' execution failed: {tool_result}") + ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) ) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue - - # 3) Merge outputs into state - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: - try: - error_message = self._handle_error( - ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) ) + continue + + tool_to_invoke = self._tools_with_names[tool_name] + + # 1) Combine user + state inputs + llm_args = tool_call.arguments.copy() + final_args = self._inject_state_args(tool_to_invoke, llm_args, state) + + # Check whether to inject streaming_callback + if ( + resolved_enable_streaming_passthrough + and streaming_callback is not None + and "streaming_callback" not in final_args + ): + invoke_params = self._get_func_params(tool_to_invoke) + if "streaming_callback" in invoke_params: + final_args["streaming_callback"] = streaming_callback + + # Dispatch each call into our local executor + task = invoke_tool_safely(executor, tool_to_invoke, final_args) + tool_call_tasks.append(task) + valid_tool_calls.append((tool_call, tool_to_invoke)) + + # Execute all valid tool calls concurrently + if tool_call_tasks: + tool_results = await asyncio.gather(*tool_call_tasks) # No return_exceptions since we handle in wrapper + + # Process results + for (tool_call, tool_to_invoke), tool_result in zip(valid_tool_calls, tool_results): + # Check if the tool_result is a ToolInvocationError (caught by our wrapper) + if isinstance(tool_result, ToolInvocationError): + error_message = self._handle_error(tool_result) tool_messages.append( ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) ) continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e - # 4) Prepare the tool result ChatMessage message - tool_messages.append( - self._prepare_tool_result_message( - result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke + # 3) Merge outputs into state + try: + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" + ) + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e + + # 4) Prepare the tool result ChatMessage message + tool_messages.append( + self._prepare_tool_result_message( + result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke + ) ) - ) - # Handle streaming callback - if streaming_callback is not None: - await streaming_callback( - StreamingChunk( - content="", - index=len(tool_messages) - 1, - tool_call_result=tool_messages[-1].tool_call_results[0], - start=True, - meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, + # Handle streaming callback + if streaming_callback is not None: + await streaming_callback( + StreamingChunk( + content="", + index=len(tool_messages) - 1, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, + meta={ + "tool_result": tool_messages[-1].tool_call_results[0].result, + "tool_call": tool_call, + }, + ) ) - ) # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: From 169883dbaeb6e9b6d198e74f1917969913ea9925 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Jun 2025 15:36:13 +0200 Subject: [PATCH 06/17] Add release notes --- .../notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml diff --git a/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml new file mode 100644 index 0000000000..9b717549c3 --- /dev/null +++ b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + `ToolInvoker` now executes `tool_calls` in parallel within the run_async method. From 043cdf4462f8338bc7f42b77dd22e4eae29551f6 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Mon, 23 Jun 2025 00:40:43 +0200 Subject: [PATCH 07/17] Add parallel tool calling to sync run --- haystack/components/tools/tool_invoker.py | 136 ++++++++++++++-------- 1 file changed, 86 insertions(+), 50 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index f4fca4caef..02766a0a7d 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -169,6 +169,7 @@ def __init__( streaming_callback: Optional[StreamingCallbackT] = None, *, enable_streaming_callback_passthrough: bool = False, + max_workers: int = 4, async_executor: Optional[ThreadPoolExecutor] = None, ): """ @@ -193,6 +194,8 @@ def __init__( This allows tools to stream their results back to the client. Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature. If False, the `streaming_callback` will not be passed to the tool invocation. + :param max_workers: + The maximum number of workers to use in the thread pool executor. :param async_executor: Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be initialized and used. @@ -206,6 +209,7 @@ def __init__( self.tools = tools self.streaming_callback = streaming_callback self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough + self.max_workers = max_workers # Convert Toolset to list for internal use if isinstance(tools, Toolset): @@ -480,6 +484,9 @@ def run( ) tool_messages = [] + + # Collect all tool calls and their parameters for parallel execution + tool_call_params = [] for message in messages_with_tool_calls: for tool_call in message.tool_calls: tool_name = tool_call.tool_name @@ -498,7 +505,7 @@ def run( llm_args = tool_call.arguments.copy() final_args = self._inject_state_args(tool_to_invoke, llm_args, state) - # Check whether to inject streaming_callback + # 2) Check whether to inject streaming_callback if ( resolved_enable_streaming_passthrough and streaming_callback is not None @@ -508,48 +515,65 @@ def run( if "streaming_callback" in invoke_params: final_args["streaming_callback"] = streaming_callback - # 2) Invoke the tool - try: - tool_result = tool_to_invoke.invoke(**final_args) + tool_call_params.append( + {"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args} + ) - except ToolInvocationError as e: - error_message = self._handle_error(e) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue + # 3) Execute valid tool calls in parallel + if tool_call_params: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [] + for params in tool_call_params: + future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type] + futures.append(future) + + # Process results as they complete + for future in futures: + result = future.result() + if isinstance(result, ChatMessage): + tool_messages.append(result) + else: + # Handle state merging and prepare tool result message + tool_call, tool_to_invoke, tool_result = result + + # 4) Merge outputs into state + try: + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" + ) + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e - # 3) Merge outputs into state - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: - try: - error_message = self._handle_error( - ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}") - ) + # 5) Prepare the tool result ChatMessage message tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + self._prepare_tool_result_message( + result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke + ) ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e - - # 4) Prepare the tool result ChatMessage message - tool_messages.append( - self._prepare_tool_result_message( - result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke - ) - ) - if streaming_callback is not None: - streaming_callback( - StreamingChunk( - content="", - index=len(tool_messages) - 1, - tool_call_result=tool_messages[-1].tool_call_results[0], - start=True, - meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, - ) - ) + if streaming_callback is not None: + streaming_callback( + StreamingChunk( + content="", + index=len(tool_messages) - 1, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, + meta={ + "tool_result": tool_messages[-1].tool_call_results[0].result, + "tool_call": tool_call, + }, + ) + ) # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: @@ -557,6 +581,22 @@ def run( return {"tool_messages": tool_messages, "state": state} + def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, final_args: Dict[str, Any]): + """ + Execute a single tool call. This method is designed to be run in a thread pool. + + :param tool_call: The ToolCall object containing the tool name and arguments. + :param tool_to_invoke: The Tool object that should be invoked. + :param final_args: The final arguments to pass to the tool. + :returns: Either a ChatMessage with error or a tuple of (tool_call, tool_to_invoke, tool_result) + """ + try: + tool_result = tool_to_invoke.invoke(**final_args) + return (tool_call, tool_to_invoke, tool_result) + except ToolInvocationError as e: + error_message = self._handle_error(e) + return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + @component.output_types(tool_messages=List[ChatMessage], state=State) async def run_async( # noqa: PLR0915 self, @@ -611,10 +651,6 @@ async def run_async( # noqa: PLR0915 init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - # Count total calls so we can size the pool - total_calls = sum(len(message.tool_calls) for message in messages_with_tool_calls) - max_workers = total_calls or 1 - # Use a local executor sized to exactly the number of calls loop = asyncio.get_running_loop() @@ -627,7 +663,7 @@ async def invoke_tool_safely( except ToolInvocationError as e: return e - with ThreadPoolExecutor(max_workers=max_workers) as executor: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: tool_call_tasks = [] valid_tool_calls = [] tool_messages = [] @@ -652,7 +688,7 @@ async def invoke_tool_safely( llm_args = tool_call.arguments.copy() final_args = self._inject_state_args(tool_to_invoke, llm_args, state) - # Check whether to inject streaming_callback + # 2) Check whether to inject streaming_callback if ( resolved_enable_streaming_passthrough and streaming_callback is not None @@ -662,17 +698,17 @@ async def invoke_tool_safely( if "streaming_callback" in invoke_params: final_args["streaming_callback"] = streaming_callback - # Dispatch each call into our local executor + # 3) Dispatch each call into our local executor task = invoke_tool_safely(executor, tool_to_invoke, final_args) tool_call_tasks.append(task) valid_tool_calls.append((tool_call, tool_to_invoke)) - # Execute all valid tool calls concurrently + # 4) Execute all valid tool calls concurrently if tool_call_tasks: tool_results = await asyncio.gather(*tool_call_tasks) # No return_exceptions since we handle in wrapper # Process results - for (tool_call, tool_to_invoke), tool_result in zip(valid_tool_calls, tool_results): + for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)): # Check if the tool_result is a ToolInvocationError (caught by our wrapper) if isinstance(tool_result, ToolInvocationError): error_message = self._handle_error(tool_result) @@ -681,7 +717,7 @@ async def invoke_tool_safely( ) continue - # 3) Merge outputs into state + # 5) Merge outputs into state try: self._merge_tool_outputs(tool_to_invoke, tool_result, state) except Exception as e: @@ -699,7 +735,7 @@ async def invoke_tool_safely( # Re-raise with proper error chain raise propagated_e from e - # 4) Prepare the tool result ChatMessage message + # 6) Prepare the tool result ChatMessage message tool_messages.append( self._prepare_tool_result_message( result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke @@ -711,7 +747,7 @@ async def invoke_tool_safely( await streaming_callback( StreamingChunk( content="", - index=len(tool_messages) - 1, + index=i, tool_call_result=tool_messages[-1].tool_call_results[0], start=True, meta={ From 4cfac72b284d635343b0f45bd5df3179921ebaf9 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 11:36:19 +0200 Subject: [PATCH 08/17] Deprecate async_executor --- haystack/components/tools/tool_invoker.py | 24 +++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 02766a0a7d..819f454232 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -5,6 +5,7 @@ import asyncio import inspect import json +import warnings from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any, Dict, List, Optional, Set, Union @@ -197,8 +198,12 @@ def __init__( :param max_workers: The maximum number of workers to use in the thread pool executor. :param async_executor: - Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be - initialized and used. + Optional `ThreadPoolExecutor` to use for asynchronous calls. + Note: As of Haystack 2.15.0, you no longer need to explicitly pass + `async_executor`. Instead, you can provide the `max_workers` parameter, + and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations. + Support for `async_executor` will be removed in Haystack 2.16.0. + Please migrate to using `max_workers` instead. :raises ValueError: If no tools are provided or if duplicate tool names are found. """ @@ -227,8 +232,19 @@ def __init__( self.raise_on_failure = raise_on_failure self.convert_result_to_json_string = convert_result_to_json_string self._owns_executor = async_executor is None + if self._owns_executor: + warnings.warn( + "'async_executor' is deprecated in favor of the 'max_workers' parameter. " + "ToolInvoker now creates its own executor by default using 'max_workers'. " + "Support for 'async_executor' will be removed in Haystack 2.16.0. " + "Please update your usage to pass 'max_workers' instead.", + DeprecationWarning, + ) + self.executor = ( - ThreadPoolExecutor(thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=1) + ThreadPoolExecutor( + thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers + ) if async_executor is None else async_executor ) @@ -663,7 +679,7 @@ async def invoke_tool_safely( except ToolInvocationError as e: return e - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with self.executor as executor: tool_call_tasks = [] valid_tool_calls = [] tool_messages = [] From 5cdb28d3fe2d8b209b266c0d9fb4461f499b740b Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 12:03:45 +0200 Subject: [PATCH 09/17] Deprecate async_executor --- haystack/components/tools/tool_invoker.py | 35 ++++++++++++----------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 819f454232..5206b182eb 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -5,6 +5,7 @@ import asyncio import inspect import json +import threading import warnings from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -231,11 +232,12 @@ def __init__( self._tools_with_names = dict(zip(tool_names, converted_tools)) self.raise_on_failure = raise_on_failure self.convert_result_to_json_string = convert_result_to_json_string + self._state_lock = threading.Lock() self._owns_executor = async_executor is None if self._owns_executor: warnings.warn( "'async_executor' is deprecated in favor of the 'max_workers' parameter. " - "ToolInvoker now creates its own executor by default using 'max_workers'. " + "ToolInvoker now creates its own thread pool executor by default using 'max_workers'. " "Support for 'async_executor' will be removed in Haystack 2.16.0. " "Please update your usage to pass 'max_workers' instead.", DeprecationWarning, @@ -559,7 +561,7 @@ def run( try: error_message = self._handle_error( ToolOutputMergeError( - f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" + f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" ) ) tool_messages.append( @@ -734,22 +736,23 @@ async def invoke_tool_safely( continue # 5) Merge outputs into state - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: + with self._state_lock: try: - error_message = self._handle_error( - ToolOutputMergeError( - f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" + ) ) - ) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) - ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e # 6) Prepare the tool result ChatMessage message tool_messages.append( From dd359c1a50d38e109c727241b778688691838334 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 12:07:32 +0200 Subject: [PATCH 10/17] Add thread lock --- haystack/components/tools/tool_invoker.py | 30 ++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 5206b182eb..c05b612609 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -555,22 +555,24 @@ def run( tool_call, tool_to_invoke, tool_result = result # 4) Merge outputs into state - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: + with self._state_lock: try: - error_message = self._handle_error( - ToolOutputMergeError( - f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs from" + f"tool {tool_call.tool_name} into State: {e}" + ) ) - ) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) - ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e # 5) Prepare the tool result ChatMessage message tool_messages.append( From ff72fe4cad8404fb5719471e99d82192e5bac7ed Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 12:20:08 +0200 Subject: [PATCH 11/17] extract methods --- haystack/components/tools/tool_invoker.py | 136 +++++++++++----------- 1 file changed, 70 insertions(+), 66 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index c05b612609..3565db7caf 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -449,6 +449,61 @@ def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None: # Merge other outputs into the state state.set(state_key, output_value, handler_override=handler) + def _prepare_tool_call_params( + self, + messages_with_tool_calls: List[ChatMessage], + state: State, + streaming_callback: Optional[StreamingCallbackT], + enable_streaming_passthrough: bool, + ) -> tuple[List[Dict[str, Any]], List[ChatMessage]]: + """ + Prepare tool call parameters for execution and collect any error messages. + + :param messages_with_tool_calls: Messages containing tool calls to process + :param state: The current state for argument injection + :param streaming_callback: Optional streaming callback to inject + :param enable_streaming_passthrough: Whether to pass streaming callback to tools + :returns: Tuple of (tool_call_params, error_messages) + """ + tool_call_params = [] + error_messages = [] + + for message in messages_with_tool_calls: + for tool_call in message.tool_calls: + tool_name = tool_call.tool_name + + # Check if the tool is available, otherwise return an error message + if tool_name not in self._tools_with_names: + error_message = self._handle_error( + ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) + ) + error_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + + tool_to_invoke = self._tools_with_names[tool_name] + + # 1) Combine user + state inputs + llm_args = tool_call.arguments.copy() + final_args = self._inject_state_args(tool_to_invoke, llm_args, state) + + # 2) Check whether to inject streaming_callback + if ( + enable_streaming_passthrough + and streaming_callback is not None + and "streaming_callback" not in final_args + ): + invoke_params = self._get_func_params(tool_to_invoke) + if "streaming_callback" in invoke_params: + final_args["streaming_callback"] = streaming_callback + + tool_call_params.append( + {"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args} + ) + + return tool_call_params, error_messages + @component.output_types(tool_messages=List[ChatMessage], state=State) def run( self, @@ -504,38 +559,10 @@ def run( tool_messages = [] # Collect all tool calls and their parameters for parallel execution - tool_call_params = [] - for message in messages_with_tool_calls: - for tool_call in message.tool_calls: - tool_name = tool_call.tool_name - - # Check if the tool is available, otherwise return an error message - if tool_name not in self._tools_with_names: - error_message = self._handle_error( - ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) - ) - tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)) - continue - - tool_to_invoke = self._tools_with_names[tool_name] - - # 1) Combine user + state inputs - llm_args = tool_call.arguments.copy() - final_args = self._inject_state_args(tool_to_invoke, llm_args, state) - - # 2) Check whether to inject streaming_callback - if ( - resolved_enable_streaming_passthrough - and streaming_callback is not None - and "streaming_callback" not in final_args - ): - invoke_params = self._get_func_params(tool_to_invoke) - if "streaming_callback" in invoke_params: - final_args["streaming_callback"] = streaming_callback - - tool_call_params.append( - {"tool_call": tool_call, "tool_to_invoke": tool_to_invoke, "final_args": final_args} - ) + tool_call_params, error_messages = self._prepare_tool_call_params( + messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough + ) + tool_messages.extend(error_messages) # 3) Execute valid tool calls in parallel if tool_call_params: @@ -618,7 +645,7 @@ def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, f return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) @component.output_types(tool_messages=List[ChatMessage], state=State) - async def run_async( # noqa: PLR0915 + async def run_async( self, messages: List[ChatMessage], state: Optional[State] = None, @@ -688,40 +715,17 @@ async def invoke_tool_safely( valid_tool_calls = [] tool_messages = [] - for message in messages_with_tool_calls: - for tool_call in message.tool_calls: - tool_name = tool_call.tool_name - - # Check if the tool is available, otherwise return an error message - if tool_name not in self._tools_with_names: - error_message = self._handle_error( - ToolNotFoundException(tool_name, list(self._tools_with_names.keys())) - ) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) - ) - continue + # Prepare tool call parameters for execution + tool_call_params, error_messages = self._prepare_tool_call_params( + messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough + ) + tool_messages.extend(error_messages) - tool_to_invoke = self._tools_with_names[tool_name] - - # 1) Combine user + state inputs - llm_args = tool_call.arguments.copy() - final_args = self._inject_state_args(tool_to_invoke, llm_args, state) - - # 2) Check whether to inject streaming_callback - if ( - resolved_enable_streaming_passthrough - and streaming_callback is not None - and "streaming_callback" not in final_args - ): - invoke_params = self._get_func_params(tool_to_invoke) - if "streaming_callback" in invoke_params: - final_args["streaming_callback"] = streaming_callback - - # 3) Dispatch each call into our local executor - task = invoke_tool_safely(executor, tool_to_invoke, final_args) - tool_call_tasks.append(task) - valid_tool_calls.append((tool_call, tool_to_invoke)) + # Create async tasks for valid tool calls + for params in tool_call_params: + task = invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"]) + tool_call_tasks.append(task) + valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"])) # 4) Execute all valid tool calls concurrently if tool_call_tasks: From fc6873165f26888b76bca99b9b4c036d8ccfa32a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 12:23:56 +0200 Subject: [PATCH 12/17] Update release notes --- .../notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml index 9b717549c3..7cba7adaf8 100644 --- a/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml +++ b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml @@ -1,4 +1,6 @@ --- features: - | - `ToolInvoker` now executes `tool_calls` in parallel within the run_async method. + `ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode. + `async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0. + You can use `max_workers` parameter to control the number of threads used for parallel tool calling. From ddbd46e2b2ef93c78d5122fb7828d8e891100360 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 12:37:03 +0200 Subject: [PATCH 13/17] Update release notes --- .../notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml index 7cba7adaf8..1ca33c77d1 100644 --- a/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml +++ b/releasenotes/notes/enable-parallel-tool-calling-96c6589f116c7e7b.yaml @@ -2,5 +2,8 @@ features: - | `ToolInvoker` now executes `tool_calls` in parallel for both sync and async mode. + +deprecations: + - | `async_executor` parameter in `ToolInvoker` is deprecated in favor of `max_workers` parameter and will be removed in Haystack 2.16.0. You can use `max_workers` parameter to control the number of threads used for parallel tool calling. From 02413c91598ae41f8039bdb808f70b57ad0f036e Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Jun 2025 15:10:20 +0200 Subject: [PATCH 14/17] Updates --- haystack/components/tools/tool_invoker.py | 164 +++++++++++----------- 1 file changed, 83 insertions(+), 81 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 3565db7caf..682311cc65 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -484,11 +484,11 @@ def _prepare_tool_call_params( tool_to_invoke = self._tools_with_names[tool_name] - # 1) Combine user + state inputs + # Combine user + state inputs llm_args = tool_call.arguments.copy() final_args = self._inject_state_args(tool_to_invoke, llm_args, state) - # 2) Check whether to inject streaming_callback + # Check whether to inject streaming_callback if ( enable_streaming_passthrough and streaming_callback is not None @@ -558,13 +558,13 @@ def run( tool_messages = [] - # Collect all tool calls and their parameters for parallel execution + # 1) Collect all tool calls and their parameters for parallel execution tool_call_params, error_messages = self._prepare_tool_call_params( messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough ) tool_messages.extend(error_messages) - # 3) Execute valid tool calls in parallel + # 2) Execute valid tool calls in parallel if tool_call_params: with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [] @@ -572,7 +572,7 @@ def run( future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type] futures.append(future) - # Process results as they complete + # 3) Process results as they complete for future in futures: result = future.result() if isinstance(result, ChatMessage): @@ -608,6 +608,7 @@ def run( ) ) + # 6) Handle streaming callback if streaming_callback is not None: streaming_callback( StreamingChunk( @@ -644,6 +645,15 @@ def _execute_single_tool_call(self, tool_call: ToolCall, tool_to_invoke: Tool, f error_message = self._handle_error(e) return ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + @staticmethod + async def invoke_tool_safely(executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any]) -> Any: + """Safely invoke a tool with proper exception handling.""" + loop = asyncio.get_running_loop() + try: + return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args)) + except ToolInvocationError as e: + return e + @component.output_types(tool_messages=List[ChatMessage], state=State) async def run_async( self, @@ -698,90 +708,82 @@ async def run_async( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - # Use a local executor sized to exactly the number of calls - loop = asyncio.get_running_loop() + tool_messages = [] - async def invoke_tool_safely( - executor: ThreadPoolExecutor, tool_to_invoke: Tool, final_args: Dict[str, Any] - ) -> Any: - """Safely invoke a tool with proper exception handling.""" - try: - return await loop.run_in_executor(executor, partial(tool_to_invoke.invoke, **final_args)) - except ToolInvocationError as e: - return e - - with self.executor as executor: - tool_call_tasks = [] - valid_tool_calls = [] - tool_messages = [] - - # Prepare tool call parameters for execution - tool_call_params, error_messages = self._prepare_tool_call_params( - messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough - ) - tool_messages.extend(error_messages) - - # Create async tasks for valid tool calls - for params in tool_call_params: - task = invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"]) - tool_call_tasks.append(task) - valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"])) - - # 4) Execute all valid tool calls concurrently - if tool_call_tasks: - tool_results = await asyncio.gather(*tool_call_tasks) # No return_exceptions since we handle in wrapper - - # Process results - for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)): - # Check if the tool_result is a ToolInvocationError (caught by our wrapper) - if isinstance(tool_result, ToolInvocationError): - error_message = self._handle_error(tool_result) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) - ) - continue + # 1) Prepare tool call parameters for execution + tool_call_params, error_messages = self._prepare_tool_call_params( + messages_with_tool_calls, state, streaming_callback, resolved_enable_streaming_passthrough + ) + tool_messages.extend(error_messages) + + # 2) Execute valid tool calls in parallel + if tool_call_params: + with self.executor as executor: + tool_call_tasks = [] + valid_tool_calls = [] - # 5) Merge outputs into state - with self._state_lock: - try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: + # 3) Create async tasks for valid tool calls + for params in tool_call_params: + task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"]) + tool_call_tasks.append(task) + valid_tool_calls.append((params["tool_call"], params["tool_to_invoke"])) + + if tool_call_tasks: + # 4) Gather results from all tool calls + tool_results = await asyncio.gather(*tool_call_tasks) + + # 5) Process results + for i, ((tool_call, tool_to_invoke), tool_result) in enumerate(zip(valid_tool_calls, tool_results)): + # Check if the tool_result is a ToolInvocationError (caught by our wrapper) + if isinstance(tool_result, ToolInvocationError): + error_message = self._handle_error(tool_result) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + + # 6) Merge outputs into state + with self._state_lock: try: - error_message = self._handle_error( - ToolOutputMergeError( - f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}" + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: + try: + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs from" + f"tool {tool_call.tool_name} into State: {e}" + ) ) - ) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) - ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e - - # 6) Prepare the tool result ChatMessage message - tool_messages.append( - self._prepare_tool_result_message( - result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke - ) - ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e - # Handle streaming callback - if streaming_callback is not None: - await streaming_callback( - StreamingChunk( - content="", - index=i, - tool_call_result=tool_messages[-1].tool_call_results[0], - start=True, - meta={ - "tool_result": tool_messages[-1].tool_call_results[0].result, - "tool_call": tool_call, - }, + # 7) Prepare the tool result ChatMessage message + tool_messages.append( + self._prepare_tool_result_message( + result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke ) ) + # 8) Handle streaming callback + if streaming_callback is not None: + await streaming_callback( + StreamingChunk( + content="", + index=i, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, + meta={ + "tool_result": tool_messages[-1].tool_call_results[0].result, + "tool_call": tool_call, + }, + ) + ) + # We stream one more chunk that contains a finish_reason if tool_messages were generated if len(tool_messages) > 0 and streaming_callback is not None: await streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"})) From d28c395b6a952e62321e97b0bcb9303ae2e9ec39 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 25 Jun 2025 11:35:03 +0200 Subject: [PATCH 15/17] Add new tests --- haystack/components/tools/tool_invoker.py | 30 ++++---- test/components/tools/test_tool_invoker.py | 79 ++++++++++++++++++++++ 2 files changed, 93 insertions(+), 16 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 682311cc65..bf1e526a4f 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -582,24 +582,22 @@ def run( tool_call, tool_to_invoke, tool_result = result # 4) Merge outputs into state - with self._state_lock: + try: + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: - try: - error_message = self._handle_error( - ToolOutputMergeError( - f"Failed to merge tool outputs from" - f"tool {tool_call.tool_name} into State: {e}" - ) + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" ) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) - ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e # 5) Prepare the tool result ChatMessage message tool_messages.append( diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index a21f5ae31a..30a83705f9 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -6,6 +6,7 @@ import pytest import json import datetime +import time from haystack import Pipeline from haystack.components.builders.prompt_builder import PromptBuilder @@ -624,6 +625,84 @@ def test_enable_streaming_callback_passthrough_runtime(self, monkeypatch): ) mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")]) + def test_parallel_tool_calling_with_state_updates(self): + """Test that parallel tool execution with state updates works correctly with the state lock.""" + # Create a shared counter variable to simulate a state value that gets updated + execution_log = [] + + def function_1(): + # Simulate some work that takes time + time.sleep(0.1) + execution_log.append("tool_1_executed") + return {"counter": 1, "tool_name": "tool_1"} + + def function_2(): + # Simulate some work that takes time + time.sleep(0.1) + execution_log.append("tool_2_executed") + return {"counter": 2, "tool_name": "tool_2"} + + def function_3(): + # Simulate some work that takes time + time.sleep(0.1) + execution_log.append("tool_3_executed") + return {"counter": 3, "tool_name": "tool_3"} + + # Create tools that all update the same state key + tool_1 = Tool( + name="state_tool_1", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_1, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_2 = Tool( + name="state_tool_2", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_2, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_3 = Tool( + name="state_tool_3", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_3, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + # Create ToolInvoker with all three tools + invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) + + # Create initial state + state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) + + # Create tool calls that will be executed in parallel + tool_calls = [ + ToolCall(tool_name="state_tool_1", arguments={}), + ToolCall(tool_name="state_tool_2", arguments={}), + ToolCall(tool_name="state_tool_3", arguments={}), + ] + message = ChatMessage.from_assistant(tool_calls=tool_calls) + + # Execute the tools + result = invoker.run(messages=[message], state=state) + + # Verify that all three tools were executed + assert len(execution_log) == 3 + assert "tool_1_executed" in execution_log + assert "tool_2_executed" in execution_log + assert "tool_3_executed" in execution_log + + # Verify that the state was updated correctly + # Due to parallel execution, we can't predict which tool will be the last to update + assert state.has("counter") + assert state.has("last_tool") + assert state.get("counter") in [1, 2, 3] # Should be one of the tool values + assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names + class TestMergeToolOutputs: def test_merge_tool_outputs_result_not_a_dict(self, weather_tool): From 4c0ac834df9d7d59104a795ffeb5d98d488b5365 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 25 Jun 2025 11:40:05 +0200 Subject: [PATCH 16/17] Add test for async --- haystack/components/tools/tool_invoker.py | 32 ++++----- test/components/tools/test_tool_invoker.py | 79 +++++++++++++++++++--- 2 files changed, 85 insertions(+), 26 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index bf1e526a4f..c5835b67e5 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -5,7 +5,6 @@ import asyncio import inspect import json -import threading import warnings from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -232,7 +231,6 @@ def __init__( self._tools_with_names = dict(zip(tool_names, converted_tools)) self.raise_on_failure = raise_on_failure self.convert_result_to_json_string = convert_result_to_json_string - self._state_lock = threading.Lock() self._owns_executor = async_executor is None if self._owns_executor: warnings.warn( @@ -741,24 +739,22 @@ async def run_async( continue # 6) Merge outputs into state - with self._state_lock: + try: + self._merge_tool_outputs(tool_to_invoke, tool_result, state) + except Exception as e: try: - self._merge_tool_outputs(tool_to_invoke, tool_result, state) - except Exception as e: - try: - error_message = self._handle_error( - ToolOutputMergeError( - f"Failed to merge tool outputs from" - f"tool {tool_call.tool_name} into State: {e}" - ) - ) - tool_messages.append( - ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + error_message = self._handle_error( + ToolOutputMergeError( + f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" ) - continue - except ToolOutputMergeError as propagated_e: - # Re-raise with proper error chain - raise propagated_e from e + ) + tool_messages.append( + ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True) + ) + continue + except ToolOutputMergeError as propagated_e: + # Re-raise with proper error chain + raise propagated_e from e # 7) Prepare the tool result ChatMessage message tool_messages.append( diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index 30a83705f9..90e1f81c5d 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -631,19 +631,16 @@ def test_parallel_tool_calling_with_state_updates(self): execution_log = [] def function_1(): - # Simulate some work that takes time time.sleep(0.1) execution_log.append("tool_1_executed") return {"counter": 1, "tool_name": "tool_1"} def function_2(): - # Simulate some work that takes time time.sleep(0.1) execution_log.append("tool_2_executed") return {"counter": 2, "tool_name": "tool_2"} def function_3(): - # Simulate some work that takes time time.sleep(0.1) execution_log.append("tool_3_executed") return {"counter": 3, "tool_name": "tool_3"} @@ -676,18 +673,13 @@ def function_3(): # Create ToolInvoker with all three tools invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) - # Create initial state state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) - - # Create tool calls that will be executed in parallel tool_calls = [ ToolCall(tool_name="state_tool_1", arguments={}), ToolCall(tool_name="state_tool_2", arguments={}), ToolCall(tool_name="state_tool_3", arguments={}), ] message = ChatMessage.from_assistant(tool_calls=tool_calls) - - # Execute the tools result = invoker.run(messages=[message], state=state) # Verify that all three tools were executed @@ -703,6 +695,77 @@ def function_3(): assert state.get("counter") in [1, 2, 3] # Should be one of the tool values assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names + @pytest.mark.asyncio + async def test_async_parallel_tool_calling_with_state_updates(self): + """Test that parallel tool execution with state updates works correctly with the state lock.""" + # Create a shared counter variable to simulate a state value that gets updated + execution_log = [] + + def function_1(): + time.sleep(0.1) + execution_log.append("tool_1_executed") + return {"counter": 1, "tool_name": "tool_1"} + + def function_2(): + time.sleep(0.1) + execution_log.append("tool_2_executed") + return {"counter": 2, "tool_name": "tool_2"} + + def function_3(): + time.sleep(0.1) + execution_log.append("tool_3_executed") + return {"counter": 3, "tool_name": "tool_3"} + + # Create tools that all update the same state key + tool_1 = Tool( + name="state_tool_1", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_1, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_2 = Tool( + name="state_tool_2", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_2, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + tool_3 = Tool( + name="state_tool_3", + description="A tool that updates state counter", + parameters={"type": "object", "properties": {}}, + function=function_3, + outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, + ) + + # Create ToolInvoker with all three tools + invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) + + state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) + tool_calls = [ + ToolCall(tool_name="state_tool_1", arguments={}), + ToolCall(tool_name="state_tool_2", arguments={}), + ToolCall(tool_name="state_tool_3", arguments={}), + ] + message = ChatMessage.from_assistant(tool_calls=tool_calls) + result = await invoker.run_async(messages=[message], state=state) + + # Verify that all three tools were executed + assert len(execution_log) == 3 + assert "tool_1_executed" in execution_log + assert "tool_2_executed" in execution_log + assert "tool_3_executed" in execution_log + + # Verify that the state was updated correctly + # Due to parallel execution, we can't predict which tool will be the last to update + assert state.has("counter") + assert state.has("last_tool") + assert state.get("counter") in [1, 2, 3] # Should be one of the tool values + assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names + class TestMergeToolOutputs: def test_merge_tool_outputs_result_not_a_dict(self, weather_tool): From 0673edc94f57577beaa0eca0c946fe821058ef4e Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 25 Jun 2025 11:43:02 +0200 Subject: [PATCH 17/17] PR comments --- haystack/components/tools/tool_invoker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index c5835b67e5..d953e4f027 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -570,7 +570,7 @@ def run( future = executor.submit(self._execute_single_tool_call, **params) # type: ignore[arg-type] futures.append(future) - # 3) Process results as they complete + # 3) Process results in the order they are submitted for future in futures: result = future.result() if isinstance(result, ChatMessage):