diff --git a/docs-website/docs/concepts/agents/state.mdx b/docs-website/docs/concepts/agents/state.mdx index f0ece37187..ce62df9c5d 100644 --- a/docs-website/docs/concepts/agents/state.mdx +++ b/docs-website/docs/concepts/agents/state.mdx @@ -438,6 +438,91 @@ print(f"Processed {final_count} documents") print(final_docs) ``` +### Injecting State Directly into Tools + +As an alternative to `inputs_from_state` and `outputs_to_state`, a tool can declare a parameter annotated as `State` to receive the live `State` object at invocation time. +This lets the tool read from and write to any number of state keys without declaring mappings upfront. + +For function-based tools, add a `State` parameter to the function and use the `@tool` decorator: + +```python +from typing import Annotated + +from haystack.components.agents import Agent, State +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage, Document +from haystack.tools import tool + + +@tool +def retrieve_and_store( + query: Annotated[str, "The search query"], + state: State, +) -> str: + """Retrieve documents and store them directly in state.""" + documents = [Document(content=f"Result for '{query}'")] + state.set("documents", documents) + user_name = state.get("user_name", "unknown") + return f"Retrieved {len(documents)} document(s) for {user_name}" + + +agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-5-nano"), + tools=[retrieve_and_store], + state_schema={"documents": {"type": list[Document]}, "user_name": {"type": str}}, +) + +result = agent.run( + messages=[ChatMessage.from_user("Find documents about Python")], + user_name="Alice", +) +``` + +For component-based tools, declare a `State` input socket on the component's `run` method and +wrap it with `ComponentTool`: + +```python +from haystack import component +from haystack.components.agents import Agent, State +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage, Document +from haystack.tools import ComponentTool + + +@component +class DocumentRetriever: + """Retrieve documents and store them in state.""" + + @component.output_types(reply=str) + def run(self, query: str, state: State) -> dict: + """ + Retrieve documents based on query and store them in state." + + :param query: The search query + """ + documents = [Document(content=f"Result for '{query}'")] + state.set("documents", documents) + return {"reply": f"Retrieved {len(documents)} document(s)"} + + +retriever_tool = ComponentTool( + component=DocumentRetriever(), + name="retrieve", + description="Retrieve documents and store them in state", +) + +agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-5-nano"), + tools=[retriever_tool], + state_schema={"documents": {"type": list[Document]}}, +) + +result = agent.run(messages=[ChatMessage.from_user("Find documents about Python")]) +``` + +`ToolInvoker` automatically injects the runtime `State` object and excludes the `State` parameter from the LLM-facing schema, so the model is never asked to supply it. +Both `State` and `State | None` annotations are supported. + ## Complete Example This example shows a multi-tool agent workflow where tools share data through State: diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 30cacc9bb3..522f9059ca 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -28,6 +28,7 @@ warm_up_tools, ) from haystack.tools.errors import ToolInvocationError +from haystack.tools.parameters_schema_utils import _unwrap_optional from haystack.tracing.utils import _serializable_value from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -376,11 +377,13 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to return ChatMessage.from_tool(tool_result=str(e), origin=tool_call, error=True) @staticmethod - def _get_func_params(tool: Tool) -> set: + def _get_func_params(tool: Tool) -> dict[str, Any]: """ - Returns the function parameters of the tool's invoke method. + Returns the function parameters with types of the tool's invoke method. This method inspects the tool's function signature to determine which parameters the tool accepts. + + :param tool: The tool for which to get the function parameters and their types. """ # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets # to find out which parameters the tool accepts. @@ -389,9 +392,13 @@ def _get_func_params(tool: Tool) -> set: assert hasattr(tool._component, "__haystack_input__") and isinstance( tool._component.__haystack_input__, Sockets ) - func_params = set(tool._component.__haystack_input__._sockets_dict.keys()) + func_params = { + name: socket.type for name, socket in tool._component.__haystack_input__._sockets_dict.items() + } else: - func_params = set(inspect.signature(tool.function).parameters.keys()) + func_params = { + name: param.annotation for name, param in inspect.signature(tool.function).parameters.items() + } return func_params @@ -406,7 +413,7 @@ def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> di - function signature name matching """ final_args = dict(llm_args) # start with LLM-provided - func_params = ToolInvoker._get_func_params(tool) + func_params = ToolInvoker._get_func_params(tool).keys() # Determine the source of parameter mappings (explicit tool inputs or direct function parameters) # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"} @@ -420,6 +427,11 @@ def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> di if param_name not in final_args and state.has(state_key): final_args[param_name] = state.get(state_key) + # Inject the live State object for any parameter annotated as State or Optional[State] + for param_name, param_type in ToolInvoker._get_func_params(tool).items(): + if _unwrap_optional(param_type) is State: + final_args[param_name] = state + return final_args @staticmethod @@ -528,7 +540,7 @@ def _prepare_tool_call_params( enable_streaming_passthrough and streaming_callback is not None and "streaming_callback" not in final_args - and "streaming_callback" in self._get_func_params(tool_to_invoke) + and "streaming_callback" in self._get_func_params(tool_to_invoke).keys() ): final_args["streaming_callback"] = streaming_callback diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py index 4f1a9e6645..9441f3d5c5 100644 --- a/haystack/tools/component_tool.py +++ b/haystack/tools/component_tool.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Callable -from types import NoneType, UnionType -from typing import Any, Union, get_args, get_origin +from typing import Any, get_args, get_origin from pydantic import Field, TypeAdapter, create_model from haystack import logging +from haystack.components.agents.state.state import State from haystack.core.component import Component from haystack.core.serialization import ( component_from_dict, @@ -23,6 +23,7 @@ _contains_callable_type, _get_component_param_descriptions, _resolve_type, + _unwrap_optional, ) from haystack.tools.tool import ( _deserialize_outputs_to_state, @@ -328,6 +329,10 @@ def _create_tool_parameters_schema(self, component: Component, inputs_from_state if _contains_callable_type(input_type): continue + # Skip State-typed parameters - ToolInvoker injects them at runtime + if _unwrap_optional(input_type) is State: + continue + description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") # if the parameter has not a default value, Pydantic requires an Ellipsis (...) @@ -352,19 +357,6 @@ def _create_tool_parameters_schema(self, component: Component, inputs_from_state return parameters_schema - def _unwrap_optional(self, _type: type) -> type: - """ - Unwrap Optional types to get the underlying type and whether it was originally optional. - - :returns: - The underlying type if `t` is `Optional[X]`, otherwise returns `t` unchanged. - """ - if get_origin(_type) is Union or get_origin(_type) is UnionType: - non_none = [a for a in get_args(_type) if a is not NoneType] - if len(non_none) == 1: - return non_none[0] - return _type - def _convert_param(self, param_value: Any, param_type: type) -> Any: """ Converts a single parameter value to the expected type. @@ -376,7 +368,7 @@ def _convert_param(self, param_value: Any, param_type: type) -> Any: The converted parameter value. """ # We unwrap optional types so we can support types like messages: list[ChatMessage] | None - unwrapped_param_type = self._unwrap_optional(param_type) + unwrapped_param_type = _unwrap_optional(param_type) # We support calling from_dict on target types that have it, even if they are wrapped in a list. # This allows us to support lists of dataclasses as well as single dataclass inputs. diff --git a/haystack/tools/from_function.py b/haystack/tools/from_function.py index ec21aa45a4..a16135234a 100644 --- a/haystack/tools/from_function.py +++ b/haystack/tools/from_function.py @@ -8,8 +8,10 @@ from pydantic import create_model +from haystack.components.agents.state.state import State + from .errors import SchemaGenerationError -from .parameters_schema_utils import _contains_callable_type +from .parameters_schema_utils import _contains_callable_type, _unwrap_optional from .tool import Tool @@ -139,6 +141,10 @@ def get_weather( if inputs_from_state and param_name in inputs_from_state.values(): continue + # Skip State-typed parameters (including Optional[State]) - ToolInvoker injects them at runtime + if _unwrap_optional(param.annotation) is State: + continue + if param.annotation is param.empty: raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") diff --git a/haystack/tools/parameters_schema_utils.py b/haystack/tools/parameters_schema_utils.py index 78ab6c5322..07608b5ad1 100644 --- a/haystack/tools/parameters_schema_utils.py +++ b/haystack/tools/parameters_schema_utils.py @@ -8,6 +8,7 @@ from collections.abc import Callable as ABCCallable from dataclasses import MISSING, fields, is_dataclass from inspect import getdoc +from types import NoneType from typing import Any, Union, get_args, get_origin from docstring_parser import parse @@ -20,6 +21,21 @@ logger = logging.getLogger(__name__) +def _unwrap_optional(type_hint: Any) -> Any: + """ + Unwrap Optional types (i.e. ``X | None`` or ``Optional[X]``) to get the inner type. + + :param type_hint: The type hint to unwrap. + :returns: The inner type if ``type_hint`` is ``Optional[X]``, otherwise ``type_hint`` unchanged. + """ + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + non_none = [a for a in get_args(type_hint) if a is not NoneType] + if len(non_none) == 1: + return non_none[0] + return type_hint + + def _contains_callable_type(type_hint: Any) -> bool: """ Check if a type hint contains a Callable type, including within Union types. diff --git a/releasenotes/notes/forward-state-to-tools-87a2f96a39a495a6.yaml b/releasenotes/notes/forward-state-to-tools-87a2f96a39a495a6.yaml new file mode 100644 index 0000000000..a9e098b44d --- /dev/null +++ b/releasenotes/notes/forward-state-to-tools-87a2f96a39a495a6.yaml @@ -0,0 +1,48 @@ +--- +features: + - | + Tools and components can now declare a ``State`` (or ``State | None``) parameter in their + signature to receive the live agent ``State`` object at invocation time — no extra wiring + needed. + + For function-based tools created with ``@tool`` or ``create_tool_from_function``, add a + ``state`` parameter annotated as ``State``: + + .. code:: python + + from haystack.components.agents import State + from haystack.tools import tool + + @tool + def my_tool(query: str, state: State) -> str: + """Search using context from agent state.""" + history = state.get("history") + ... + + For component-based tools created with ``ComponentTool``, declare a ``State`` input socket + on the component's ``run`` method: + + .. code:: python + + from haystack import component + from haystack.components.agents import State + from haystack.tools import ComponentTool + + @component + class MyComponent: + @component.output_types(result=str) + def run(self, query: str, state: State) -> dict: + history = state.get("history") + ... + + tool = ComponentTool(component=MyComponent()) + + In both cases ``ToolInvoker`` automatically injects the runtime ``State`` object before + calling the tool, and ``State``/``Optional[State]`` parameters are excluded from the + LLM-facing schema so the model is not asked to supply them. + + This is an alternative to the existing ``inputs_from_state`` and ``outputs_to_state`` + options on ``Tool`` and ``ComponentTool``, which map individual state keys to specific + tool parameters and outputs declaratively. Injecting the full ``State`` object is more + flexible and useful when a tool needs to read from or write to multiple keys, but it + couples the tool implementation directly to ``State``. diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py index d311101660..f08519c1b4 100644 --- a/test/components/tools/test_tool_invoker.py +++ b/test/components/tools/test_tool_invoker.py @@ -216,6 +216,36 @@ def test_inject_state_args_param_in_state_and_llm(self, invoker): args = invoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state) assert args == {"location": "Paris"} + def test_inject_state_args_injects_state_object_for_state_annotated_param(self, invoker): + def function_with_state(city: str, state: State) -> str: + return f"Weather in {city}" + + state_tool = Tool( + name="state_tool", + description="A tool that receives the live State object.", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=function_with_state, + ) + state = State(schema={"city": {"type": str}}, data={"city": "Berlin"}) + args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state) + assert args["city"] == "Paris" + assert args["state"] is state + + def test_inject_state_args_injects_state_object_for_optional_state_annotated_param(self, invoker): + def function_with_optional_state(city: str, state: State | None = None) -> str: + return f"Weather in {city}" + + state_tool = Tool( + name="state_tool", + description="A tool that receives an optional State object.", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=function_with_optional_state, + ) + state = State(schema={}) + args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state) + assert args["city"] == "Paris" + assert args["state"] is state + class TestToolInvokerSerde: def test_to_dict(self, invoker, weather_tool): @@ -754,6 +784,55 @@ async def streaming_callback(chunk: StreamingChunk) -> None: assert "tool_messages" in result_2 assert len(result_2["tool_messages"]) == 3 + def test_run_injects_state_object_into_tool(self): + received_state = {} + + def function_with_state(city: str, state: State) -> str: + received_state["state"] = state + return f"Weather in {city}: sunny" + + state_tool = Tool( + name="state_tool", + description="A tool that receives the live State object.", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=function_with_state, + ) + invoker = ToolInvoker(tools=[state_tool]) + state = State(schema={"city": {"type": str}}) + + tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + result = invoker.run(messages=[message], state=state) + + assert len(result["tool_messages"]) == 1 + assert not result["tool_messages"][0].tool_call_results[0].error + assert received_state["state"] is state + + @pytest.mark.asyncio + async def test_run_async_injects_state_object_into_tool(self): + received_state = {} + + def function_with_state(city: str, state: State) -> str: + received_state["state"] = state + return f"Weather in {city}: sunny" + + state_tool = Tool( + name="state_tool", + description="A tool that receives the live State object.", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=function_with_state, + ) + invoker = ToolInvoker(tools=[state_tool]) + state = State(schema={"city": {"type": str}}) + + tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + result = await invoker.run_async(messages=[message], state=state) + + assert len(result["tool_messages"]) == 1 + assert not result["tool_messages"][0].tool_call_results[0].error + assert received_state["state"] is state + class TestToolInvokerErrorHandling: def test_tool_not_found_error(self, invoker): diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py index 712864038c..79c0f2989b 100644 --- a/test/tools/test_component_tool.py +++ b/test/tools/test_component_tool.py @@ -13,7 +13,7 @@ from openai.types.chat.chat_completion import Choice from haystack import Pipeline, SuperComponent, component -from haystack.components.agents import Agent +from haystack.components.agents import Agent, State from haystack.components.builders import PromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.tools import ToolInvoker @@ -537,6 +537,36 @@ def test_from_component_with_callable_params_skipped(self, monkeypatch): assert "streaming_callback" not in param_names assert "messages" in param_names + def test_from_component_with_state_param_excluded_from_schema(self): + @component + class ComponentWithState: + """A component that takes State as a direct input.""" + + @component.output_types(result=str) + def run(self, query: str, state: State) -> dict: + return {"result": query} + + tool = ComponentTool(component=ComponentWithState(), name="state_comp", description="test") + + param_names = list(tool.parameters.get("properties", {}).keys()) + assert "state" not in param_names + assert "query" in param_names + + def test_from_component_with_optional_state_param_excluded_from_schema(self): + @component + class ComponentWithOptionalState: + """A component that takes Optional[State] as an input (e.g. ToolInvoker style).""" + + @component.output_types(result=str) + def run(self, query: str, state: State | None = None) -> dict: + return {"result": query} + + tool = ComponentTool(component=ComponentWithOptionalState(), name="opt_state_comp", description="test") + + param_names = list(tool.parameters.get("properties", {}).keys()) + assert "state" not in param_names + assert "query" in param_names + def test_component_invoker_with_agent(self): """Tests that Agent as a ComponentTool can be invoked when calling it with a list of dicts""" agent = Agent(chat_generator=FakeChatGenerator(messages=[ChatMessage.from_assistant("Answer")])) diff --git a/test/tools/test_from_function.py b/test/tools/test_from_function.py index 066747380e..68b8c008c7 100644 --- a/test/tools/test_from_function.py +++ b/test/tools/test_from_function.py @@ -7,6 +7,7 @@ import pytest +from haystack.components.agents.state import State from haystack.tools.errors import SchemaGenerationError from haystack.tools.from_function import _remove_title_from_schema, create_tool_from_function, tool from haystack.tools.tool import Tool @@ -114,6 +115,43 @@ def function_with_callback(query: str, callback: Callable[[str], None] | None = assert "query" in param_names +def test_from_function_state_param_excluded_from_schema(): + def function_with_state(city: str, state: State) -> str: + """Get weather for a city, with access to agent state.""" + return f"Weather in {city}: sunny" + + tool = create_tool_from_function(function=function_with_state) + + assert tool.name == "function_with_state" + param_names = list(tool.parameters.get("properties", {}).keys()) + assert "state" not in param_names + assert "city" in param_names + assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + + +def test_tool_decorator_state_param_excluded_from_schema(): + @tool + def function_with_state(city: str, state: State) -> str: + """Get weather for a city, with access to agent state.""" + return f"Weather in {city}: sunny" + + param_names = list(function_with_state.parameters.get("properties", {}).keys()) + assert "state" not in param_names + assert "city" in param_names + + +def test_from_function_optional_state_param_excluded_from_schema(): + def function_with_optional_state(city: str, state: State | None = None) -> str: + """Get weather for a city, optionally using agent state.""" + return f"Weather in {city}: sunny" + + tool = create_tool_from_function(function=function_with_optional_state) + + param_names = list(tool.parameters.get("properties", {}).keys()) + assert "state" not in param_names + assert "city" in param_names + + def test_tool_decorator(): @tool def get_weather(city: str) -> str: