forked from deepset-ai/haystack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool_invoker.py
More file actions
860 lines (733 loc) · 38.4 KB
/
tool_invoker.py
File metadata and controls
860 lines (733 loc) · 38.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import contextvars
import inspect
import json
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from haystack.components.agents import State
from haystack.core.component.component import component
from haystack.core.component.sockets import Sockets
from haystack.core.serialization import default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
from haystack.tools import (
ComponentTool,
Tool,
ToolsType,
_check_duplicate_tool_names,
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
warm_up_tools,
)
from haystack.tools.errors import ToolInvocationError
from haystack.tracing.utils import _serializable_value
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
logger = logging.getLogger(__name__)
class ToolInvokerError(Exception):
"""Base exception class for ToolInvoker errors."""
def __init__(self, message: str) -> None:
super().__init__(message)
class ToolNotFoundException(ToolInvokerError):
"""Exception raised when a tool is not found in the list of available tools."""
def __init__(self, tool_name: str, available_tools: list[str]) -> None:
message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
super().__init__(message)
class StringConversionError(ToolInvokerError):
"""Exception raised when the conversion of a tool result to a string fails."""
def __init__(self, tool_name: str, conversion_function: str, error: Exception) -> None:
message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
super().__init__(message)
class ResultConversionError(ToolInvokerError):
"""Exception raised when the conversion of a tool output to a result fails."""
def __init__(self, tool_name: str, conversion_function: str, error: Exception) -> None:
message = f"Failed to convert tool output from tool {tool_name} using '{conversion_function}'. Error: {error}"
super().__init__(message)
class ToolOutputMergeError(ToolInvokerError):
"""Exception raised when merging tool outputs into state fails."""
@classmethod
def from_exception(cls, tool_name: str, error: Exception) -> "ToolOutputMergeError":
"""
Create a ToolOutputMergeError from an exception.
"""
message = f"Failed to merge tool outputs from tool {tool_name} into State: {error}"
return cls(message)
@component
class ToolInvoker:
"""
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
Also handles reading/writing from a shared `State`.
At initialization, the ToolInvoker component is provided with a list of available tools.
At runtime, the component processes a list of ChatMessage object containing tool calls
and invokes the corresponding tools.
The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
Usage example:
```python
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool
from haystack.components.tools import ToolInvoker
# Tool definition
def dummy_weather_function(city: str):
return f"The weather in {city} is 20 degrees."
parameters = {"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]}
tool = Tool(name="weather_tool",
description="A tool to get the weather",
function=dummy_weather_function,
parameters=parameters)
# Usually, the ChatMessage with tool_calls is generated by a Language Model
# Here, we create it manually for demonstration purposes
tool_call = ToolCall(
tool_name="weather_tool",
arguments={"city": "Berlin"}
)
message = ChatMessage.from_assistant(tool_calls=[tool_call])
# ToolInvoker initialization and run
invoker = ToolInvoker(tools=[tool])
result = invoker.run(messages=[message])
print(result)
```
```
>> {
>> 'tool_messages': [
>> ChatMessage(
>> _role=<ChatRole.TOOL: 'tool'>,
>> _content=[
>> ToolCallResult(
>> result='"The weather in Berlin is 20 degrees."',
>> origin=ToolCall(
>> tool_name='weather_tool',
>> arguments={'city': 'Berlin'},
>> id=None
>> )
>> )
>> ],
>> _meta={}
>> )
>> ]
>> }
```
Usage example with a Toolset:
```python
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool, Toolset
from haystack.components.tools import ToolInvoker
# Tool definition
def dummy_weather_function(city: str):
return f"The weather in {city} is 20 degrees."
parameters = {"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]}
tool = Tool(name="weather_tool",
description="A tool to get the weather",
function=dummy_weather_function,
parameters=parameters)
# Create a Toolset
toolset = Toolset([tool])
# Usually, the ChatMessage with tool_calls is generated by a Language Model
# Here, we create it manually for demonstration purposes
tool_call = ToolCall(
tool_name="weather_tool",
arguments={"city": "Berlin"}
)
message = ChatMessage.from_assistant(tool_calls=[tool_call])
# ToolInvoker initialization and run with Toolset
invoker = ToolInvoker(tools=toolset)
result = invoker.run(messages=[message])
print(result)
"""
def __init__(
self,
tools: ToolsType,
raise_on_failure: bool = True,
convert_result_to_json_string: bool = False,
streaming_callback: StreamingCallbackT | None = None,
*,
enable_streaming_callback_passthrough: bool = False,
max_workers: int = 4,
) -> None:
"""
Initialize the ToolInvoker component.
:param tools:
A list of Tool and/or Toolset objects, or a Toolset instance that can resolve tools.
:param raise_on_failure:
If True, the component will raise an exception in case of errors
(tool not found, tool invocation errors, tool result conversion errors).
If False, the component will return a ChatMessage object with `error=True`
and a description of the error in `result`.
:param convert_result_to_json_string:
If True, the tool invocation result will be converted to a string using `json.dumps`.
If False, the tool invocation result will be converted to a string using `str`.
:param streaming_callback:
A callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available — it is not
streamed incrementally in real time.
:param enable_streaming_callback_passthrough:
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
:param max_workers:
The maximum number of workers to use in the thread pool executor.
This also decides the maximum number of concurrent tool invocations.
:raises ValueError:
If no tools are provided or if duplicate tool names are found.
"""
self.tools = tools
self.streaming_callback = streaming_callback
self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
self.max_workers = max_workers
self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string
self._tools_with_names = self._validate_and_prepare_tools(tools)
self._is_warmed_up = False
@staticmethod
def _make_context_bound_invoke(tool_to_invoke: Tool, final_args: dict[str, Any]) -> Callable[[], Any]:
"""
Create a zero-arg callable that invokes the tool under the caller's contextvars Context.
We copy and use contextvars to preserve the caller’s ambient execution context (for example the active
tracing Span) across thread boundaries. Python’s contextvars do not automatically propagate to worker
threads (or to threadpool tasks spawned via run_in_executor), so without intervention nested tool calls
would lose their parent trace/span and appear as separate roots. By capturing the current Context in the
caller thread and invoking the tool under ctx.run(...) inside the executor, we ensure proper span parentage,
consistent tagging, and reliable log/trace correlation in both sync and async paths. The callable returns
ToolInvocationError instead of raising so parallel execution can collect failures.
"""
ctx = contextvars.copy_context()
def _runner() -> Any:
try:
return ctx.run(partial(tool_to_invoke.invoke, **final_args))
except ToolInvocationError as e:
return e
return _runner
@staticmethod
def _validate_and_prepare_tools(tools: ToolsType) -> dict[str, Tool]:
"""
Validates and prepares tools for use by the ToolInvoker.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
:returns: A dictionary mapping tool names to Tool instances.
:raises ValueError: If no tools are provided or if duplicate tool names are found.
"""
if not tools:
raise ValueError("ToolInvoker requires at least one tool.")
converted_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(converted_tools)
tool_names = [tool.name for tool in converted_tools]
duplicates = {name for name in tool_names if tool_names.count(name) > 1}
if duplicates:
raise ValueError(f"Duplicate tool names found: {duplicates}")
return dict(zip(tool_names, converted_tools, strict=True))
def _default_output_to_string_handler(self, result: Any) -> str:
"""
Default handler for converting a tool result to a string.
:param result: The tool result to convert to a string.
:returns: The converted tool result as a string.
"""
# We iterate through all items in result and call to_dict() if present
# Relevant for a few reasons:
# - If using convert_result_to_json_string we'd rather convert Haystack objects to JSON serializable dicts
# - If using default str() we prefer converting Haystack objects to dicts rather than relying on the
# __repr__ method
serializable = _serializable_value(value=result, use_placeholders=False)
if self.convert_result_to_json_string:
try:
# We disable ensure_ascii so special chars like emojis are not converted
str_result = json.dumps(serializable, ensure_ascii=False)
except Exception as error:
# If the result is not JSON serializable, we fall back to str
logger.warning(
"Tool result is not JSON serializable. Falling back to str conversion. "
"Result: {result}\nError: {err}",
result=result,
err=error,
)
str_result = str(result)
return str_result
return str(serializable)
def _process_output(self, config: dict[str, Any], result: Any, tool_call: ToolCall) -> Any:
"""
Processes a tool result based on the provided configuration.
:param config: Configuration dictionary that may contain "source", "handler", and "raw_result" keys.
:param result: The tool result to process.
:param tool_call: The ToolCall object for error reporting.
:returns: The processed tool result.
"""
source_key = config.get("source")
# If a source key is provided, we extract the result from the source key
value = result.get(source_key) if source_key is not None and isinstance(result, dict) else result
handler = config.get("handler")
raw_result = config.get("raw_result", False)
if handler is None:
if raw_result:
return value
handler = self._default_output_to_string_handler
try:
return handler(value)
except Exception as e:
if raw_result:
raise ResultConversionError(tool_call.tool_name, handler.__name__, e) from e
raise StringConversionError(tool_call.tool_name, handler.__name__, e) from e
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
"""
Prepares a ChatMessage with the result of a tool invocation.
:param result:
The tool result.
:param tool_call:
The ToolCall object containing the tool name and arguments.
:param tool_to_invoke:
The Tool object that was invoked.
:returns:
A ChatMessage object containing the tool result.
:raises
StringConversionError: If the conversion to string of the tool output fails and `raise_on_failure` is True.
ResultConversionError: If the conversion to result of the tool output fails and `raise_on_failure` is True.
"""
outputs_config = tool_to_invoke.outputs_to_string or {}
try:
# Root level single output configuration
if (
not outputs_config
or "source" in outputs_config
or "handler" in outputs_config
or "raw_result" in outputs_config
):
tool_result = self._process_output(outputs_config, result, tool_call)
return ChatMessage.from_tool(tool_result=tool_result, origin=tool_call)
# Multiple outputs configuration
tool_result_dict = {}
for output_key, config in outputs_config.items():
# For multiple outputs, we don't support raw_result and always convert to string
tool_result_dict[output_key] = self._process_output({**config, "raw_result": False}, result, tool_call)
tool_result_str = self._default_output_to_string_handler(tool_result_dict)
return ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
except (StringConversionError, ResultConversionError) as e:
if self.raise_on_failure:
raise e
logger.exception("{error_exception}", error_exception=e)
return ChatMessage.from_tool(tool_result=str(e), origin=tool_call, error=True)
@staticmethod
def _get_func_params(tool: Tool) -> set:
"""
Returns the function parameters of the tool's invoke method.
This method inspects the tool's function signature to determine which parameters the tool accepts.
"""
# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
# to find out which parameters the tool accepts.
if isinstance(tool, ComponentTool):
# mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
assert hasattr(tool._component, "__haystack_input__") and isinstance(
tool._component.__haystack_input__, Sockets
)
func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
else:
func_params = set(inspect.signature(tool.function).parameters.keys())
return func_params
@staticmethod
def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> dict[str, Any]:
"""
Combine LLM-provided arguments (llm_args) with state-based arguments.
Tool arguments take precedence in the following order:
- LLM overrides state if the same param is present in both
- local tool.inputs_from_state mappings (if any)
- function signature name matching
"""
final_args = dict(llm_args) # start with LLM-provided
func_params = ToolInvoker._get_func_params(tool)
# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
param_mappings = tool.inputs_from_state
else:
param_mappings = {name: name for name in func_params}
# Populate final_args from state if not provided by LLM
for state_key, param_name in param_mappings.items():
if param_name not in final_args and state.has(state_key):
final_args[param_name] = state.get(state_key)
return final_args
@staticmethod
def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
"""
Merges the tool result into the State.
This method processes the output of a tool execution and integrates it into the global state.
It also determines what message, if any, should be returned for further processing in a conversation.
Processing Steps:
1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
The return value in this case is simply the full `result` dictionary.
3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
handlers.
:param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
:param result: The output from tool execution. Can be a dictionary, or any other type.
:param state: The global State object to which results should be merged.
:returns: Three possible values:
- A string message for conversation
- The merged result dictionary
- Or the raw result if not a dictionary
"""
# If result is not a dictionary we exit
if not isinstance(result, dict):
return
# If there is no specific `outputs_to_state` mapping, we exit
if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
return
# Update the state with the tool outputs
for state_key, config in tool.outputs_to_state.items():
# Get the source key from the output config, otherwise use the entire result
source_key = config.get("source", None)
output_value = result.get(source_key) if source_key else result
# Skip state update when the tool didn't produce this output key
if output_value is None:
continue
# Merge other outputs into the state
state.set(state_key, output_value, handler_override=config.get("handler"))
@staticmethod
def _create_tool_result_streaming_chunk(tool_messages: list[ChatMessage], tool_call: ToolCall) -> StreamingChunk:
"""Create a streaming chunk for a tool result."""
return StreamingChunk(
content="",
index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
def _prepare_tool_call_params(
self,
*,
messages_with_tool_calls: list[ChatMessage],
state: State,
streaming_callback: StreamingCallbackT | None,
enable_streaming_passthrough: bool,
tools_with_names: dict[str, Tool],
) -> tuple[list[ToolCall], list[dict[str, Any]], list[ChatMessage]]:
"""
Prepare tool call parameters for execution and collect any error messages.
:param messages_with_tool_calls: Messages containing tool calls to process
:param state: The current state for argument injection
:param streaming_callback: Optional streaming callback to inject
:param enable_streaming_passthrough: Whether to pass streaming callback to tools
:returns: Tuple of (tool_calls, tool_call_params, error_messages)
"""
tool_call_params = []
error_messages = []
tool_calls = []
for message in messages_with_tool_calls:
for tool_call in message.tool_calls:
tool_name = tool_call.tool_name
# Check if the tool is available, otherwise return an error message
if tool_name not in tools_with_names:
error = ToolNotFoundException(tool_name, list(tools_with_names.keys()))
if self.raise_on_failure:
raise error
logger.error("{error_exception}", error_exception=error)
error_messages.append(ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True))
continue
tool_to_invoke = tools_with_names[tool_name]
# Combine user + state inputs
llm_args = tool_call.arguments.copy()
final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
# Check whether to inject streaming_callback
if (
enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
and "streaming_callback" in self._get_func_params(tool_to_invoke)
):
final_args["streaming_callback"] = streaming_callback
tool_call_params.append({"tool_to_invoke": tool_to_invoke, "final_args": final_args})
tool_calls.append(tool_call)
return tool_calls, tool_call_params, error_messages
def warm_up(self) -> None:
"""
Warm up the tool invoker.
This will warm up the tools registered in the tool invoker.
This method is idempotent and will only warm up the tools once.
"""
if not self._is_warmed_up:
warm_up_tools(self.tools)
# tools could have been updated by the warm_up, validate/prepare for invocation
self._tools_with_names = self._validate_and_prepare_tools(self.tools)
self._is_warmed_up = True
@component.output_types(tool_messages=list[ChatMessage], state=State)
def run(
self,
messages: list[ChatMessage],
state: State | None = None,
streaming_callback: StreamingCallbackT | None = None,
*,
enable_streaming_callback_passthrough: bool | None = None,
tools: ToolsType | None = None,
) -> dict[str, Any]:
"""
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
:param messages:
A list of ChatMessage objects.
:param state: The runtime state that should be used by the tools.
:param streaming_callback: A callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available — it is not
streamed incrementally in real time.
:param enable_streaming_callback_passthrough:
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
If None, the value from the constructor will be used.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
If set, it will override the `tools` parameter provided during initialization.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
:raises ToolNotFoundException:
If the tool is not found in the list of available tools and `raise_on_failure` is True.
:raises ToolInvocationError:
If the tool invocation fails and `raise_on_failure` is True.
:raises StringConversionError:
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
:raises ToolOutputMergeError:
If merging tool outputs into state fails and `raise_on_failure` is True.
"""
if not self._is_warmed_up:
self.warm_up()
tools_with_names = self._tools_with_names
if tools is not None:
tools_with_names = self._validate_and_prepare_tools(tools)
logger.debug(
"For this invocation, overriding constructor tools with: {tools}",
tools=", ".join(tools_with_names.keys()),
)
if state is None:
state = State(schema={})
resolved_enable_streaming_passthrough = (
enable_streaming_callback_passthrough
if enable_streaming_callback_passthrough is not None
else self.enable_streaming_callback_passthrough
)
# Only keep messages with tool calls
messages_with_tool_calls = [message for message in messages if message.tool_calls]
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
if not messages_with_tool_calls:
return {"tool_messages": [], "state": state}
# 1) Collect all tool calls and their parameters for parallel execution
tool_messages = []
tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
messages_with_tool_calls=messages_with_tool_calls,
state=state,
streaming_callback=streaming_callback,
enable_streaming_passthrough=resolved_enable_streaming_passthrough,
tools_with_names=tools_with_names,
)
tool_messages.extend(error_messages)
if not tool_call_params:
return {"tool_messages": tool_messages, "state": state}
# 2) Execute valid tool calls in parallel
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for params in tool_call_params:
callable_ = self._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
futures.append(executor.submit(callable_))
# 3) Gather and process results: handle errors and merge outputs into state
for future, tool_call in zip(futures, tool_calls, strict=True):
result = future.result()
if isinstance(result, ToolInvocationError):
# a) This is an error, create error Tool message
if self.raise_on_failure:
raise result
logger.error("{error_exception}", error_exception=result)
tool_messages.append(ChatMessage.from_tool(tool_result=str(result), origin=tool_call, error=True))
else:
# b) In case of success, merge outputs into state
try:
tool_to_invoke = tools_with_names[tool_call.tool_name]
self._merge_tool_outputs(tool=tool_to_invoke, result=result, state=state)
tool_messages.append(
self._prepare_tool_result_message(
result=result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)
except Exception as e:
error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
if self.raise_on_failure:
raise error from e
logger.exception("{error_exception}", error_exception=error)
tool_messages.append(
ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
)
# c) Handle streaming callback
if streaming_callback is not None:
streaming_callback(
self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
)
# We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
streaming_callback(
StreamingChunk(
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
)
)
return {"tool_messages": tool_messages, "state": state}
@component.output_types(tool_messages=list[ChatMessage], state=State)
async def run_async(
self,
messages: list[ChatMessage],
state: State | None = None,
streaming_callback: StreamingCallbackT | None = None,
*,
enable_streaming_callback_passthrough: bool | None = None,
tools: ToolsType | None = None,
) -> dict[str, Any]:
"""
Asynchronously processes ChatMessage objects containing tool calls.
Multiple tool calls are performed concurrently.
:param messages:
A list of ChatMessage objects.
:param state: The runtime state that should be used by the tools.
:param streaming_callback: An asynchronous callback function that will be called to emit tool results.
Note that the result is only emitted once it becomes available — it is not
streamed incrementally in real time.
:param enable_streaming_callback_passthrough:
If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
This allows tools to stream their results back to the client.
Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
If False, the `streaming_callback` will not be passed to the tool invocation.
If None, the value from the constructor will be used.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
If set, it will override the `tools` parameter provided during initialization.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
:raises ToolNotFoundException:
If the tool is not found in the list of available tools and `raise_on_failure` is True.
:raises ToolInvocationError:
If the tool invocation fails and `raise_on_failure` is True.
:raises StringConversionError:
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
:raises ToolOutputMergeError:
If merging tool outputs into state fails and `raise_on_failure` is True.
"""
if not self._is_warmed_up:
self.warm_up()
tools_with_names = self._tools_with_names
if tools is not None:
tools_with_names = self._validate_and_prepare_tools(tools)
logger.debug(
"For this invocation, overriding constructor tools with: {tools}",
tools=", ".join(tools_with_names.keys()),
)
if state is None:
state = State(schema={})
resolved_enable_streaming_passthrough = (
enable_streaming_callback_passthrough
if enable_streaming_callback_passthrough is not None
else self.enable_streaming_callback_passthrough
)
# Only keep messages with tool calls
messages_with_tool_calls = [message for message in messages if message.tool_calls]
streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
)
if not messages_with_tool_calls:
return {"tool_messages": [], "state": state}
# 1) Collect all tool calls and their parameters for parallel execution
tool_messages = []
tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
messages_with_tool_calls=messages_with_tool_calls,
state=state,
streaming_callback=streaming_callback,
enable_streaming_passthrough=resolved_enable_streaming_passthrough,
tools_with_names=tools_with_names,
)
tool_messages.extend(error_messages)
if not tool_call_params:
return {"tool_messages": tool_messages, "state": state}
# 2) Execute valid tool calls in parallel
tool_call_tasks = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
for params in tool_call_params:
loop = asyncio.get_running_loop()
callable_ = ToolInvoker._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
tool_call_tasks.append(loop.run_in_executor(executor, callable_))
# 3) Gather and process results: handle errors and merge outputs into state
tool_results = await asyncio.gather(*tool_call_tasks)
for tool_result, tool_call in zip(tool_results, tool_calls, strict=True):
# a) This is an error, create error Tool message
if isinstance(tool_result, ToolInvocationError):
if self.raise_on_failure:
raise tool_result
logger.error("{error_exception}", error_exception=tool_result)
tool_messages.append(
ChatMessage.from_tool(tool_result=str(tool_result), origin=tool_call, error=True)
)
else:
# b) In case of success, merge outputs into state
try:
tool_to_invoke = tools_with_names[tool_call.tool_name]
self._merge_tool_outputs(tool=tool_to_invoke, result=tool_result, state=state)
tool_messages.append(
self._prepare_tool_result_message(
result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
)
)
except Exception as e:
error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
if self.raise_on_failure:
raise error from e
logger.exception("{error_exception}", error_exception=error)
tool_messages.append(
ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
)
# c) Handle streaming callback
if streaming_callback is not None:
await streaming_callback(
self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
)
# 4) We stream one more chunk that contains a finish_reason if tool_messages were generated
if len(tool_messages) > 0 and streaming_callback is not None:
await streaming_callback(
StreamingChunk(
content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
)
)
return {"tool_messages": tool_messages, "state": state}
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
if self.streaming_callback is not None:
streaming_callback = serialize_callable(self.streaming_callback)
else:
streaming_callback = None
return default_to_dict(
self,
tools=serialize_tools_or_toolset(self.tools),
raise_on_failure=self.raise_on_failure,
convert_result_to_json_string=self.convert_result_to_json_string,
streaming_callback=streaming_callback,
enable_streaming_callback_passthrough=self.enable_streaming_callback_passthrough,
max_workers=self.max_workers,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ToolInvoker":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
if data["init_parameters"].get("streaming_callback") is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(
data["init_parameters"]["streaming_callback"]
)
return default_from_dict(cls, data)