Skip to content

Commit 54c5057

Browse files
authored
feat: (and fix) Add enable_streaming_passthrough to ToolInvoker and add missing params to to_dict (#9498)
* Fixes and tests * Add reno * Change variable name * Add test and fix for passing streaming_callback to a component tool * Add unit test * Remove unused import * Fix reno
1 parent 1d6a9f6 commit 54c5057

3 files changed

Lines changed: 206 additions & 18 deletions

File tree

haystack/components/tools/tool_invoker.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import json
88
from concurrent.futures import ThreadPoolExecutor
99
from functools import partial
10-
from typing import Any, Dict, List, Optional, Union
10+
from typing import Any, Dict, List, Optional, Set, Union
1111

1212
from haystack import component, default_from_dict, default_to_dict, logging
1313
from haystack.components.agents import State
@@ -24,6 +24,7 @@
2424
)
2525
from haystack.tools.errors import ToolInvocationError
2626
from haystack.tracing.utils import _serializable_value
27+
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -167,6 +168,7 @@ def __init__(
167168
convert_result_to_json_string: bool = False,
168169
streaming_callback: Optional[StreamingCallbackT] = None,
169170
*,
171+
enable_streaming_callback_passthrough: bool = False,
170172
async_executor: Optional[ThreadPoolExecutor] = None,
171173
):
172174
"""
@@ -186,6 +188,11 @@ def __init__(
186188
A callback function that will be called to emit tool results.
187189
Note that the result is only emitted once it becomes available — it is not
188190
streamed incrementally in real time.
191+
:param enable_streaming_callback_passthrough:
192+
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
193+
This allows tools to stream their results back to the client.
194+
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
195+
If False, the `streaming_callback` will not be passed to the tool invocation.
189196
:param async_executor:
190197
Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
191198
initialized and used.
@@ -198,6 +205,7 @@ def __init__(
198205
# could be a Toolset instance or a list of Tools
199206
self.tools = tools
200207
self.streaming_callback = streaming_callback
208+
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
201209

202210
# Convert Toolset to list for internal use
203211
if isinstance(tools, Toolset):
@@ -329,18 +337,12 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to
329337
raise conversion_error from e
330338
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
331339

332-
@staticmethod
333-
def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
340+
def _get_func_params(self, tool: Tool) -> Set:
334341
"""
335-
Combine LLM-provided arguments (llm_args) with state-based arguments.
342+
Returns the function parameters of the tool's invoke method.
336343
337-
Tool arguments take precedence in the following order:
338-
- LLM overrides state if the same param is present in both
339-
- local tool.inputs mappings (if any)
340-
- function signature name matching
344+
This method inspects the tool's function signature to determine which parameters the tool accepts.
341345
"""
342-
final_args = dict(llm_args) # start with LLM-provided
343-
344346
# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
345347
# to find out which parameters the tool accepts.
346348
if isinstance(tool, ComponentTool):
@@ -352,6 +354,20 @@ def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Di
352354
else:
353355
func_params = set(inspect.signature(tool.function).parameters.keys())
354356

357+
return func_params
358+
359+
def _inject_state_args(self, tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
360+
"""
361+
Combine LLM-provided arguments (llm_args) with state-based arguments.
362+
363+
Tool arguments take precedence in the following order:
364+
- LLM overrides state if the same param is present in both
365+
- local tool.inputs mappings (if any)
366+
- function signature name matching
367+
"""
368+
final_args = dict(llm_args) # start with LLM-provided
369+
func_params = self._get_func_params(tool)
370+
355371
# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
356372
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
357373
if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
@@ -417,6 +433,8 @@ def run(
417433
messages: List[ChatMessage],
418434
state: Optional[State] = None,
419435
streaming_callback: Optional[StreamingCallbackT] = None,
436+
*,
437+
enable_streaming_callback_passthrough: Optional[bool] = None,
420438
) -> Dict[str, Any]:
421439
"""
422440
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
@@ -427,6 +445,12 @@ def run(
427445
:param streaming_callback: A callback function that will be called to emit tool results.
428446
Note that the result is only emitted once it becomes available — it is not
429447
streamed incrementally in real time.
448+
:param enable_streaming_callback_passthrough:
449+
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
450+
This allows tools to stream their results back to the client.
451+
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
452+
If False, the `streaming_callback` will not be passed to the tool invocation.
453+
If None, the value from the constructor will be used.
430454
:returns:
431455
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
432456
Each ChatMessage objects wraps the result of a tool invocation.
@@ -443,6 +467,12 @@ def run(
443467
if state is None:
444468
state = State(schema={})
445469

470+
resolved_enable_streaming_passthrough = (
471+
enable_streaming_callback_passthrough
472+
if enable_streaming_callback_passthrough is not None
473+
else self.enable_streaming_callback_passthrough
474+
)
475+
446476
# Only keep messages with tool calls
447477
messages_with_tool_calls = [message for message in messages if message.tool_calls]
448478
streaming_callback = select_streaming_callback(
@@ -468,6 +498,16 @@ def run(
468498
llm_args = tool_call.arguments.copy()
469499
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
470500

501+
# Check whether to inject streaming_callback
502+
if (
503+
resolved_enable_streaming_passthrough
504+
and streaming_callback is not None
505+
and "streaming_callback" not in final_args
506+
):
507+
invoke_params = self._get_func_params(tool_to_invoke)
508+
if "streaming_callback" in invoke_params:
509+
final_args["streaming_callback"] = streaming_callback
510+
471511
# 2) Invoke the tool
472512
try:
473513
tool_result = tool_to_invoke.invoke(**final_args)
@@ -523,6 +563,8 @@ async def run_async(
523563
messages: List[ChatMessage],
524564
state: Optional[State] = None,
525565
streaming_callback: Optional[StreamingCallbackT] = None,
566+
*,
567+
enable_streaming_callback_passthrough: Optional[bool] = None,
526568
) -> Dict[str, Any]:
527569
"""
528570
Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools.
@@ -533,6 +575,12 @@ async def run_async(
533575
:param streaming_callback: An asynchronous callback function that will be called to emit tool results.
534576
Note that the result is only emitted once it becomes available — it is not
535577
streamed incrementally in real time.
578+
:param enable_streaming_callback_passthrough:
579+
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
580+
This allows tools to stream their results back to the client.
581+
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
582+
If False, the `streaming_callback` will not be passed to the tool invocation.
583+
If None, the value from the constructor will be used.
536584
:returns:
537585
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
538586
Each ChatMessage objects wraps the result of a tool invocation.
@@ -549,6 +597,12 @@ async def run_async(
549597
if state is None:
550598
state = State(schema={})
551599

600+
resolved_enable_streaming_passthrough = (
601+
enable_streaming_callback_passthrough
602+
if enable_streaming_callback_passthrough is not None
603+
else self.enable_streaming_callback_passthrough
604+
)
605+
552606
# Only keep messages with tool calls
553607
messages_with_tool_calls = [message for message in messages if message.tool_calls]
554608
streaming_callback = select_streaming_callback(
@@ -574,6 +628,16 @@ async def run_async(
574628
llm_args = tool_call.arguments.copy()
575629
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
576630

631+
# Check whether to inject streaming_callback
632+
if (
633+
resolved_enable_streaming_passthrough
634+
and streaming_callback is not None
635+
and "streaming_callback" not in final_args
636+
):
637+
invoke_params = self._get_func_params(tool_to_invoke)
638+
if "streaming_callback" in invoke_params:
639+
final_args["streaming_callback"] = streaming_callback
640+
577641
# 2) Invoke the tool asynchronously
578642
try:
579643
tool_result = await asyncio.get_running_loop().run_in_executor(
@@ -632,11 +696,18 @@ def to_dict(self) -> Dict[str, Any]:
632696
:returns:
633697
Dictionary with serialized data.
634698
"""
699+
if self.streaming_callback is not None:
700+
streaming_callback = serialize_callable(self.streaming_callback)
701+
else:
702+
streaming_callback = None
703+
635704
return default_to_dict(
636705
self,
637706
tools=serialize_tools_or_toolset(self.tools),
638707
raise_on_failure=self.raise_on_failure,
639708
convert_result_to_json_string=self.convert_result_to_json_string,
709+
streaming_callback=streaming_callback,
710+
enable_streaming_callback_passthrough=self.enable_streaming_callback_passthrough,
640711
)
641712

642713
@classmethod
@@ -650,4 +721,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
650721
The deserialized component.
651722
"""
652723
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
724+
if data["init_parameters"].get("streaming_callback") is not None:
725+
data["init_parameters"]["streaming_callback"] = deserialize_callable(
726+
data["init_parameters"]["streaming_callback"]
727+
)
653728
return default_from_dict(cls, data)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
Added the `enable_streaming_callback_passthrough` to the `ToolInovker` init, run and run_async methods. If set to True the ToolInvoker will try and pass the `streaming_callback` function to a tool's invoke method only if the tool's invoke method has `streaming_callback` in its signature.
5+
fixes:
6+
- |
7+
Fixed the `to_dict` and `from_dict` of `ToolInvoker` to properly serialize the `streaming_callback` init parameter.

0 commit comments

Comments
 (0)