From 74b73e41255a34c27ea7d8d6bf10b6ec369e6373 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 9 Jan 2026 13:57:01 +0100 Subject: [PATCH 1/9] Add state-based configuration support to MCPToolset --- .../tools/mcp/mcp_toolset.py | 162 +++++++++++- integrations/mcp/tests/test_mcp_toolset.py | 245 ++++++++++++++++-- 2 files changed, 391 insertions(+), 16 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 529e7e487c..d90aa378d1 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -11,6 +11,7 @@ from haystack import logging from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.tools import Tool, Toolset +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from .mcp_tool import ( AsyncExecutor, @@ -27,6 +28,84 @@ logger = logging.getLogger(__name__) +def _serialize_state_config(config: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]] | None: + """ + Serialize a state configuration dictionary, converting any callable handlers to their string representation. + + Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}}) + and outputs_to_string (tool_name -> {source, handler}). + + :param config: The state configuration dictionary to serialize + :returns: The serialized configuration dictionary, or None if empty + """ + if not config: + return None + + serialized = {} + for tool_name, tool_config in config.items(): + if not tool_config: + continue + + # Check if this is outputs_to_string format (flat with optional source/handler) + # or outputs_to_state format (nested with state keys) + if "source" in tool_config or "handler" in tool_config: + # outputs_to_string format: {source?, handler?} + serialized_tool_config = tool_config.copy() + if "handler" in tool_config and callable(tool_config["handler"]): + serialized_tool_config["handler"] = serialize_callable(tool_config["handler"]) + serialized[tool_name] = serialized_tool_config + else: + # outputs_to_state format: {state_key -> {source?, handler?}} + serialized_tool_config = {} + for state_key, state_config in tool_config.items(): + serialized_state_config = state_config.copy() + if "handler" in state_config and callable(state_config["handler"]): + serialized_state_config["handler"] = serialize_callable(state_config["handler"]) + serialized_tool_config[state_key] = serialized_state_config + serialized[tool_name] = serialized_tool_config + + return serialized if serialized else None + + +def _deserialize_state_config(config: dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]]: + """ + Deserialize a state configuration dictionary, converting any serialized handlers back to callables. + + Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}}) + and outputs_to_string (tool_name -> {source, handler}). + + :param config: The state configuration dictionary to deserialize + :returns: The deserialized configuration dictionary + """ + if not config: + return {} + + deserialized = {} + for tool_name, tool_config in config.items(): + if not tool_config: + continue + + # Check if this is outputs_to_string format (flat with optional source/handler) + # or outputs_to_state format (nested with state keys) + if "source" in tool_config or "handler" in tool_config: + # outputs_to_string format: {source?, handler?} + deserialized_tool_config = tool_config.copy() + if "handler" in tool_config and isinstance(tool_config["handler"], str): + deserialized_tool_config["handler"] = deserialize_callable(tool_config["handler"]) + deserialized[tool_name] = deserialized_tool_config + else: + # outputs_to_state format: {state_key -> {source?, handler?}} + deserialized_tool_config = {} + for state_key, state_config in tool_config.items(): + deserialized_state_config = state_config.copy() + if "handler" in state_config and isinstance(state_config["handler"], str): + deserialized_state_config["handler"] = deserialize_callable(state_config["handler"]) + deserialized_tool_config[state_key] = deserialized_state_config + deserialized[tool_name] = deserialized_tool_config + + return deserialized + + class MCPToolset(Toolset): """ A Toolset that connects to an MCP (Model Context Protocol) server and provides @@ -99,6 +178,30 @@ class MCPToolset(Toolset): # Use the toolset as shown in the pipeline example above ``` + Example with state configuration for Agent integration: + ```python + from haystack_integrations.tools.mcp import MCPToolset, StdioServerInfo + + # Create the toolset with per-tool state configuration + # This enables tools to read from and write to the Agent's State + toolset = MCPToolset( + server_info=StdioServerInfo(command="uvx", args=["mcp-server-git"]), + tool_names=["git_status", "git_diff", "git_log"], + + # Map state keys to tool parameters for each tool + inputs_from_state={ + "git_status": {"repository": "repo_path"}, + "git_diff": {"repository": "repo_path"}, + "git_log": {"repository": "repo_path"}, + }, + # Map tool outputs to state keys + outputs_to_state={ + "git_status": {"status_result": {"source": "status"}}, + "git_diff": {"diff_result": {}}, + }, + ) + ``` + Example using SSE (deprecated): ```python from haystack_integrations.tools.mcp import MCPToolset, SSEServerInfo @@ -121,6 +224,9 @@ def __init__( connection_timeout: float = 30.0, invocation_timeout: float = 30.0, eager_connect: bool = False, + inputs_from_state: dict[str, dict[str, str]] | None = None, + outputs_to_state: dict[str, dict[str, dict[str, Any]]] | None = None, + outputs_to_string: dict[str, dict[str, Any]] | None = None, ): """ Initialize the MCP toolset. @@ -132,6 +238,15 @@ def __init__( :param invocation_timeout: Default timeout in seconds for tool invocations :param eager_connect: If True, connect to server and load tools during initialization. If False (default), defer connection to warm_up. + :param inputs_from_state: Optional dictionary mapping tool names to their inputs_from_state config. + Each config maps state keys to tool parameter names. + Example: `{"git_status": {"repository": "repo_path"}}` + :param outputs_to_state: Optional dictionary mapping tool names to their outputs_to_state config. + Each config defines how tool outputs map to state keys with optional handlers. + Example: `{"git_status": {"status_result": {"source": "status"}}}` + :param outputs_to_string: Optional dictionary mapping tool names to their outputs_to_string config. + Each config defines how tool outputs are converted to strings. + Example: `{"git_diff": {"source": "diff", "handler": format_diff}}` :raises MCPToolNotFoundError: If any of the specified tool names are not found on the server """ # Store configuration @@ -140,6 +255,9 @@ def __init__( self.connection_timeout = connection_timeout self.invocation_timeout = invocation_timeout self.eager_connect = eager_connect + self.inputs_from_state = inputs_from_state or {} + self.outputs_to_state = outputs_to_state or {} + self.outputs_to_string = outputs_to_string or {} self._warmup_called = False if not eager_connect: @@ -226,9 +344,15 @@ def invoke_tool(**kwargs: Any) -> Any: description=tool_info.description or "", parameters=tool_info.inputSchema, function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout), + inputs_from_state=self.inputs_from_state.get(tool_info.name), + outputs_to_state=self.outputs_to_state.get(tool_info.name), + outputs_to_string=self.outputs_to_string.get(tool_info.name), ) haystack_tools.append(tool) + # Validate state configs reference known tools + self._validate_state_configs({tool.name for tool in haystack_tools}) + return haystack_tools except Exception as e: # We need to close because we could connect properly, retrieve tools yet @@ -292,6 +416,31 @@ def invoke_tool(**kwargs: Any) -> Any: raise MCPConnectionError(message=message, server_info=self.server_info, operation="initialize") from e + def _validate_state_configs(self, available_tool_names: set[str]) -> None: + """ + Validate that state configuration tool names exist in the toolset. + + Logs a warning for any tool names in the state configs that don't match + available tools in the toolset. + + :param available_tool_names: Set of tool names that are available in the toolset + """ + configs: list[tuple[str, dict[str, Any]]] = [ + ("inputs_from_state", self.inputs_from_state), + ("outputs_to_state", self.outputs_to_state), + ("outputs_to_string", self.outputs_to_string), + ] + for config_name, config in configs: + if config: + unknown_tools = set(config.keys()) - available_tool_names + if unknown_tools: + logger.warning( + "{config_name} references unknown tools: {unknown_tools}. Available tools: {available_tools}", + config_name=config_name, + unknown_tools=unknown_tools, + available_tools=available_tool_names, + ) + def to_dict(self) -> dict[str, Any]: """ Serialize the MCPToolset to a dictionary. @@ -306,6 +455,9 @@ def to_dict(self) -> dict[str, Any]: "connection_timeout": self.connection_timeout, "invocation_timeout": self.invocation_timeout, "eager_connect": self.eager_connect, + "inputs_from_state": self.inputs_from_state if self.inputs_from_state else None, + "outputs_to_state": _serialize_state_config(self.outputs_to_state), + "outputs_to_string": _serialize_state_config(self.outputs_to_string), }, } @@ -324,13 +476,21 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset": server_info_class = import_class_by_name(server_info_dict["type"]) server_info = cast(MCPServerInfo, server_info_class).from_dict(server_info_dict) + # Deserialize state configuration parameters + inputs_from_state = inner_data.get("inputs_from_state") or {} + outputs_to_state = _deserialize_state_config(inner_data.get("outputs_to_state")) + outputs_to_string = _deserialize_state_config(inner_data.get("outputs_to_string")) + # Create a new MCPToolset instance return cls( server_info=server_info, tool_names=inner_data.get("tool_names"), connection_timeout=inner_data.get("connection_timeout", 30.0), invocation_timeout=inner_data.get("invocation_timeout", 30.0), - eager_connect=inner_data.get("eager_connect", True), + eager_connect=inner_data.get("eager_connect", False), + inputs_from_state=inputs_from_state if inputs_from_state else None, + outputs_to_state=outputs_to_state if outputs_to_state else None, + outputs_to_string=outputs_to_string if outputs_to_string else None, ) def close(self): diff --git a/integrations/mcp/tests/test_mcp_toolset.py b/integrations/mcp/tests/test_mcp_toolset.py index 625df58a71..a043cf8259 100644 --- a/integrations/mcp/tests/test_mcp_toolset.py +++ b/integrations/mcp/tests/test_mcp_toolset.py @@ -7,7 +7,6 @@ import time from unittest.mock import patch -import haystack import pytest import pytest_asyncio from haystack import logging @@ -24,6 +23,10 @@ SSEServerInfo, StreamableHttpServerInfo, ) +from haystack_integrations.tools.mcp.mcp_toolset import ( + _deserialize_state_config, + _serialize_state_config, +) # Import in-memory transport and fixtures from .mcp_memory_transport import InMemoryServerInfo @@ -78,6 +81,38 @@ async def calculator_toolset_with_tool_filter(mcp_tool_cleanup): return mcp_tool_cleanup(toolset) +def format_result(result): + """Sample handler function for testing.""" + return f"FORMATTED: {result}" + + +@pytest_asyncio.fixture +async def calculator_toolset_with_state_config(mcp_tool_cleanup): + """Fixture that provides an MCPToolset with state configuration.""" + + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["add", "subtract"], + connection_timeout=45, + invocation_timeout=60, + eager_connect=True, + inputs_from_state={ + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + }, + outputs_to_state={ + "add": {"sum_result": {"source": "content"}}, + "subtract": {"diff_result": {}}, + }, + outputs_to_string={ + "add": {"source": "content", "handler": format_result}, + }, + ) + + return mcp_tool_cleanup(toolset) + + @pytest.mark.asyncio class TestMCPToolset: """Tests for the MCPToolset class.""" @@ -233,6 +268,91 @@ async def test_toolset_tool_not_found(self): eager_connect=True, ) + async def test_toolset_with_state_config(self, calculator_toolset_with_state_config): + """Test that MCPToolset correctly passes state configuration to tools.""" + toolset = calculator_toolset_with_state_config + + # Verify toolset has state configs stored + assert toolset.inputs_from_state == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert "add" in toolset.outputs_to_state + assert "subtract" in toolset.outputs_to_state + assert "add" in toolset.outputs_to_string + + # Verify tools have correct state configurations + add_tool = next(tool for tool in toolset.tools if tool.name == "add") + subtract_tool = next(tool for tool in toolset.tools if tool.name == "subtract") + + assert add_tool.inputs_from_state == {"first_number": "a"} + assert subtract_tool.inputs_from_state == {"first_number": "a", "second_number": "b"} + assert add_tool.outputs_to_state == {"sum_result": {"source": "content"}} + assert subtract_tool.outputs_to_state == {"diff_result": {}} + assert add_tool.outputs_to_string is not None + assert subtract_tool.outputs_to_string is None + + async def test_toolset_state_config_serde(self, calculator_toolset_with_state_config): + """Test serialization and deserialization of MCPToolset with state configuration.""" + toolset = calculator_toolset_with_state_config + + toolset_dict = toolset.to_dict() + + # Verify state configs are serialized + assert toolset_dict["data"]["inputs_from_state"] == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert toolset_dict["data"]["outputs_to_state"] is not None + assert toolset_dict["data"]["outputs_to_string"] is not None + # Handler should be serialized as a string + assert isinstance(toolset_dict["data"]["outputs_to_string"]["add"]["handler"], str) + + # Test deserialization + with patch("haystack_integrations.tools.mcp.mcp_toolset.MCPToolset.__init__", return_value=None) as mock_init: + MCPToolset.from_dict(toolset_dict) + + mock_init.assert_called_once() + _, kwargs = mock_init.call_args + assert kwargs["inputs_from_state"] == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert "add" in kwargs["outputs_to_state"] + assert "add" in kwargs["outputs_to_string"] + # Handler should be deserialized back to a callable + assert callable(kwargs["outputs_to_string"]["add"]["handler"]) + + async def test_toolset_state_config_unknown_tool_warning(self, caplog): + """Test that a warning is logged when state config references unknown tools.""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + + with caplog.at_level("WARNING"): + toolset = MCPToolset( + server_info=server_info, + tool_names=["add"], # Only include add + connection_timeout=10, + invocation_timeout=10, + eager_connect=True, + inputs_from_state={ + "add": {"first_number": "a"}, + "unknown_tool": {"some_key": "some_param"}, # This tool doesn't exist + }, + ) + + # The warning should be logged + assert any("unknown_tool" in record.message for record in caplog.records) + toolset.close() + + async def test_toolset_no_state_config(self, calculator_toolset): + """Test that tools have no state config when none is provided.""" + toolset = calculator_toolset + + for tool in toolset.tools: + assert tool.inputs_from_state is None + assert tool.outputs_to_state is None + assert tool.outputs_to_string is None + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.integration async def test_pipeline_warmup_with_mcp_toolset(self): @@ -464,19 +584,18 @@ def subtract(a: int, b: int) -> int: if os.path.exists(server_script_path): os.remove(server_script_path) - def test_pipeline_deserialization_fails_without_github_token(self, monkeypatch): + def test_pipeline_deserialization_succeeds_with_lazy_connection(self, monkeypatch): """ - Test that pipeline deserialization + MCPToolset initialization fails when GitHub - token is not resolved during deserialization. + Test that pipeline deserialization succeeds with lazy connection (eager_connect=False). - The issue: - - Setup: Agent pipeline template with MCPToolset with a token from env var (PERSONAL_ACCESS_TOKEN_GITHUB) - - MCPToolset tries to connect immediately during __init__ after validation - - Secrets get resolved during validation, after MCPToolset is initialized - - Connection fails because token can't be resolved in __init__ - - Pipeline deserialization fails with DeserializationError + With lazy connection (the default), MCPToolset defers connection until warm_up() is called. + This allows pipelines to be deserialized even when the server is not available or + credentials are not yet resolved. - This test demonstrates why we need warmup for MCPToolset on first use rather than during deserialization. + This test demonstrates that: + - Pipeline deserialization succeeds even with an invalid token + - MCPToolset creates a placeholder tool during initialization + - Actual connection happens later during warm_up() """ pipeline_yaml = """ components: @@ -528,7 +647,103 @@ def test_pipeline_deserialization_fails_without_github_token(self, monkeypatch): connections: [] """ monkeypatch.setenv("PERSONAL_ACCESS_TOKEN_GITHUB", "SOME_OBVIOUSLY_INVALID_TOKEN") - # Attempt to deserialize the pipeline - this will fail because MCPToolset - # tries to connect immediately and the token isn't available - with pytest.raises(haystack.core.errors.DeserializationError): - Pipeline.loads(pipeline_yaml) + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + + # Deserialization should succeed because eager_connect defaults to False + # With lazy connection, MCPToolset creates a placeholder tool and doesn't try to connect + pipeline = Pipeline.loads(pipeline_yaml) + + # Verify the pipeline was created successfully + assert pipeline is not None + agent = pipeline.get_component("agent") + assert agent is not None + + # The key point is that deserialization succeeded even with an invalid token + # because the connection is deferred until warm_up() is called + + +class TestStateConfigHelpers: + """Tests for the state configuration serialization helper functions.""" + + def test_serialize_outputs_to_string_with_handler(self): + """Test serializing outputs_to_string config with a handler function.""" + config = { + "add": {"source": "content", "handler": format_result}, + "subtract": {"source": "diff"}, + } + + serialized = _serialize_state_config(config) + + assert serialized is not None + assert "add" in serialized + assert "subtract" in serialized + assert isinstance(serialized["add"]["handler"], str) # Handler serialized to string + assert serialized["subtract"]["source"] == "diff" + assert "handler" not in serialized["subtract"] # No handler for subtract + + def test_serialize_outputs_to_state_with_handler(self): + """Test serializing outputs_to_state config with a handler function.""" + config = { + "add": { + "sum_result": {"source": "content", "handler": format_result}, + "raw_result": {}, + }, + } + + serialized = _serialize_state_config(config) + + assert serialized is not None + assert "add" in serialized + assert isinstance(serialized["add"]["sum_result"]["handler"], str) + assert serialized["add"]["raw_result"] == {} + + def test_serialize_empty_config(self): + """Test that empty config returns None.""" + assert _serialize_state_config({}) is None + assert _serialize_state_config(None) is None + + def test_deserialize_outputs_to_string_with_handler(self): + """Test deserializing outputs_to_string config with a handler function.""" + # First serialize to get the correct handler path + original = {"add": {"source": "content", "handler": format_result}} + serialized = _serialize_state_config(original) + + # Now deserialize + deserialized = _deserialize_state_config(serialized) + + assert "add" in deserialized + assert callable(deserialized["add"]["handler"]) + assert deserialized["add"]["source"] == "content" + + def test_deserialize_outputs_to_state_with_handler(self): + """Test deserializing outputs_to_state config with a handler function.""" + # First serialize to get the correct handler path + original = {"add": {"sum_result": {"source": "content", "handler": format_result}}} + serialized = _serialize_state_config(original) + + # Now deserialize + deserialized = _deserialize_state_config(serialized) + + assert "add" in deserialized + assert callable(deserialized["add"]["sum_result"]["handler"]) + + def test_deserialize_empty_config(self): + """Test that empty config returns empty dict.""" + assert _deserialize_state_config({}) == {} + assert _deserialize_state_config(None) == {} + + def test_roundtrip_serialization(self): + """Test that serialization and deserialization are inverse operations.""" + original = { + "add": {"source": "content", "handler": format_result}, + "subtract": {"source": "diff"}, + } + + serialized = _serialize_state_config(original) + deserialized = _deserialize_state_config(serialized) + + assert "add" in deserialized + assert "subtract" in deserialized + assert deserialized["add"]["source"] == "content" + assert callable(deserialized["add"]["handler"]) + assert deserialized["subtract"]["source"] == "diff" From 8f02af24cf90b92c369d13bf0ee50175e706df4b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 9 Jan 2026 14:34:23 +0100 Subject: [PATCH 2/9] Some final touches --- .../tools/mcp/mcp_toolset.py | 15 +++++++-- integrations/mcp/tests/test_mcp_toolset.py | 32 ++++++++++++++++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index d90aa378d1..160277d7a1 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -240,14 +240,23 @@ def __init__( If False (default), defer connection to warm_up. :param inputs_from_state: Optional dictionary mapping tool names to their inputs_from_state config. Each config maps state keys to tool parameter names. + Tool names should match available tools from the server; a warning is logged for + unknown tools. Note: With Haystack >= 2.22.0, parameter names are validated; + ValueError is raised for invalid parameters. With earlier versions, invalid + parameters fail at runtime. Example: `{"git_status": {"repository": "repo_path"}}` :param outputs_to_state: Optional dictionary mapping tool names to their outputs_to_state config. Each config defines how tool outputs map to state keys with optional handlers. + Tool names should match available tools from the server; a warning is logged for + unknown tools. Example: `{"git_status": {"status_result": {"source": "status"}}}` :param outputs_to_string: Optional dictionary mapping tool names to their outputs_to_string config. Each config defines how tool outputs are converted to strings. + Tool names should match available tools from the server; a warning is logged for + unknown tools. Example: `{"git_diff": {"source": "diff", "handler": format_diff}}` :raises MCPToolNotFoundError: If any of the specified tool names are not found on the server + :raises ValueError: If parameter names in inputs_from_state are invalid (Haystack >= 2.22.0 only) """ # Store configuration self.server_info = server_info @@ -356,11 +365,11 @@ def invoke_tool(**kwargs: Any) -> Any: return haystack_tools except Exception as e: # We need to close because we could connect properly, retrieve tools yet - # fail because of an MCPToolNotFoundError + # fail because of validation errors self.close() - if isinstance(e, MCPToolNotFoundError): - raise # re-raise MCPToolNotFoundError as is to show original message + if isinstance(e, (MCPToolNotFoundError, ValueError)): + raise # re-raise validation errors as is to show original message # Create informative error message for SSE connection errors # Common error handling for HTTP-based transports diff --git a/integrations/mcp/tests/test_mcp_toolset.py b/integrations/mcp/tests/test_mcp_toolset.py index a043cf8259..7b1eb004e3 100644 --- a/integrations/mcp/tests/test_mcp_toolset.py +++ b/integrations/mcp/tests/test_mcp_toolset.py @@ -324,7 +324,12 @@ async def test_toolset_state_config_serde(self, calculator_toolset_with_state_co assert callable(kwargs["outputs_to_string"]["add"]["handler"]) async def test_toolset_state_config_unknown_tool_warning(self, caplog): - """Test that a warning is logged when state config references unknown tools.""" + """Test that a warning is logged when state config references unknown tools. + + Note: This test validates unknown tool names at the MCPToolset level. + For parameter validation (unknown parameter names), see test_toolset_state_config_invalid_parameter_raises_error + which requires Haystack >= 2.22.0. + """ server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) with caplog.at_level("WARNING"): @@ -344,6 +349,31 @@ async def test_toolset_state_config_unknown_tool_warning(self, caplog): assert any("unknown_tool" in record.message for record in caplog.records) toolset.close() + @pytest.mark.skipif( + not hasattr(__import__("haystack.tools", fromlist=["Tool"]).Tool, "_get_valid_inputs"), + reason="Requires Haystack >= 2.22.0 for inputs_from_state validation", + ) + async def test_toolset_state_config_invalid_parameter_raises_error(self): + """Test that ValueError is raised when inputs_from_state references non-existent parameter. + + Requires Haystack >= 2.22.0 which validates inputs_from_state parameter names. + With Haystack < 2.22.0, this test is skipped and invalid parameter mappings will + only fail at runtime when the tool is invoked. + """ + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + + with pytest.raises(ValueError, match="unknown parameter"): + MCPToolset( + server_info=server_info, + tool_names=["add"], + connection_timeout=10, + invocation_timeout=10, + eager_connect=True, + inputs_from_state={ + "add": {"state_key": "non_existent_param"}, # 'add' tool has 'a' and 'b' parameters + }, + ) + async def test_toolset_no_state_config(self, calculator_toolset): """Test that tools have no state config when none is provided.""" toolset = calculator_toolset From 4fcce696d984d0003d4af9da9141dfa34aa7c368 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 12 Jan 2026 13:36:19 +0100 Subject: [PATCH 3/9] Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- .../mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 160277d7a1..f1c96fcc30 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -188,7 +188,7 @@ class MCPToolset(Toolset): server_info=StdioServerInfo(command="uvx", args=["mcp-server-git"]), tool_names=["git_status", "git_diff", "git_log"], - # Map state keys to tool parameters for each tool + # Maps the state key "repository" to the tool parameter "repo_path" for each tool inputs_from_state={ "git_status": {"repository": "repo_path"}, "git_diff": {"repository": "repo_path"}, From b2eb651f4a42b9f4b83990651e1fdba49c7526d1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 12 Jan 2026 13:36:29 +0100 Subject: [PATCH 4/9] Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- .../mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index f1c96fcc30..efc6267cca 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -194,7 +194,7 @@ class MCPToolset(Toolset): "git_diff": {"repository": "repo_path"}, "git_log": {"repository": "repo_path"}, }, - # Map tool outputs to state keys + # Map tool outputs to state keys for each tool outputs_to_state={ "git_status": {"status_result": {"source": "status"}}, "git_diff": {"diff_result": {}}, From e20efdd0d1b3a1cc4030e9c74c909397eb5e23e0 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 12 Jan 2026 13:37:07 +0100 Subject: [PATCH 5/9] Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- .../mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index efc6267cca..97ea825787 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -486,7 +486,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset": server_info = cast(MCPServerInfo, server_info_class).from_dict(server_info_dict) # Deserialize state configuration parameters - inputs_from_state = inner_data.get("inputs_from_state") or {} + inputs_from_state = inner_data.get("inputs_from_state") outputs_to_state = _deserialize_state_config(inner_data.get("outputs_to_state")) outputs_to_string = _deserialize_state_config(inner_data.get("outputs_to_string")) From d2cb4f0b39d4c213ea085d462dcfcf90d735a73e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 12 Jan 2026 13:59:19 +0100 Subject: [PATCH 6/9] PR touches --- .../mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 97ea825787..c30a656157 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -196,8 +196,8 @@ class MCPToolset(Toolset): }, # Map tool outputs to state keys for each tool outputs_to_state={ - "git_status": {"status_result": {"source": "status"}}, - "git_diff": {"diff_result": {}}, + "git_status": {"status_result": {"source": "status"}}, # Extract "status" from output + "git_diff": {"diff_result": {}}, # use full output with default handling }, ) ``` From 6887a0b34e47cfdee8f7f5d855d25c4c4fe8c364 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 15 Jan 2026 14:39:03 +0100 Subject: [PATCH 7/9] Add MCP tool/Agent state io integration test --- .../tools/mcp/mcp_tool.py | 54 ++++++++- .../tools/mcp/mcp_toolset.py | 33 +++++- .../mcp/tests/mcp_servers_fixtures.py | 8 +- .../mcp/tests/test_mcp_integration.py | 110 +++++++++++++++++- 4 files changed, 191 insertions(+), 14 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index d23c585c6c..4e0f8e663b 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import json import threading import warnings from abc import ABC, abstractmethod @@ -1048,12 +1049,13 @@ def _connect_and_initialize(self, tool_name: str) -> types.Tool: return tool - def _invoke_tool(self, **kwargs: Any) -> str: + def _invoke_tool(self, **kwargs: Any) -> str | dict[str, Any]: """ Synchronous tool invocation. :param kwargs: Arguments to pass to the tool - :returns: JSON string representation of the tool invocation result + :returns: JSON string or dictionary representation of the tool invocation result. + Returns a dictionary when outputs_to_state is configured to enable state updates. """ logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}") try: @@ -1070,6 +1072,26 @@ async def invoke(): logger.debug(f"TOOL: About to run invoke for '{self.name}'") result = AsyncExecutor.get_instance().run(invoke(), timeout=self._invocation_timeout) logger.debug(f"TOOL: Invoke complete for '{self.name}', result type: {type(result)}") + + # Parse JSON to dict only when outputs_to_state is configured. + # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. + if hasattr(self, "outputs_to_state") and self.outputs_to_state: + parsed = json.loads(result) + + # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. + # Parse only first TextContent block (ToolInvoker requires dict, not list). + content = parsed.get("content", []) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + # No TextContent found, return full parsed response as fallback + return parsed + return result except (MCPError, TimeoutError) as e: logger.debug(f"TOOL: Known error during invoke of '{self.name}': {e!s}") @@ -1081,19 +1103,41 @@ async def invoke(): message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}" raise MCPInvocationError(message, self.name, kwargs) from e - async def ainvoke(self, **kwargs: Any) -> str: + async def ainvoke(self, **kwargs: Any) -> str | dict[str, Any]: """ Asynchronous tool invocation. :param kwargs: Arguments to pass to the tool - :returns: JSON string representation of the tool invocation result + :returns: JSON string or dictionary representation of the tool invocation result. + Returns a dictionary when outputs_to_state is configured to enable state updates. :raises MCPInvocationError: If the tool invocation fails :raises TimeoutError: If the operation times out """ try: self.warm_up() client = cast(MCPClient, self._client) - return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) + result = await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) + + # Parse JSON to dict only when outputs_to_state is configured. + # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. + if hasattr(self, "outputs_to_state") and self.outputs_to_state: + parsed = json.loads(result) + + # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. + # Parse only first TextContent block (ToolInvoker requires dict, not list). + content = parsed.get("content", []) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + # No TextContent found, return full parsed response as fallback + return parsed + + return result except asyncio.TimeoutError as e: message = f"Tool invocation timed out after {self._invocation_timeout} seconds" raise TimeoutError(message) from e diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index c30a656157..536dc83e7a 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import json from collections.abc import Callable from typing import Any, cast from urllib.parse import urlparse @@ -326,14 +327,34 @@ def create_invoke_tool( mcp_client: MCPClient, tool_name: str, tool_timeout: float, + outputs_to_state: dict[str, Any] | None = None, ) -> Callable[..., Any]: """Return a closure that keeps a strong reference to *owner_toolset* alive.""" def invoke_tool(**kwargs: Any) -> Any: _ = owner_toolset # strong reference so GC can't collect the toolset too early - return AsyncExecutor.get_instance().run( + result = AsyncExecutor.get_instance().run( mcp_client.call_tool(tool_name, kwargs), timeout=tool_timeout ) + # Parse JSON to dict only when outputs_to_state is configured. + # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. + if outputs_to_state: + parsed = json.loads(result) + + # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. + # Parse only first TextContent block (ToolInvoker requires dict, not list). + content = parsed.get("content", []) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + # No TextContent found, return full parsed response as fallback + return parsed + return result return invoke_tool @@ -343,18 +364,22 @@ def invoke_tool(**kwargs: Any) -> Any: # Skip tools not in the tool_names list if tool_names is provided if self.tool_names is not None and tool_info.name not in self.tool_names: logger.debug( - "Skipping tool '{name}' as it's not in the requested tool_names list", name=tool_info.name + "Skipping tool '{tool_name}' as it's not in the requested tool_names list", + tool_name=tool_info.name, ) continue # Use the helper function to create the invoke_tool function + tool_outputs_to_state = self.outputs_to_state.get(tool_info.name) tool = Tool( name=tool_info.name, description=tool_info.description or "", parameters=tool_info.inputSchema, - function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout), + function=create_invoke_tool( + self, client, tool_info.name, self.invocation_timeout, tool_outputs_to_state + ), inputs_from_state=self.inputs_from_state.get(tool_info.name), - outputs_to_state=self.outputs_to_state.get(tool_info.name), + outputs_to_state=tool_outputs_to_state, outputs_to_string=self.outputs_to_string.get(tool_info.name), ) haystack_tools.append(tool) diff --git a/integrations/mcp/tests/mcp_servers_fixtures.py b/integrations/mcp/tests/mcp_servers_fixtures.py index ae988214fa..57cba8a458 100644 --- a/integrations/mcp/tests/mcp_servers_fixtures.py +++ b/integrations/mcp/tests/mcp_servers_fixtures.py @@ -8,15 +8,15 @@ @calculator_mcp.tool() -def add(a: int, b: int) -> int: +def add(a: int, b: int) -> dict: """Add two integers.""" - return a + b + return {"result": a + b} @calculator_mcp.tool() -def subtract(a: int, b: int) -> int: +def subtract(a: int, b: int) -> dict: """Subtract integer b from integer a.""" - return a - b + return {"result": a - b} @calculator_mcp.tool() diff --git a/integrations/mcp/tests/test_mcp_integration.py b/integrations/mcp/tests/test_mcp_integration.py index 484a34109b..3fd5732723 100644 --- a/integrations/mcp/tests/test_mcp_integration.py +++ b/integrations/mcp/tests/test_mcp_integration.py @@ -7,7 +7,9 @@ import time import pytest +import pytest_asyncio from haystack import Pipeline, logging +from haystack.components.agents import Agent from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole @@ -16,12 +18,13 @@ MCPConnectionError, MCPError, MCPTool, + MCPToolset, SSEServerInfo, StdioServerInfo, ) from .mcp_memory_transport import InMemoryServerInfo -from .mcp_servers_fixtures import echo_mcp +from .mcp_servers_fixtures import calculator_mcp, echo_mcp logger = logging.getLogger(__name__) @@ -242,3 +245,108 @@ def test_mcp_tool_error_handling_integration(self): assert any(text in error_message.lower() for text in ["failed", "connection", "initialize"]), ( f"Error message '{error_message}' should contain connection failure information" ) + + +@pytest_asyncio.fixture +async def calculator_toolset_with_state_config(mcp_tool_cleanup): + """Fixture that provides an MCPToolset with state configuration for integration testing. + + Configuration: + - add: No inputs_from_state, writes result to 'sum' state key + - subtract: Reads 'sum' from state (written by add), writes to 'difference' state key + """ + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["add", "subtract"], + connection_timeout=45, + invocation_timeout=60, + eager_connect=True, + inputs_from_state={ + # add tool takes normal parameters (not from state) + # subtract tool reads 'sum' from state and maps to parameter 'a' + "subtract": {"sum": "a"}, + }, + outputs_to_state={ + # Extract from structuredContent.result for both tools + "add": {"sum": {"source": "result"}}, + "subtract": {"difference": {"source": "result"}}, + }, + ) + return mcp_tool_cleanup(toolset) + + +@pytest.mark.integration +class TestMCPToolsetStateConfiguration: + """Integration tests for MCPToolset with state configuration features.""" + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + def test_toolset_with_multiple_tools_state_chaining(self, calculator_toolset_with_state_config): + """ + Test that outputs_to_state and inputs_from_state work with Agent state management. + + This test verifies the complete state propagation workflow in a single agent run: + 1. Agent calls add tool which writes 'sum' to state via outputs_to_state + 2. Agent calls subtract tool which reads 'sum' from state via inputs_from_state + + Both tools are called in sequence during a single agent execution, demonstrating + how tools communicate through Agent state. + """ + toolset = calculator_toolset_with_state_config + + # Verify state configurations + add_tool = next(tool for tool in toolset.tools if tool.name == "add") + subtract_tool = next(tool for tool in toolset.tools if tool.name == "subtract") + + assert add_tool.inputs_from_state is None # add takes normal parameters + assert add_tool.outputs_to_state == {"sum": {"source": "result"}} # writes sum to state + assert subtract_tool.inputs_from_state == {"sum": "a"} # reads 'sum' from state + assert subtract_tool.outputs_to_state == {"difference": {"source": "result"}} # writes difference to state + + # Create Agent with state schema + agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=toolset.tools, + state_schema={ + "sum": {"type": int}, + "difference": {"type": int}, + }, + ) + + # Create pipeline + pipeline = Pipeline() + pipeline.add_component("agent", agent) + + # Run agent - it will call both tools in sequence during this single execution + # 1. First, add tool calculates 20+5 and writes sum=25 to state + # 2. Then, subtract tool reads sum from state and calculates sum-10 + result = pipeline.run( + { + "agent": { + "messages": [ + ChatMessage.from_user( + "First, use the add tool to calculate 20 + 5. " + "Then use the subtract tool to subtract 10 from the result." + ) + ], + } + } + ) + + # Verify both state values were written by the tools + assert "sum" in result["agent"], "Expected 'sum' to be written to state by add tool's outputs_to_state" + sum_value = result["agent"]["sum"] + assert sum_value == 25, f"Expected sum=25 (20+5), got {sum_value}" + + assert "difference" in result["agent"], ( + "Expected 'difference' to be written to state by subtract tool's outputs_to_state" + ) + difference_value = result["agent"]["difference"] + assert difference_value == 15, f"Expected difference=15 (25-10), got {difference_value}" + + logger.info("✓ State propagation successful in single agent run!") + logger.info(f" - add(20, 5) wrote sum={sum_value} to state via outputs_to_state") + logger.info( + f" - subtract(sum={sum_value} via inputs_from_state, 10) " + f"wrote difference={difference_value} via outputs_to_state" + ) From fa89a100b2f7119a00db46816a927e87fb6be5a1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 15 Jan 2026 15:15:57 +0100 Subject: [PATCH 8/9] Test collision fixes --- .../tools/mcp/mcp_tool.py | 4 +- .../mcp/tests/mcp_servers_fixtures.py | 27 ++++++++-- .../mcp/tests/test_mcp_integration.py | 49 ++++++++----------- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index 4e0f8e663b..87c5bb47bd 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -1075,7 +1075,7 @@ async def invoke(): # Parse JSON to dict only when outputs_to_state is configured. # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. - if hasattr(self, "outputs_to_state") and self.outputs_to_state: + if self.outputs_to_state: parsed = json.loads(result) # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. @@ -1120,7 +1120,7 @@ async def ainvoke(self, **kwargs: Any) -> str | dict[str, Any]: # Parse JSON to dict only when outputs_to_state is configured. # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. - if hasattr(self, "outputs_to_state") and self.outputs_to_state: + if self.outputs_to_state: parsed = json.loads(result) # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. diff --git a/integrations/mcp/tests/mcp_servers_fixtures.py b/integrations/mcp/tests/mcp_servers_fixtures.py index 57cba8a458..d7d54bf2ee 100644 --- a/integrations/mcp/tests/mcp_servers_fixtures.py +++ b/integrations/mcp/tests/mcp_servers_fixtures.py @@ -8,15 +8,15 @@ @calculator_mcp.tool() -def add(a: int, b: int) -> dict: +def add(a: int, b: int) -> int: """Add two integers.""" - return {"result": a + b} + return a + b @calculator_mcp.tool() -def subtract(a: int, b: int) -> dict: +def subtract(a: int, b: int) -> int: """Subtract integer b from integer a.""" - return {"result": a - b} + return a - b @calculator_mcp.tool() @@ -25,6 +25,25 @@ def divide_by_zero(a: int) -> float: return a / 0 +################################################ +# State IO Calculator MCP Server (returns dicts for state propagation) +################################################ + +state_calculator_mcp = FastMCP("StateCalculator") + + +@state_calculator_mcp.tool() +def state_add(a: int, b: int) -> dict: + """Add two integers.""" + return {"result": a + b} + + +@state_calculator_mcp.tool() +def state_subtract(a: int, b: int) -> dict: + """Subtract integer b from integer a.""" + return {"result": a - b} + + ################################################ # Echo MCP Server ################################################ diff --git a/integrations/mcp/tests/test_mcp_integration.py b/integrations/mcp/tests/test_mcp_integration.py index 3fd5732723..bb6ab750a3 100644 --- a/integrations/mcp/tests/test_mcp_integration.py +++ b/integrations/mcp/tests/test_mcp_integration.py @@ -24,7 +24,7 @@ ) from .mcp_memory_transport import InMemoryServerInfo -from .mcp_servers_fixtures import calculator_mcp, echo_mcp +from .mcp_servers_fixtures import echo_mcp, state_calculator_mcp logger = logging.getLogger(__name__) @@ -255,22 +255,22 @@ async def calculator_toolset_with_state_config(mcp_tool_cleanup): - add: No inputs_from_state, writes result to 'sum' state key - subtract: Reads 'sum' from state (written by add), writes to 'difference' state key """ - server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + server_info = InMemoryServerInfo(server=state_calculator_mcp._mcp_server) toolset = MCPToolset( server_info=server_info, - tool_names=["add", "subtract"], + tool_names=["state_add", "state_subtract"], connection_timeout=45, invocation_timeout=60, eager_connect=True, inputs_from_state={ - # add tool takes normal parameters (not from state) - # subtract tool reads 'sum' from state and maps to parameter 'a' - "subtract": {"sum": "a"}, + # state_add tool takes normal parameters (not from state) + # state_subtract tool reads 'sum' from state and maps to parameter 'a' + "state_subtract": {"sum": "a"}, }, outputs_to_state={ - # Extract from structuredContent.result for both tools - "add": {"sum": {"source": "result"}}, - "subtract": {"difference": {"source": "result"}}, + # Extract from content[].text result for both tools + "state_add": {"sum": {"source": "result"}}, + "state_subtract": {"difference": {"source": "result"}}, }, ) return mcp_tool_cleanup(toolset) @@ -286,8 +286,8 @@ def test_toolset_with_multiple_tools_state_chaining(self, calculator_toolset_wit Test that outputs_to_state and inputs_from_state work with Agent state management. This test verifies the complete state propagation workflow in a single agent run: - 1. Agent calls add tool which writes 'sum' to state via outputs_to_state - 2. Agent calls subtract tool which reads 'sum' from state via inputs_from_state + 1. Agent calls state_add tool which writes 'sum' to state via outputs_to_state + 2. Agent calls state_subtract tool which reads 'sum' from state via inputs_from_state Both tools are called in sequence during a single agent execution, demonstrating how tools communicate through Agent state. @@ -295,10 +295,10 @@ def test_toolset_with_multiple_tools_state_chaining(self, calculator_toolset_wit toolset = calculator_toolset_with_state_config # Verify state configurations - add_tool = next(tool for tool in toolset.tools if tool.name == "add") - subtract_tool = next(tool for tool in toolset.tools if tool.name == "subtract") + add_tool = next(tool for tool in toolset.tools if tool.name == "state_add") + subtract_tool = next(tool for tool in toolset.tools if tool.name == "state_subtract") - assert add_tool.inputs_from_state is None # add takes normal parameters + assert add_tool.inputs_from_state is None # state_add takes normal parameters assert add_tool.outputs_to_state == {"sum": {"source": "result"}} # writes sum to state assert subtract_tool.inputs_from_state == {"sum": "a"} # reads 'sum' from state assert subtract_tool.outputs_to_state == {"difference": {"source": "result"}} # writes difference to state @@ -318,15 +318,15 @@ def test_toolset_with_multiple_tools_state_chaining(self, calculator_toolset_wit pipeline.add_component("agent", agent) # Run agent - it will call both tools in sequence during this single execution - # 1. First, add tool calculates 20+5 and writes sum=25 to state - # 2. Then, subtract tool reads sum from state and calculates sum-10 + # 1. First, state_add tool calculates 20+5 and writes sum=25 to state + # 2. Then, state_subtract tool reads sum from state and calculates sum-10 result = pipeline.run( { "agent": { "messages": [ ChatMessage.from_user( - "First, use the add tool to calculate 20 + 5. " - "Then use the subtract tool to subtract 10 from the result." + "First, use the state_add tool to calculate 20 + 5. " + "Then use the state_subtract tool to subtract 10 from the result." ) ], } @@ -334,19 +334,10 @@ def test_toolset_with_multiple_tools_state_chaining(self, calculator_toolset_wit ) # Verify both state values were written by the tools - assert "sum" in result["agent"], "Expected 'sum' to be written to state by add tool's outputs_to_state" + assert "sum" in result["agent"], "Expected 'sum' to be written to state by state_add tool" sum_value = result["agent"]["sum"] assert sum_value == 25, f"Expected sum=25 (20+5), got {sum_value}" - assert "difference" in result["agent"], ( - "Expected 'difference' to be written to state by subtract tool's outputs_to_state" - ) + assert "difference" in result["agent"], "Expected 'difference' to be written to state by state_subtract tool" difference_value = result["agent"]["difference"] assert difference_value == 15, f"Expected difference=15 (25-10), got {difference_value}" - - logger.info("✓ State propagation successful in single agent run!") - logger.info(f" - add(20, 5) wrote sum={sum_value} to state via outputs_to_state") - logger.info( - f" - subtract(sum={sum_value} via inputs_from_state, 10) " - f"wrote difference={difference_value} via outputs_to_state" - ) From 4b078f2b080e9a5c9d25e60a959e1a52af4d4fc4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 15 Jan 2026 15:31:43 +0100 Subject: [PATCH 9/9] PR feedback - mpangrazzi --- .../tools/mcp/mcp_toolset.py | 6 ++++- integrations/mcp/tests/test_mcp_toolset.py | 27 +++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 536dc83e7a..3d9479d695 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -29,13 +29,17 @@ logger = logging.getLogger(__name__) -def _serialize_state_config(config: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]] | None: +def _serialize_state_config(config: dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]] | None: """ Serialize a state configuration dictionary, converting any callable handlers to their string representation. Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}}) and outputs_to_string (tool_name -> {source, handler}). + Note: The keys "source" and "handler" are reserved and used internally to distinguish between + outputs_to_string format and outputs_to_state format. Do not use these as state keys in + outputs_to_state configurations. + :param config: The state configuration dictionary to serialize :returns: The serialized configuration dictionary, or None if empty """ diff --git a/integrations/mcp/tests/test_mcp_toolset.py b/integrations/mcp/tests/test_mcp_toolset.py index 7b1eb004e3..eaeba5ae2c 100644 --- a/integrations/mcp/tests/test_mcp_toolset.py +++ b/integrations/mcp/tests/test_mcp_toolset.py @@ -292,7 +292,7 @@ async def test_toolset_with_state_config(self, calculator_toolset_with_state_con assert add_tool.outputs_to_string is not None assert subtract_tool.outputs_to_string is None - async def test_toolset_state_config_serde(self, calculator_toolset_with_state_config): + async def test_toolset_state_config_serde(self, calculator_toolset_with_state_config, mcp_tool_cleanup): """Test serialization and deserialization of MCPToolset with state configuration.""" toolset = calculator_toolset_with_state_config @@ -308,20 +308,19 @@ async def test_toolset_state_config_serde(self, calculator_toolset_with_state_co # Handler should be serialized as a string assert isinstance(toolset_dict["data"]["outputs_to_string"]["add"]["handler"], str) - # Test deserialization - with patch("haystack_integrations.tools.mcp.mcp_toolset.MCPToolset.__init__", return_value=None) as mock_init: - MCPToolset.from_dict(toolset_dict) + # Test deserialization with full roundtrip + new_toolset = MCPToolset.from_dict(toolset_dict) + mcp_tool_cleanup(new_toolset) - mock_init.assert_called_once() - _, kwargs = mock_init.call_args - assert kwargs["inputs_from_state"] == { - "add": {"first_number": "a"}, - "subtract": {"first_number": "a", "second_number": "b"}, - } - assert "add" in kwargs["outputs_to_state"] - assert "add" in kwargs["outputs_to_string"] - # Handler should be deserialized back to a callable - assert callable(kwargs["outputs_to_string"]["add"]["handler"]) + # Verify state configs are correctly deserialized + assert new_toolset.inputs_from_state == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert "add" in new_toolset.outputs_to_state + assert "add" in new_toolset.outputs_to_string + # Handler should be deserialized back to a callable + assert callable(new_toolset.outputs_to_string["add"]["handler"]) async def test_toolset_state_config_unknown_tool_warning(self, caplog): """Test that a warning is logged when state config references unknown tools.