Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
warm_up_tools,
)
from haystack.tools.errors import ToolInvocationError
from haystack.tools.parameters_schema_utils import _unwrap_optional
from haystack.tracing.utils import _serializable_value
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

Expand Down Expand Up @@ -376,11 +377,13 @@ def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to
return ChatMessage.from_tool(tool_result=str(e), origin=tool_call, error=True)

@staticmethod
def _get_func_params(tool: Tool) -> set:
def _get_func_params(tool: Tool) -> dict[str, Any]:
"""
Returns the function parameters of the tool's invoke method.
Returns the function parameters with types of the tool's invoke method.

This method inspects the tool's function signature to determine which parameters the tool accepts.

:param tool: The tool for which to get the function parameters and their types.
"""
# ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
# to find out which parameters the tool accepts.
Expand All @@ -389,9 +392,13 @@ def _get_func_params(tool: Tool) -> set:
assert hasattr(tool._component, "__haystack_input__") and isinstance(
tool._component.__haystack_input__, Sockets
)
func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
func_params = {
name: socket.type for name, socket in tool._component.__haystack_input__._sockets_dict.items()
}
else:
func_params = set(inspect.signature(tool.function).parameters.keys())
func_params = {
name: param.annotation for name, param in inspect.signature(tool.function).parameters.items()
}

return func_params

Expand All @@ -406,7 +413,7 @@ def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> di
- function signature name matching
"""
final_args = dict(llm_args) # start with LLM-provided
func_params = ToolInvoker._get_func_params(tool)
func_params = ToolInvoker._get_func_params(tool).keys()

# Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
# Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
Expand All @@ -420,6 +427,11 @@ def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> di
if param_name not in final_args and state.has(state_key):
final_args[param_name] = state.get(state_key)

# Inject the live State object for any parameter annotated as State or Optional[State]
for param_name, param_type in ToolInvoker._get_func_params(tool).items():
if _unwrap_optional(param_type) is State:
final_args[param_name] = state

return final_args

@staticmethod
Expand Down Expand Up @@ -528,7 +540,7 @@ def _prepare_tool_call_params(
enable_streaming_passthrough
and streaming_callback is not None
and "streaming_callback" not in final_args
and "streaming_callback" in self._get_func_params(tool_to_invoke)
and "streaming_callback" in self._get_func_params(tool_to_invoke).keys()
):
final_args["streaming_callback"] = streaming_callback

Expand Down
24 changes: 8 additions & 16 deletions haystack/tools/component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from types import NoneType, UnionType
from typing import Any, Union, get_args, get_origin
from typing import Any, get_args, get_origin

from pydantic import Field, TypeAdapter, create_model

from haystack import logging
from haystack.components.agents.state.state import State
from haystack.core.component import Component
from haystack.core.serialization import (
component_from_dict,
Expand All @@ -23,6 +23,7 @@
_contains_callable_type,
_get_component_param_descriptions,
_resolve_type,
_unwrap_optional,
)
from haystack.tools.tool import (
_deserialize_outputs_to_state,
Expand Down Expand Up @@ -328,6 +329,10 @@ def _create_tool_parameters_schema(self, component: Component, inputs_from_state
if _contains_callable_type(input_type):
continue

# Skip State-typed parameters - ToolInvoker injects them at runtime
if _unwrap_optional(input_type) is State:
continue

description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")

# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
Expand All @@ -352,19 +357,6 @@ def _create_tool_parameters_schema(self, component: Component, inputs_from_state

return parameters_schema

def _unwrap_optional(self, _type: type) -> type:
"""
Unwrap Optional types to get the underlying type and whether it was originally optional.

