Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.agents import State
Expand All @@ -24,6 +24,7 @@
)
from haystack.tools.errors import ToolInvocationError
from haystack.tracing.utils import _serializable_value
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,6 +168,7 @@ def __init__(
convert_result_to_json_string: bool = False,
streaming_callback: Optional[StreamingCallbackT] = None,
*,
enable_streaming_callback_passthrough: bool = False,
async_executor: Optional[ThreadPoolExecutor] = None,
):
"""
Expand All @@ -186,6 +188,11 @@ def __init__(
A callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available — it is not
streamed incrementally in real time.
:param enable_streaming_callback_passthrough:
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
:param async_executor:
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
initialized and used.
Expand All @@ -198,6 +205,7 @@ def __init__(
# could be a Toolset instance or a list of Tools
self.tools = tools
self.streaming_callback = streaming_callback
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough

# Convert Toolset to list for internal use
if isinstance(tools, Toolset):
Expand Down Expand Up @@ -329,18 +337,12 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to
raise conversion_error from e
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)

@staticmethod
def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
def _get_func_params(self, tool: Tool) -> Set:
"""
Combine LLM-provided arguments (llm_args) with state-based arguments.
Returns the function parameters of the tool's invoke method.

Tool arguments take precedence in the following order:
- LLM overrides state if the same param is present in both
- local tool.inputs mappings (if any)
- function signature name matching
This method inspects the tool's function signature to determine which parameters the tool accepts.
"""
final_args = dict(llm_args) # start with LLM-provided

# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
# to find out which parameters the tool accepts.
if isinstance(tool, ComponentTool):
Expand All @@ -352,6 +354,20 @@ def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Di
else:
func_params = set(inspect.signature(tool.function).parameters.keys())

return func_params

def _inject_state_args(self, tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
"""
Combine LLM-provided arguments (llm_args) with state-based arguments.

Tool arguments take precedence in the following order:
- LLM overrides state if the same param is present in both
- local tool.inputs mappings (if any)
- function signature name matching
"""
final_args = dict(llm_args) # start with LLM-provided
func_params = self._get_func_params(tool)

# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
Expand Down Expand Up @@ -417,6 +433,8 @@ def run(
messages: List[ChatMessage],
state: Optional[State] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
*,
enable_streaming_callback_passthrough: Optional[bool] = None,
) -> Dict[str, Any]:
"""
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
Expand All @@ -427,6 +445,12 @@ def run(
:param streaming_callback: A callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available — it is not
streamed incrementally in real time.
:param enable_streaming_callback_passthrough:
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
If None, the value from the constructor will be used.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
Expand All @@ -443,6 +467,12 @@ def run(
if state is None:
state = State(schema={})

resolved_enable_streaming_passthrough = (
enable_streaming_callback_passthrough
if enable_streaming_callback_passthrough is not None
else self.enable_streaming_callback_passthrough
)

# Only keep messages with tool calls
messages_with_tool_calls = [message for message in messages if message.tool_calls]
streaming_callback = select_streaming_callback(
Expand All @@ -468,6 +498,16 @@ def run(
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)

# Check whether to inject streaming_callback
if (
resolved_enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback

# 2) Invoke the tool
try:
tool_result = tool_to_invoke.invoke(**final_args)
Expand Down Expand Up @@ -520,6 +560,8 @@ async def run_async(
messages: List[ChatMessage],
state: Optional[State] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
*,
enable_streaming_callback_passthrough: Optional[bool] = None,
) -> Dict[str, Any]:
"""
Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools.
Expand All @@ -530,6 +572,12 @@ async def run_async(
:param streaming_callback: An asynchronous callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available — it is not
streamed incrementally in real time.
:param enable_streaming_callback_passthrough:
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
If None, the value from the constructor will be used.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
Expand All @@ -546,6 +594,12 @@ async def run_async(
if state is None:
state = State(schema={})

resolved_enable_streaming_passthrough = (
enable_streaming_callback_passthrough
if enable_streaming_callback_passthrough is not None
else self.enable_streaming_callback_passthrough
)

# Only keep messages with tool calls
messages_with_tool_calls = [message for message in messages if message.tool_calls]
streaming_callback = select_streaming_callback(
Expand All @@ -571,6 +625,16 @@ async def run_async(
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)

# Check whether to inject streaming_callback
if (
resolved_enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
):
invoke_params = self._get_func_params(tool_to_invoke)
if "streaming_callback" in invoke_params:
final_args["streaming_callback"] = streaming_callback

# 2) Invoke the tool asynchronously
try:
tool_result = await asyncio.get_running_loop().run_in_executor(
Expand Down Expand Up @@ -626,11 +690,18 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
if self.streaming_callback is not None:
streaming_callback = serialize_callable(self.streaming_callback)
else:
streaming_callback = None

return default_to_dict(
self,
tools=serialize_tools_or_toolset(self.tools),
raise_on_failure=self.raise_on_failure,
convert_result_to_json_string=self.convert_result_to_json_string,
streaming_callback=streaming_callback,
enable_streaming_callback_passthrough=self.enable_streaming_callback_passthrough,
)

@classmethod
Expand All @@ -644,4 +715,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
The deserialized component.
"""
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
if data["init_parameters"].get("streaming_callback") is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(
data["init_parameters"]["streaming_callback"]
)
return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Added the `enable_streaming_callback_passthrough` to the `ToolInovker` init, run and run_async methods. If set to True the ToolInvoker will try and pass the `streaming_callback` function to a tool's invoke method only if the tool's invoke method has `streaming_callback` in its signature.
fixes:
- |
Fixed the `to_dict` and `from_dict` of `ToolInvoker` to properly serialize the `streaming_callback` init parameter.
Loading
Loading