Skip to content

Commit e1ae2f4

Browse files
authored
feat: Forward state to tools (#11064)
* adding injection of State or State | Optional with tests * refactor to add new util function to reduce duplicate code * Add reno * update state docs
1 parent a545c6a commit e1ae2f4

File tree

9 files changed

+330
-24
lines changed

9 files changed

+330
-24
lines changed

docs-website/docs/concepts/agents/state.mdx

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,91 @@ print(f"Processed {final_count} documents")
438438
print(final_docs)
439439
```
440440

441+
### Injecting State Directly into Tools
442+
443+
As an alternative to `inputs_from_state` and `outputs_to_state`, a tool can declare a parameter annotated as `State` to receive the live `State` object at invocation time.
444+
This lets the tool read from and write to any number of state keys without declaring mappings upfront.
445+
446+
For function-based tools, add a `State` parameter to the function and use the `@tool` decorator:
447+
448+
```python
449+
from typing import Annotated
450+
451+
from haystack.components.agents import Agent, State
452+
from haystack.components.generators.chat import OpenAIChatGenerator
453+
from haystack.dataclasses import ChatMessage, Document
454+
from haystack.tools import tool
455+
456+
457+
@tool
458+
def retrieve_and_store(
459+
query: Annotated[str, "The search query"],
460+
state: State,
461+
) -> str:
462+
"""Retrieve documents and store them directly in state."""
463+
documents = [Document(content=f"Result for '{query}'")]
464+
state.set("documents", documents)
465+
user_name = state.get("user_name", "unknown")
466+
return f"Retrieved {len(documents)} document(s) for {user_name}"
467+
468+
469+
agent = Agent(
470+
chat_generator=OpenAIChatGenerator(model="gpt-5-nano"),
471+
tools=[retrieve_and_store],
472+
state_schema={"documents": {"type": list[Document]}, "user_name": {"type": str}},
473+
)
474+
475+
result = agent.run(
476+
messages=[ChatMessage.from_user("Find documents about Python")],
477+
user_name="Alice",
478+
)
479+
```
480+
481+
For component-based tools, declare a `State` input socket on the component's `run` method and
482+
wrap it with `ComponentTool`:
483+
484+
```python
485+
from haystack import component
486+
from haystack.components.agents import Agent, State
487+
from haystack.components.generators.chat import OpenAIChatGenerator
488+
from haystack.dataclasses import ChatMessage, Document
489+
from haystack.tools import ComponentTool
490+
491+
492+
@component
493+
class DocumentRetriever:
494+
"""Retrieve documents and store them in state."""
495+
496+
@component.output_types(reply=str)
497+
def run(self, query: str, state: State) -> dict:
498+
"""
499+
Retrieve documents based on query and store them in state."
500+
501+
:param query: The search query
502+
"""
503+
documents = [Document(content=f"Result for '{query}'")]
504+
state.set("documents", documents)
505+
return {"reply": f"Retrieved {len(documents)} document(s)"}
506+
507+
508+
retriever_tool = ComponentTool(
509+
component=DocumentRetriever(),
510+
name="retrieve",
511+
description="Retrieve documents and store them in state",
512+
)
513+
514+
agent = Agent(
515+
chat_generator=OpenAIChatGenerator(model="gpt-5-nano"),
516+
tools=[retriever_tool],
517+
state_schema={"documents": {"type": list[Document]}},
518+
)
519+
520+
result = agent.run(messages=[ChatMessage.from_user("Find documents about Python")])
521+
```
522+
523+
`ToolInvoker` automatically injects the runtime `State` object and excludes the `State` parameter from the LLM-facing schema, so the model is never asked to supply it.
524+
Both `State` and `State | None` annotations are supported.
525+
441526
## Complete Example
442527

443528
This example shows a multi-tool agent workflow where tools share data through State:

haystack/components/tools/tool_invoker.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
warm_up_tools,
2929
)
3030
from haystack.tools.errors import ToolInvocationError
31+
from haystack.tools.parameters_schema_utils import _unwrap_optional
3132
from haystack.tracing.utils import _serializable_value
3233
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
3334

@@ -376,11 +377,13 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to
376377
return ChatMessage.from_tool(tool_result=str(e), origin=tool_call, error=True)
377378

378379
@staticmethod
379-
def _get_func_params(tool: Tool) -> set:
380+
def _get_func_params(tool: Tool) -> dict[str, Any]:
380381
"""
381-
Returns the function parameters of the tool's invoke method.
382+
Returns the function parameters with types of the tool's invoke method.
382383
383384
This method inspects the tool's function signature to determine which parameters the tool accepts.
385+
386+
:param tool: The tool for which to get the function parameters and their types.
384387
"""
385388
# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
386389
# to find out which parameters the tool accepts.
@@ -389,9 +392,13 @@ def _get_func_params(tool: Tool) -> set:
389392
assert hasattr(tool._component, "__haystack_input__") and isinstance(
390393
tool._component.__haystack_input__, Sockets
391394
)
392-
func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
395+
func_params = {
396+
name: socket.type for name, socket in tool._component.__haystack_input__._sockets_dict.items()
397+
}
393398
else:
394-
func_params = set(inspect.signature(tool.function).parameters.keys())
399+
func_params = {
400+
name: param.annotation for name, param in inspect.signature(tool.function).parameters.items()
401+
}
395402

396403
return func_params
397404

@@ -406,7 +413,7 @@ def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> di
406413
- function signature name matching
407414
"""
408415
final_args = dict(llm_args) # start with LLM-provided
409-
func_params = ToolInvoker._get_func_params(tool)
416+
func_params = ToolInvoker._get_func_params(tool).keys()
410417

411418
# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
412419
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
@@ -420,6 +427,11 @@ def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> di
420427
if param_name not in final_args and state.has(state_key):
421428
final_args[param_name] = state.get(state_key)
422429

430+
# Inject the live State object for any parameter annotated as State or Optional[State]
431+
for param_name, param_type in ToolInvoker._get_func_params(tool).items():
432+
if _unwrap_optional(param_type) is State:
433+
final_args[param_name] = state
434+
423435
return final_args
424436

425437
@staticmethod
@@ -528,7 +540,7 @@ def _prepare_tool_call_params(
528540
enable_streaming_passthrough
529541
and streaming_callback is not None
530542
and "streaming_callback" not in final_args
531-
and "streaming_callback" in self._get_func_params(tool_to_invoke)
543+
and "streaming_callback" in self._get_func_params(tool_to_invoke).keys()
532544
):
533545
final_args["streaming_callback"] = streaming_callback
534546

haystack/tools/component_tool.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from collections.abc import Callable
6-
from types import NoneType, UnionType
7-
from typing import Any, Union, get_args, get_origin
6+
from typing import Any, get_args, get_origin
87

98
from pydantic import Field, TypeAdapter, create_model
109

1110
from haystack import logging
11+
from haystack.components.agents.state.state import State
1212
from haystack.core.component import Component
1313
from haystack.core.serialization import (
1414
component_from_dict,
@@ -23,6 +23,7 @@
2323
_contains_callable_type,
2424
_get_component_param_descriptions,
2525
_resolve_type,
26+
_unwrap_optional,
2627
)
2728
from haystack.tools.tool import (
2829
_deserialize_outputs_to_state,
@@ -328,6 +329,10 @@ def _create_tool_parameters_schema(self, component: Component, inputs_from_state
328329
if _contains_callable_type(input_type):
329330
continue
330331

332+
# Skip State-typed parameters - ToolInvoker injects them at runtime
333+
if _unwrap_optional(input_type) is State:
334+
continue
335+
331336
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
332337

333338
# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
@@ -352,19 +357,6 @@ def _create_tool_parameters_schema(self, component: Component, inputs_from_state
352357

353358
return parameters_schema
354359

355-
def _unwrap_optional(self, _type: type) -> type:
356-
"""
357-
Unwrap Optional types to get the underlying type and whether it was originally optional.
358-
359-
:returns:
360-
The underlying type if `t` is `Optional[X]`, otherwise returns `t` unchanged.
361-
"""
362-
if get_origin(_type) is Union or get_origin(_type) is UnionType:
363-
non_none = [a for a in get_args(_type) if a is not NoneType]
364-
if len(non_none) == 1:
365-
return non_none[0]
366-
return _type
367-
368360
def _convert_param(self, param_value: Any, param_type: type) -> Any:
369361
"""
370362
Converts a single parameter value to the expected type.
@@ -376,7 +368,7 @@ def _convert_param(self, param_value: Any, param_type: type) -> Any:
376368
The converted parameter value.
377369
"""
378370
# We unwrap optional types so we can support types like messages: list[ChatMessage] | None
379-
unwrapped_param_type = self._unwrap_optional(param_type)
371+
unwrapped_param_type = _unwrap_optional(param_type)
380372

381373
# We support calling from_dict on target types that have it, even if they are wrapped in a list.
382374
# This allows us to support lists of dataclasses as well as single dataclass inputs.

haystack/tools/from_function.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
from pydantic import create_model
1010

11+
from haystack.components.agents.state.state import State
12+
1113
from .errors import SchemaGenerationError
12-
from .parameters_schema_utils import _contains_callable_type
14+
from .parameters_schema_utils import _contains_callable_type, _unwrap_optional
1315
from .tool import Tool
1416

1517

@@ -139,6 +141,10 @@ def get_weather(
139141
if inputs_from_state and param_name in inputs_from_state.values():
140142
continue
141143

144+
# Skip State-typed parameters (including Optional[State]) - ToolInvoker injects them at runtime
145+
if _unwrap_optional(param.annotation) is State:
146+
continue
147+
142148
if param.annotation is param.empty:
143149
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")
144150

haystack/tools/parameters_schema_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Callable as ABCCallable
99
from dataclasses import MISSING, fields, is_dataclass
1010
from inspect import getdoc
11+
from types import NoneType
1112
from typing import Any, Union, get_args, get_origin
1213

1314
from docstring_parser import parse
@@ -20,6 +21,21 @@
2021
logger = logging.getLogger(__name__)
2122

2223

24+
def _unwrap_optional(type_hint: Any) -> Any:
25+
"""
26+
Unwrap Optional types (i.e. ``X | None`` or ``Optional[X]``) to get the inner type.
27+
28+
:param type_hint: The type hint to unwrap.
29+
:returns: The inner type if ``type_hint`` is ``Optional[X]``, otherwise ``type_hint`` unchanged.
30+
"""
31+
origin = get_origin(type_hint)
32+
if origin is Union or origin is types.UnionType:
33+
non_none = [a for a in get_args(type_hint) if a is not NoneType]
34+
if len(non_none) == 1:
35+
return non_none[0]
36+
return type_hint
37+
38+
2339
def _contains_callable_type(type_hint: Any) -> bool:
2440
"""
2541
Check if a type hint contains a Callable type, including within Union types.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
---
2+
features:
3+
- |
4+
Tools and components can now declare a ``State`` (or ``State | None``) parameter in their
5+
signature to receive the live agent ``State`` object at invocation time — no extra wiring
6+
needed.
7+
8+
For function-based tools created with ``@tool`` or ``create_tool_from_function``, add a
9+
``state`` parameter annotated as ``State``:
10+
11+
.. code:: python
12+
13+
from haystack.components.agents import State
14+
from haystack.tools import tool
15+
16+
@tool
17+
def my_tool(query: str, state: State) -> str:
18+
"""Search using context from agent state."""
19+
history = state.get("history")
20+
...
21+
22+
For component-based tools created with ``ComponentTool``, declare a ``State`` input socket
23+
on the component's ``run`` method:
24+
25+
.. code:: python
26+
27+
from haystack import component
28+
from haystack.components.agents import State
29+
from haystack.tools import ComponentTool
30+
31+
@component
32+
class MyComponent:
33+
@component.output_types(result=str)
34+
def run(self, query: str, state: State) -> dict:
35+
history = state.get("history")
36+
...
37+
38+
tool = ComponentTool(component=MyComponent())
39+
40+
In both cases ``ToolInvoker`` automatically injects the runtime ``State`` object before
41+
calling the tool, and ``State``/``Optional[State]`` parameters are excluded from the
42+
LLM-facing schema so the model is not asked to supply them.
43+
44+
This is an alternative to the existing ``inputs_from_state`` and ``outputs_to_state``
45+
options on ``Tool`` and ``ComponentTool``, which map individual state keys to specific
46+
tool parameters and outputs declaratively. Injecting the full ``State`` object is more
47+
flexible and useful when a tool needs to read from or write to multiple keys, but it
48+
couples the tool implementation directly to ``State``.

0 commit comments

Comments
 (0)