:returns:
The underlying type if `t` is `Optional[X]`, otherwise returns `t` unchanged.
"""
if get_origin(_type) is Union or get_origin(_type) is UnionType:
non_none = [a for a in get_args(_type) if a is not NoneType]
if len(non_none) == 1:
return non_none[0]
return _type

def _convert_param(self, param_value: Any, param_type: type) -> Any:
"""
Converts a single parameter value to the expected type.
Expand All @@ -376,7 +368,7 @@ def _convert_param(self, param_value: Any, param_type: type) -> Any:
The converted parameter value.
"""
# We unwrap optional types so we can support types like messages: list[ChatMessage] | None
unwrapped_param_type = self._unwrap_optional(param_type)
unwrapped_param_type = _unwrap_optional(param_type)

# We support calling from_dict on target types that have it, even if they are wrapped in a list.
# This allows us to support lists of dataclasses as well as single dataclass inputs.
Expand Down
8 changes: 7 additions & 1 deletion haystack/tools/from_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from pydantic import create_model

from haystack.components.agents.state.state import State

from .errors import SchemaGenerationError
from .parameters_schema_utils import _contains_callable_type
from .parameters_schema_utils import _contains_callable_type, _unwrap_optional
from .tool import Tool


Expand Down Expand Up @@ -139,6 +141,10 @@ def get_weather(
if inputs_from_state and param_name in inputs_from_state.values():
continue

# Skip State-typed parameters (including Optional[State]) - ToolInvoker injects them at runtime
if _unwrap_optional(param.annotation) is State:
continue

if param.annotation is param.empty:
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")

Expand Down
16 changes: 16 additions & 0 deletions haystack/tools/parameters_schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Callable as ABCCallable
from dataclasses import MISSING, fields, is_dataclass
from inspect import getdoc
from types import NoneType
from typing import Any, Union, get_args, get_origin

from docstring_parser import parse
Expand All @@ -20,6 +21,21 @@
logger = logging.getLogger(__name__)


def _unwrap_optional(type_hint: Any) -> Any:
"""
Unwrap Optional types (i.e. ``X | None`` or ``Optional[X]``) to get the inner type.

:param type_hint: The type hint to unwrap.
:returns: The inner type if ``type_hint`` is ``Optional[X]``, otherwise ``type_hint`` unchanged.
"""
origin = get_origin(type_hint)
if origin is Union or origin is types.UnionType:
non_none = [a for a in get_args(type_hint) if a is not NoneType]
if len(non_none) == 1:
return non_none[0]
return type_hint


def _contains_callable_type(type_hint: Any) -> bool:
"""
Check if a type hint contains a Callable type, including within Union types.
Expand Down
48 changes: 48 additions & 0 deletions releasenotes/notes/forward-state-to-tools-87a2f96a39a495a6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
features:
- |
Tools and components can now declare a ``State`` (or ``State | None``) parameter in their
signature to receive the live agent ``State`` object at invocation time — no extra wiring
needed.

For function-based tools created with ``@tool`` or ``create_tool_from_function``, add a
``state`` parameter annotated as ``State``:

.. code:: python

from haystack.components.agents import State
from haystack.tools import tool

@tool
def my_tool(query: str, state: State) -> str:
"""Search using context from agent state."""
history = state.get("history")
...

For component-based tools created with ``ComponentTool``, declare a ``State`` input socket
on the component's ``run`` method:

.. code:: python

from haystack import component
from haystack.components.agents import State
from haystack.tools import ComponentTool

@component
class MyComponent:
@component.output_types(result=str)
def run(self, query: str, state: State) -> dict:
history = state.get("history")
...

tool = ComponentTool(component=MyComponent())

In both cases ``ToolInvoker`` automatically injects the runtime ``State`` object before
calling the tool, and ``State``/``Optional[State]`` parameters are excluded from the
LLM-facing schema so the model is not asked to supply them.

This is an alternative to the existing ``inputs_from_state`` and ``outputs_to_state``
options on ``Tool`` and ``ComponentTool``, which map individual state keys to specific
tool parameters and outputs declaratively. Injecting the full ``State`` object is more
flexible and useful when a tool needs to read from or write to multiple keys, but it
couples the tool implementation directly to ``State``.
79 changes: 79 additions & 0 deletions test/components/tools/test_tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,36 @@ def test_inject_state_args_param_in_state_and_llm(self, invoker):
args = invoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
assert args == {"location": "Paris"}

def test_inject_state_args_injects_state_object_for_state_annotated_param(self, invoker):
def function_with_state(city: str, state: State) -> str:
return f"Weather in {city}"

