77import json
88from concurrent .futures import ThreadPoolExecutor
99from functools import partial
10- from typing import Any , Dict , List , Optional , Union
10+ from typing import Any , Dict , List , Optional , Set , Union
1111
1212from haystack import component , default_from_dict , default_to_dict , logging
1313from haystack .components .agents import State
2424)
2525from haystack .tools .errors import ToolInvocationError
2626from haystack .tracing .utils import _serializable_value
27+ from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
2728
2829logger = logging .getLogger (__name__ )
2930
@@ -167,6 +168,7 @@ def __init__(
167168 convert_result_to_json_string : bool = False ,
168169 streaming_callback : Optional [StreamingCallbackT ] = None ,
169170 * ,
171+ enable_streaming_callback_passthrough : bool = False ,
170172 async_executor : Optional [ThreadPoolExecutor ] = None ,
171173 ):
172174 """
@@ -186,6 +188,11 @@ def __init__(
186188 A callback function that will be called to emit tool results.
187189 Note that the result is only emitted once it becomes available — it is not
188190 streamed incrementally in real time.
191+ :param enable_streaming_callback_passthrough:
192+ If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
193+ This allows tools to stream their results back to the client.
194+ Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
195+ If False, the `streaming_callback` will not be passed to the tool invocation.
189196 :param async_executor:
190197 Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
191198 initialized and used.
@@ -198,6 +205,7 @@ def __init__(
198205 # could be a Toolset instance or a list of Tools
199206 self .tools = tools
200207 self .streaming_callback = streaming_callback
208+ self .enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
201209
202210 # Convert Toolset to list for internal use
203211 if isinstance (tools , Toolset ):
@@ -329,18 +337,12 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to
329337 raise conversion_error from e
330338 return ChatMessage .from_tool (tool_result = tool_result_str , error = error , origin = tool_call )
331339
332- @staticmethod
333- def _inject_state_args (tool : Tool , llm_args : Dict [str , Any ], state : State ) -> Dict [str , Any ]:
340+ def _get_func_params (self , tool : Tool ) -> Set :
334341 """
335- Combine LLM-provided arguments (llm_args) with state-based arguments .
342+ Returns the function parameters of the tool's invoke method .
336343
337- Tool arguments take precedence in the following order:
338- - LLM overrides state if the same param is present in both
339- - local tool.inputs mappings (if any)
340- - function signature name matching
344+ This method inspects the tool's function signature to determine which parameters the tool accepts.
341345 """
342- final_args = dict (llm_args ) # start with LLM-provided
343-
344346 # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
345347 # to find out which parameters the tool accepts.
346348 if isinstance (tool , ComponentTool ):
@@ -352,6 +354,20 @@ def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Di
352354 else :
353355 func_params = set (inspect .signature (tool .function ).parameters .keys ())
354356
357+ return func_params
358+
359+ def _inject_state_args (self , tool : Tool , llm_args : Dict [str , Any ], state : State ) -> Dict [str , Any ]:
360+ """
361+ Combine LLM-provided arguments (llm_args) with state-based arguments.
362+
363+ Tool arguments take precedence in the following order:
364+ - LLM overrides state if the same param is present in both
365+ - local tool.inputs mappings (if any)
366+ - function signature name matching
367+ """
368+ final_args = dict (llm_args ) # start with LLM-provided
369+ func_params = self ._get_func_params (tool )
370+
355371 # Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
356372 # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
357373 if hasattr (tool , "inputs_from_state" ) and isinstance (tool .inputs_from_state , dict ):
@@ -417,6 +433,8 @@ def run(
417433 messages : List [ChatMessage ],
418434 state : Optional [State ] = None ,
419435 streaming_callback : Optional [StreamingCallbackT ] = None ,
436+ * ,
437+ enable_streaming_callback_passthrough : Optional [bool ] = None ,
420438 ) -> Dict [str , Any ]:
421439 """
422440 Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
@@ -427,6 +445,12 @@ def run(
427445 :param streaming_callback: A callback function that will be called to emit tool results.
428446 Note that the result is only emitted once it becomes available — it is not
429447 streamed incrementally in real time.
448+ :param enable_streaming_callback_passthrough:
449+ If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
450+ This allows tools to stream their results back to the client.
451+ Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
452+ If False, the `streaming_callback` will not be passed to the tool invocation.
453+ If None, the value from the constructor will be used.
430454 :returns:
431455 A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
432456 Each ChatMessage objects wraps the result of a tool invocation.
@@ -443,6 +467,12 @@ def run(
443467 if state is None :
444468 state = State (schema = {})
445469
470+ resolved_enable_streaming_passthrough = (
471+ enable_streaming_callback_passthrough
472+ if enable_streaming_callback_passthrough is not None
473+ else self .enable_streaming_callback_passthrough
474+ )
475+
446476 # Only keep messages with tool calls
447477 messages_with_tool_calls = [message for message in messages if message .tool_calls ]
448478 streaming_callback = select_streaming_callback (
@@ -468,6 +498,16 @@ def run(
468498 llm_args = tool_call .arguments .copy ()
469499 final_args = self ._inject_state_args (tool_to_invoke , llm_args , state )
470500
501+ # Check whether to inject streaming_callback
502+ if (
503+ resolved_enable_streaming_passthrough
504+ and streaming_callback is not None
505+ and "streaming_callback" not in final_args
506+ ):
507+ invoke_params = self ._get_func_params (tool_to_invoke )
508+ if "streaming_callback" in invoke_params :
509+ final_args ["streaming_callback" ] = streaming_callback
510+
471511 # 2) Invoke the tool
472512 try :
473513 tool_result = tool_to_invoke .invoke (** final_args )
@@ -523,6 +563,8 @@ async def run_async(
523563 messages : List [ChatMessage ],
524564 state : Optional [State ] = None ,
525565 streaming_callback : Optional [StreamingCallbackT ] = None ,
566+ * ,
567+ enable_streaming_callback_passthrough : Optional [bool ] = None ,
526568 ) -> Dict [str , Any ]:
527569 """
528570 Asynchronously processes ChatMessage objects containing tool calls and invokes the corresponding tools.
@@ -533,6 +575,12 @@ async def run_async(
533575 :param streaming_callback: An asynchronous callback function that will be called to emit tool results.
534576 Note that the result is only emitted once it becomes available — it is not
535577 streamed incrementally in real time.
578+ :param enable_streaming_callback_passthrough:
579+ If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
580+ This allows tools to stream their results back to the client.
581+ Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
582+ If False, the `streaming_callback` will not be passed to the tool invocation.
583+ If None, the value from the constructor will be used.
536584 :returns:
537585 A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
538586 Each ChatMessage objects wraps the result of a tool invocation.
@@ -549,6 +597,12 @@ async def run_async(
549597 if state is None :
550598 state = State (schema = {})
551599
600+ resolved_enable_streaming_passthrough = (
601+ enable_streaming_callback_passthrough
602+ if enable_streaming_callback_passthrough is not None
603+ else self .enable_streaming_callback_passthrough
604+ )
605+
552606 # Only keep messages with tool calls
553607 messages_with_tool_calls = [message for message in messages if message .tool_calls ]
554608 streaming_callback = select_streaming_callback (
@@ -574,6 +628,16 @@ async def run_async(
574628 llm_args = tool_call .arguments .copy ()
575629 final_args = self ._inject_state_args (tool_to_invoke , llm_args , state )
576630
631+ # Check whether to inject streaming_callback
632+ if (
633+ resolved_enable_streaming_passthrough
634+ and streaming_callback is not None
635+ and "streaming_callback" not in final_args
636+ ):
637+ invoke_params = self ._get_func_params (tool_to_invoke )
638+ if "streaming_callback" in invoke_params :
639+ final_args ["streaming_callback" ] = streaming_callback
640+
577641 # 2) Invoke the tool asynchronously
578642 try :
579643 tool_result = await asyncio .get_running_loop ().run_in_executor (
@@ -632,11 +696,18 @@ def to_dict(self) -> Dict[str, Any]:
632696 :returns:
633697 Dictionary with serialized data.
634698 """
699+ if self .streaming_callback is not None :
700+ streaming_callback = serialize_callable (self .streaming_callback )
701+ else :
702+ streaming_callback = None
703+
635704 return default_to_dict (
636705 self ,
637706 tools = serialize_tools_or_toolset (self .tools ),
638707 raise_on_failure = self .raise_on_failure ,
639708 convert_result_to_json_string = self .convert_result_to_json_string ,
709+ streaming_callback = streaming_callback ,
710+ enable_streaming_callback_passthrough = self .enable_streaming_callback_passthrough ,
640711 )
641712
642713 @classmethod
@@ -650,4 +721,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
650721 The deserialized component.
651722 """
652723 deserialize_tools_or_toolset_inplace (data ["init_parameters" ], key = "tools" )
724+ if data ["init_parameters" ].get ("streaming_callback" ) is not None :
725+ data ["init_parameters" ]["streaming_callback" ] = deserialize_callable (
726+ data ["init_parameters" ]["streaming_callback" ]
727+ )
653728 return default_from_dict (cls , data )
0 commit comments