state_tool = Tool(
name="state_tool",
description="A tool that receives the live State object.",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=function_with_state,
)
state = State(schema={"city": {"type": str}}, data={"city": "Berlin"})
args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state)
assert args["city"] == "Paris"
assert args["state"] is state

def test_inject_state_args_injects_state_object_for_optional_state_annotated_param(self, invoker):
def function_with_optional_state(city: str, state: State | None = None) -> str:
return f"Weather in {city}"

state_tool = Tool(
name="state_tool",
description="A tool that receives an optional State object.",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=function_with_optional_state,
)
state = State(schema={})
args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state)
assert args["city"] == "Paris"
assert args["state"] is state


class TestToolInvokerSerde:
def test_to_dict(self, invoker, weather_tool):
Expand Down Expand Up @@ -754,6 +784,55 @@ async def streaming_callback(chunk: StreamingChunk) -> None:
assert "tool_messages" in result_2
assert len(result_2["tool_messages"]) == 3

def test_run_injects_state_object_into_tool(self):
received_state = {}

def function_with_state(city: str, state: State) -> str:
received_state["state"] = state
return f"Weather in {city}: sunny"

state_tool = Tool(
name="state_tool",
description="A tool that receives the live State object.",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=function_with_state,
)
invoker = ToolInvoker(tools=[state_tool])
state = State(schema={"city": {"type": str}})

tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = invoker.run(messages=[message], state=state)

assert len(result["tool_messages"]) == 1
assert not result["tool_messages"][0].tool_call_results[0].error
assert received_state["state"] is state

@pytest.mark.asyncio
async def test_run_async_injects_state_object_into_tool(self):
received_state = {}

def function_with_state(city: str, state: State) -> str:
received_state["state"] = state
return f"Weather in {city}: sunny"

state_tool = Tool(
name="state_tool",
description="A tool that receives the live State object.",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=function_with_state,
)
invoker = ToolInvoker(tools=[state_tool])
state = State(schema={"city": {"type": str}})

tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = await invoker.run_async(messages=[message], state=state)

assert len(result["tool_messages"]) == 1
assert not result["tool_messages"][0].tool_call_results[0].error
assert received_state["state"] is state


class TestToolInvokerErrorHandling:
def test_tool_not_found_error(self, invoker):
Expand Down
32 changes: 31 additions & 1 deletion test/tools/test_component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from openai.types.chat.chat_completion import Choice

from haystack import Pipeline, SuperComponent, component
from haystack.components.agents import Agent
from haystack.components.agents import Agent, State
from haystack.components.builders import PromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.tools import ToolInvoker
Expand Down Expand Up @@ -537,6 +537,36 @@ def test_from_component_with_callable_params_skipped(self, monkeypatch):
assert "streaming_callback" not in param_names
assert "messages" in param_names

def test_from_component_with_state_param_excluded_from_schema(self):
@component
class ComponentWithState:
"""A component that takes State as a direct input."""

@component.output_types(result=str)
def run(self, query: str, state: State) -> dict:
return {"result": query}

tool = ComponentTool(component=ComponentWithState(), name="state_comp", description="test")

param_names = list(tool.parameters.get("properties", {}).keys())
assert "state" not in param_names
assert "query" in param_names

def test_from_component_with_optional_state_param_excluded_from_schema(self):
@component
class ComponentWithOptionalState:
"""A component that takes Optional[State] as an input (e.g. ToolInvoker style)."""

@component.output_types(result=str)
def run(self, query: str, state: State | None = None) -> dict:
return {"result": query}

tool = ComponentTool(component=ComponentWithOptionalState(), name="opt_state_comp", description="test")

param_names = list(tool.parameters.get("properties", {}).keys())
assert "state" not in param_names
assert "query" in param_names

def test_component_invoker_with_agent(self):
"""Tests that Agent as a ComponentTool can be invoked when calling it with a list of dicts"""
agent = Agent(chat_generator=FakeChatGenerator(messages=[ChatMessage.from_assistant("Answer")]))
Expand Down
Loading
Loading