diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index d23c585c6c..87c5bb47bd 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import json import threading import warnings from abc import ABC, abstractmethod @@ -1048,12 +1049,13 @@ def _connect_and_initialize(self, tool_name: str) -> types.Tool: return tool - def _invoke_tool(self, **kwargs: Any) -> str: + def _invoke_tool(self, **kwargs: Any) -> str | dict[str, Any]: """ Synchronous tool invocation. :param kwargs: Arguments to pass to the tool - :returns: JSON string representation of the tool invocation result + :returns: JSON string or dictionary representation of the tool invocation result. + Returns a dictionary when outputs_to_state is configured to enable state updates. """ logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}") try: @@ -1070,6 +1072,26 @@ async def invoke(): logger.debug(f"TOOL: About to run invoke for '{self.name}'") result = AsyncExecutor.get_instance().run(invoke(), timeout=self._invocation_timeout) logger.debug(f"TOOL: Invoke complete for '{self.name}', result type: {type(result)}") + + # Parse JSON to dict only when outputs_to_state is configured. + # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. + if self.outputs_to_state: + parsed = json.loads(result) + + # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. + # Parse only first TextContent block (ToolInvoker requires dict, not list). + content = parsed.get("content", []) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + # No TextContent found, return full parsed response as fallback + return parsed + return result except (MCPError, TimeoutError) as e: logger.debug(f"TOOL: Known error during invoke of '{self.name}': {e!s}") @@ -1081,19 +1103,41 @@ async def invoke(): message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}" raise MCPInvocationError(message, self.name, kwargs) from e - async def ainvoke(self, **kwargs: Any) -> str: + async def ainvoke(self, **kwargs: Any) -> str | dict[str, Any]: """ Asynchronous tool invocation. :param kwargs: Arguments to pass to the tool - :returns: JSON string representation of the tool invocation result + :returns: JSON string or dictionary representation of the tool invocation result. + Returns a dictionary when outputs_to_state is configured to enable state updates. :raises MCPInvocationError: If the tool invocation fails :raises TimeoutError: If the operation times out """ try: self.warm_up() client = cast(MCPClient, self._client) - return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) + result = await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) + + # Parse JSON to dict only when outputs_to_state is configured. + # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. + if self.outputs_to_state: + parsed = json.loads(result) + + # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. + # Parse only first TextContent block (ToolInvoker requires dict, not list). + content = parsed.get("content", []) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + # No TextContent found, return full parsed response as fallback + return parsed + + return result except asyncio.TimeoutError as e: message = f"Tool invocation timed out after {self._invocation_timeout} seconds" raise TimeoutError(message) from e diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 529e7e487c..3d9479d695 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import json from collections.abc import Callable from typing import Any, cast from urllib.parse import urlparse @@ -11,6 +12,7 @@ from haystack import logging from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.tools import Tool, Toolset +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from .mcp_tool import ( AsyncExecutor, @@ -27,6 +29,88 @@ logger = logging.getLogger(__name__) +def _serialize_state_config(config: dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]] | None: + """ + Serialize a state configuration dictionary, converting any callable handlers to their string representation. + + Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}}) + and outputs_to_string (tool_name -> {source, handler}). + + Note: The keys "source" and "handler" are reserved and used internally to distinguish between + outputs_to_string format and outputs_to_state format. Do not use these as state keys in + outputs_to_state configurations. + + :param config: The state configuration dictionary to serialize + :returns: The serialized configuration dictionary, or None if empty + """ + if not config: + return None + + serialized = {} + for tool_name, tool_config in config.items(): + if not tool_config: + continue + + # Check if this is outputs_to_string format (flat with optional source/handler) + # or outputs_to_state format (nested with state keys) + if "source" in tool_config or "handler" in tool_config: + # outputs_to_string format: {source?, handler?} + serialized_tool_config = tool_config.copy() + if "handler" in tool_config and callable(tool_config["handler"]): + serialized_tool_config["handler"] = serialize_callable(tool_config["handler"]) + serialized[tool_name] = serialized_tool_config + else: + # outputs_to_state format: {state_key -> {source?, handler?}} + serialized_tool_config = {} + for state_key, state_config in tool_config.items(): + serialized_state_config = state_config.copy() + if "handler" in state_config and callable(state_config["handler"]): + serialized_state_config["handler"] = serialize_callable(state_config["handler"]) + serialized_tool_config[state_key] = serialized_state_config + serialized[tool_name] = serialized_tool_config + + return serialized if serialized else None + + +def _deserialize_state_config(config: dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]]: + """ + Deserialize a state configuration dictionary, converting any serialized handlers back to callables. + + Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}}) + and outputs_to_string (tool_name -> {source, handler}). + + :param config: The state configuration dictionary to deserialize + :returns: The deserialized configuration dictionary + """ + if not config: + return {} + + deserialized = {} + for tool_name, tool_config in config.items(): + if not tool_config: + continue + + # Check if this is outputs_to_string format (flat with optional source/handler) + # or outputs_to_state format (nested with state keys) + if "source" in tool_config or "handler" in tool_config: + # outputs_to_string format: {source?, handler?} + deserialized_tool_config = tool_config.copy() + if "handler" in tool_config and isinstance(tool_config["handler"], str): + deserialized_tool_config["handler"] = deserialize_callable(tool_config["handler"]) + deserialized[tool_name] = deserialized_tool_config + else: + # outputs_to_state format: {state_key -> {source?, handler?}} + deserialized_tool_config = {} + for state_key, state_config in tool_config.items(): + deserialized_state_config = state_config.copy() + if "handler" in state_config and isinstance(state_config["handler"], str): + deserialized_state_config["handler"] = deserialize_callable(state_config["handler"]) + deserialized_tool_config[state_key] = deserialized_state_config + deserialized[tool_name] = deserialized_tool_config + + return deserialized + + class MCPToolset(Toolset): """ A Toolset that connects to an MCP (Model Context Protocol) server and provides @@ -99,6 +183,30 @@ class MCPToolset(Toolset): # Use the toolset as shown in the pipeline example above ``` + Example with state configuration for Agent integration: + ```python + from haystack_integrations.tools.mcp import MCPToolset, StdioServerInfo + + # Create the toolset with per-tool state configuration + # This enables tools to read from and write to the Agent's State + toolset = MCPToolset( + server_info=StdioServerInfo(command="uvx", args=["mcp-server-git"]), + tool_names=["git_status", "git_diff", "git_log"], + + # Maps the state key "repository" to the tool parameter "repo_path" for each tool + inputs_from_state={ + "git_status": {"repository": "repo_path"}, + "git_diff": {"repository": "repo_path"}, + "git_log": {"repository": "repo_path"}, + }, + # Map tool outputs to state keys for each tool + outputs_to_state={ + "git_status": {"status_result": {"source": "status"}}, # Extract "status" from output + "git_diff": {"diff_result": {}}, # use full output with default handling + }, + ) + ``` + Example using SSE (deprecated): ```python from haystack_integrations.tools.mcp import MCPToolset, SSEServerInfo @@ -121,6 +229,9 @@ def __init__( connection_timeout: float = 30.0, invocation_timeout: float = 30.0, eager_connect: bool = False, + inputs_from_state: dict[str, dict[str, str]] | None = None, + outputs_to_state: dict[str, dict[str, dict[str, Any]]] | None = None, + outputs_to_string: dict[str, dict[str, Any]] | None = None, ): """ Initialize the MCP toolset. @@ -132,7 +243,25 @@ def __init__( :param invocation_timeout: Default timeout in seconds for tool invocations :param eager_connect: If True, connect to server and load tools during initialization. If False (default), defer connection to warm_up. + :param inputs_from_state: Optional dictionary mapping tool names to their inputs_from_state config. + Each config maps state keys to tool parameter names. + Tool names should match available tools from the server; a warning is logged for + unknown tools. Note: With Haystack >= 2.22.0, parameter names are validated; + ValueError is raised for invalid parameters. With earlier versions, invalid + parameters fail at runtime. + Example: `{"git_status": {"repository": "repo_path"}}` + :param outputs_to_state: Optional dictionary mapping tool names to their outputs_to_state config. + Each config defines how tool outputs map to state keys with optional handlers. + Tool names should match available tools from the server; a warning is logged for + unknown tools. + Example: `{"git_status": {"status_result": {"source": "status"}}}` + :param outputs_to_string: Optional dictionary mapping tool names to their outputs_to_string config. + Each config defines how tool outputs are converted to strings. + Tool names should match available tools from the server; a warning is logged for + unknown tools. + Example: `{"git_diff": {"source": "diff", "handler": format_diff}}` :raises MCPToolNotFoundError: If any of the specified tool names are not found on the server + :raises ValueError: If parameter names in inputs_from_state are invalid (Haystack >= 2.22.0 only) """ # Store configuration self.server_info = server_info @@ -140,6 +269,9 @@ def __init__( self.connection_timeout = connection_timeout self.invocation_timeout = invocation_timeout self.eager_connect = eager_connect + self.inputs_from_state = inputs_from_state or {} + self.outputs_to_state = outputs_to_state or {} + self.outputs_to_string = outputs_to_string or {} self._warmup_called = False if not eager_connect: @@ -199,14 +331,34 @@ def create_invoke_tool( mcp_client: MCPClient, tool_name: str, tool_timeout: float, + outputs_to_state: dict[str, Any] | None = None, ) -> Callable[..., Any]: """Return a closure that keeps a strong reference to *owner_toolset* alive.""" def invoke_tool(**kwargs: Any) -> Any: _ = owner_toolset # strong reference so GC can't collect the toolset too early - return AsyncExecutor.get_instance().run( + result = AsyncExecutor.get_instance().run( mcp_client.call_tool(tool_name, kwargs), timeout=tool_timeout ) + # Parse JSON to dict only when outputs_to_state is configured. + # ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise. + if outputs_to_state: + parsed = json.loads(result) + + # Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc. + # Parse only first TextContent block (ToolInvoker requires dict, not list). + content = parsed.get("content", []) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + return text + + # No TextContent found, return full parsed response as fallback + return parsed + return result return invoke_tool @@ -216,27 +368,37 @@ def invoke_tool(**kwargs: Any) -> Any: # Skip tools not in the tool_names list if tool_names is provided if self.tool_names is not None and tool_info.name not in self.tool_names: logger.debug( - "Skipping tool '{name}' as it's not in the requested tool_names list", name=tool_info.name + "Skipping tool '{tool_name}' as it's not in the requested tool_names list", + tool_name=tool_info.name, ) continue # Use the helper function to create the invoke_tool function + tool_outputs_to_state = self.outputs_to_state.get(tool_info.name) tool = Tool( name=tool_info.name, description=tool_info.description or "", parameters=tool_info.inputSchema, - function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout), + function=create_invoke_tool( + self, client, tool_info.name, self.invocation_timeout, tool_outputs_to_state + ), + inputs_from_state=self.inputs_from_state.get(tool_info.name), + outputs_to_state=tool_outputs_to_state, + outputs_to_string=self.outputs_to_string.get(tool_info.name), ) haystack_tools.append(tool) + # Validate state configs reference known tools + self._validate_state_configs({tool.name for tool in haystack_tools}) + return haystack_tools except Exception as e: # We need to close because we could connect properly, retrieve tools yet - # fail because of an MCPToolNotFoundError + # fail because of validation errors self.close() - if isinstance(e, MCPToolNotFoundError): - raise # re-raise MCPToolNotFoundError as is to show original message + if isinstance(e, (MCPToolNotFoundError, ValueError)): + raise # re-raise validation errors as is to show original message # Create informative error message for SSE connection errors # Common error handling for HTTP-based transports @@ -292,6 +454,31 @@ def invoke_tool(**kwargs: Any) -> Any: raise MCPConnectionError(message=message, server_info=self.server_info, operation="initialize") from e + def _validate_state_configs(self, available_tool_names: set[str]) -> None: + """ + Validate that state configuration tool names exist in the toolset. + + Logs a warning for any tool names in the state configs that don't match + available tools in the toolset. + + :param available_tool_names: Set of tool names that are available in the toolset + """ + configs: list[tuple[str, dict[str, Any]]] = [ + ("inputs_from_state", self.inputs_from_state), + ("outputs_to_state", self.outputs_to_state), + ("outputs_to_string", self.outputs_to_string), + ] + for config_name, config in configs: + if config: + unknown_tools = set(config.keys()) - available_tool_names + if unknown_tools: + logger.warning( + "{config_name} references unknown tools: {unknown_tools}. Available tools: {available_tools}", + config_name=config_name, + unknown_tools=unknown_tools, + available_tools=available_tool_names, + ) + def to_dict(self) -> dict[str, Any]: """ Serialize the MCPToolset to a dictionary. @@ -306,6 +493,9 @@ def to_dict(self) -> dict[str, Any]: "connection_timeout": self.connection_timeout, "invocation_timeout": self.invocation_timeout, "eager_connect": self.eager_connect, + "inputs_from_state": self.inputs_from_state if self.inputs_from_state else None, + "outputs_to_state": _serialize_state_config(self.outputs_to_state), + "outputs_to_string": _serialize_state_config(self.outputs_to_string), }, } @@ -324,13 +514,21 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset": server_info_class = import_class_by_name(server_info_dict["type"]) server_info = cast(MCPServerInfo, server_info_class).from_dict(server_info_dict) + # Deserialize state configuration parameters + inputs_from_state = inner_data.get("inputs_from_state") + outputs_to_state = _deserialize_state_config(inner_data.get("outputs_to_state")) + outputs_to_string = _deserialize_state_config(inner_data.get("outputs_to_string")) + # Create a new MCPToolset instance return cls( server_info=server_info, tool_names=inner_data.get("tool_names"), connection_timeout=inner_data.get("connection_timeout", 30.0), invocation_timeout=inner_data.get("invocation_timeout", 30.0), - eager_connect=inner_data.get("eager_connect", True), + eager_connect=inner_data.get("eager_connect", False), + inputs_from_state=inputs_from_state if inputs_from_state else None, + outputs_to_state=outputs_to_state if outputs_to_state else None, + outputs_to_string=outputs_to_string if outputs_to_string else None, ) def close(self): diff --git a/integrations/mcp/tests/mcp_servers_fixtures.py b/integrations/mcp/tests/mcp_servers_fixtures.py index ae988214fa..d7d54bf2ee 100644 --- a/integrations/mcp/tests/mcp_servers_fixtures.py +++ b/integrations/mcp/tests/mcp_servers_fixtures.py @@ -25,6 +25,25 @@ def divide_by_zero(a: int) -> float: return a / 0 +################################################ +# State IO Calculator MCP Server (returns dicts for state propagation) +################################################ + +state_calculator_mcp = FastMCP("StateCalculator") + + +@state_calculator_mcp.tool() +def state_add(a: int, b: int) -> dict: + """Add two integers.""" + return {"result": a + b} + + +@state_calculator_mcp.tool() +def state_subtract(a: int, b: int) -> dict: + """Subtract integer b from integer a.""" + return {"result": a - b} + + ################################################ # Echo MCP Server ################################################ diff --git a/integrations/mcp/tests/test_mcp_integration.py b/integrations/mcp/tests/test_mcp_integration.py index 484a34109b..bb6ab750a3 100644 --- a/integrations/mcp/tests/test_mcp_integration.py +++ b/integrations/mcp/tests/test_mcp_integration.py @@ -7,7 +7,9 @@ import time import pytest +import pytest_asyncio from haystack import Pipeline, logging +from haystack.components.agents import Agent from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole @@ -16,12 +18,13 @@ MCPConnectionError, MCPError, MCPTool, + MCPToolset, SSEServerInfo, StdioServerInfo, ) from .mcp_memory_transport import InMemoryServerInfo -from .mcp_servers_fixtures import echo_mcp +from .mcp_servers_fixtures import echo_mcp, state_calculator_mcp logger = logging.getLogger(__name__) @@ -242,3 +245,99 @@ def test_mcp_tool_error_handling_integration(self): assert any(text in error_message.lower() for text in ["failed", "connection", "initialize"]), ( f"Error message '{error_message}' should contain connection failure information" ) + + +@pytest_asyncio.fixture +async def calculator_toolset_with_state_config(mcp_tool_cleanup): + """Fixture that provides an MCPToolset with state configuration for integration testing. + + Configuration: + - add: No inputs_from_state, writes result to 'sum' state key + - subtract: Reads 'sum' from state (written by add), writes to 'difference' state key + """ + server_info = InMemoryServerInfo(server=state_calculator_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["state_add", "state_subtract"], + connection_timeout=45, + invocation_timeout=60, + eager_connect=True, + inputs_from_state={ + # state_add tool takes normal parameters (not from state) + # state_subtract tool reads 'sum' from state and maps to parameter 'a' + "state_subtract": {"sum": "a"}, + }, + outputs_to_state={ + # Extract from content[].text result for both tools + "state_add": {"sum": {"source": "result"}}, + "state_subtract": {"difference": {"source": "result"}}, + }, + ) + return mcp_tool_cleanup(toolset) + + +@pytest.mark.integration +class TestMCPToolsetStateConfiguration: + """Integration tests for MCPToolset with state configuration features.""" + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + def test_toolset_with_multiple_tools_state_chaining(self, calculator_toolset_with_state_config): + """ + Test that outputs_to_state and inputs_from_state work with Agent state management. + + This test verifies the complete state propagation workflow in a single agent run: + 1. Agent calls state_add tool which writes 'sum' to state via outputs_to_state + 2. Agent calls state_subtract tool which reads 'sum' from state via inputs_from_state + + Both tools are called in sequence during a single agent execution, demonstrating + how tools communicate through Agent state. + """ + toolset = calculator_toolset_with_state_config + + # Verify state configurations + add_tool = next(tool for tool in toolset.tools if tool.name == "state_add") + subtract_tool = next(tool for tool in toolset.tools if tool.name == "state_subtract") + + assert add_tool.inputs_from_state is None # state_add takes normal parameters + assert add_tool.outputs_to_state == {"sum": {"source": "result"}} # writes sum to state + assert subtract_tool.inputs_from_state == {"sum": "a"} # reads 'sum' from state + assert subtract_tool.outputs_to_state == {"difference": {"source": "result"}} # writes difference to state + + # Create Agent with state schema + agent = Agent( + chat_generator=OpenAIChatGenerator(model="gpt-4.1"), + tools=toolset.tools, + state_schema={ + "sum": {"type": int}, + "difference": {"type": int}, + }, + ) + + # Create pipeline + pipeline = Pipeline() + pipeline.add_component("agent", agent) + + # Run agent - it will call both tools in sequence during this single execution + # 1. First, state_add tool calculates 20+5 and writes sum=25 to state + # 2. Then, state_subtract tool reads sum from state and calculates sum-10 + result = pipeline.run( + { + "agent": { + "messages": [ + ChatMessage.from_user( + "First, use the state_add tool to calculate 20 + 5. " + "Then use the state_subtract tool to subtract 10 from the result." + ) + ], + } + } + ) + + # Verify both state values were written by the tools + assert "sum" in result["agent"], "Expected 'sum' to be written to state by state_add tool" + sum_value = result["agent"]["sum"] + assert sum_value == 25, f"Expected sum=25 (20+5), got {sum_value}" + + assert "difference" in result["agent"], "Expected 'difference' to be written to state by state_subtract tool" + difference_value = result["agent"]["difference"] + assert difference_value == 15, f"Expected difference=15 (25-10), got {difference_value}" diff --git a/integrations/mcp/tests/test_mcp_toolset.py b/integrations/mcp/tests/test_mcp_toolset.py index 625df58a71..eaeba5ae2c 100644 --- a/integrations/mcp/tests/test_mcp_toolset.py +++ b/integrations/mcp/tests/test_mcp_toolset.py @@ -7,7 +7,6 @@ import time from unittest.mock import patch -import haystack import pytest import pytest_asyncio from haystack import logging @@ -24,6 +23,10 @@ SSEServerInfo, StreamableHttpServerInfo, ) +from haystack_integrations.tools.mcp.mcp_toolset import ( + _deserialize_state_config, + _serialize_state_config, +) # Import in-memory transport and fixtures from .mcp_memory_transport import InMemoryServerInfo @@ -78,6 +81,38 @@ async def calculator_toolset_with_tool_filter(mcp_tool_cleanup): return mcp_tool_cleanup(toolset) +def format_result(result): + """Sample handler function for testing.""" + return f"FORMATTED: {result}" + + +@pytest_asyncio.fixture +async def calculator_toolset_with_state_config(mcp_tool_cleanup): + """Fixture that provides an MCPToolset with state configuration.""" + + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["add", "subtract"], + connection_timeout=45, + invocation_timeout=60, + eager_connect=True, + inputs_from_state={ + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + }, + outputs_to_state={ + "add": {"sum_result": {"source": "content"}}, + "subtract": {"diff_result": {}}, + }, + outputs_to_string={ + "add": {"source": "content", "handler": format_result}, + }, + ) + + return mcp_tool_cleanup(toolset) + + @pytest.mark.asyncio class TestMCPToolset: """Tests for the MCPToolset class.""" @@ -233,6 +268,120 @@ async def test_toolset_tool_not_found(self): eager_connect=True, ) + async def test_toolset_with_state_config(self, calculator_toolset_with_state_config): + """Test that MCPToolset correctly passes state configuration to tools.""" + toolset = calculator_toolset_with_state_config + + # Verify toolset has state configs stored + assert toolset.inputs_from_state == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert "add" in toolset.outputs_to_state + assert "subtract" in toolset.outputs_to_state + assert "add" in toolset.outputs_to_string + + # Verify tools have correct state configurations + add_tool = next(tool for tool in toolset.tools if tool.name == "add") + subtract_tool = next(tool for tool in toolset.tools if tool.name == "subtract") + + assert add_tool.inputs_from_state == {"first_number": "a"} + assert subtract_tool.inputs_from_state == {"first_number": "a", "second_number": "b"} + assert add_tool.outputs_to_state == {"sum_result": {"source": "content"}} + assert subtract_tool.outputs_to_state == {"diff_result": {}} + assert add_tool.outputs_to_string is not None + assert subtract_tool.outputs_to_string is None + + async def test_toolset_state_config_serde(self, calculator_toolset_with_state_config, mcp_tool_cleanup): + """Test serialization and deserialization of MCPToolset with state configuration.""" + toolset = calculator_toolset_with_state_config + + toolset_dict = toolset.to_dict() + + # Verify state configs are serialized + assert toolset_dict["data"]["inputs_from_state"] == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert toolset_dict["data"]["outputs_to_state"] is not None + assert toolset_dict["data"]["outputs_to_string"] is not None + # Handler should be serialized as a string + assert isinstance(toolset_dict["data"]["outputs_to_string"]["add"]["handler"], str) + + # Test deserialization with full roundtrip + new_toolset = MCPToolset.from_dict(toolset_dict) + mcp_tool_cleanup(new_toolset) + + # Verify state configs are correctly deserialized + assert new_toolset.inputs_from_state == { + "add": {"first_number": "a"}, + "subtract": {"first_number": "a", "second_number": "b"}, + } + assert "add" in new_toolset.outputs_to_state + assert "add" in new_toolset.outputs_to_string + # Handler should be deserialized back to a callable + assert callable(new_toolset.outputs_to_string["add"]["handler"]) + + async def test_toolset_state_config_unknown_tool_warning(self, caplog): + """Test that a warning is logged when state config references unknown tools. + + Note: This test validates unknown tool names at the MCPToolset level. + For parameter validation (unknown parameter names), see test_toolset_state_config_invalid_parameter_raises_error + which requires Haystack >= 2.22.0. + """ + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + + with caplog.at_level("WARNING"): + toolset = MCPToolset( + server_info=server_info, + tool_names=["add"], # Only include add + connection_timeout=10, + invocation_timeout=10, + eager_connect=True, + inputs_from_state={ + "add": {"first_number": "a"}, + "unknown_tool": {"some_key": "some_param"}, # This tool doesn't exist + }, + ) + + # The warning should be logged + assert any("unknown_tool" in record.message for record in caplog.records) + toolset.close() + + @pytest.mark.skipif( + not hasattr(__import__("haystack.tools", fromlist=["Tool"]).Tool, "_get_valid_inputs"), + reason="Requires Haystack >= 2.22.0 for inputs_from_state validation", + ) + async def test_toolset_state_config_invalid_parameter_raises_error(self): + """Test that ValueError is raised when inputs_from_state references non-existent parameter. + + Requires Haystack >= 2.22.0 which validates inputs_from_state parameter names. + With Haystack < 2.22.0, this test is skipped and invalid parameter mappings will + only fail at runtime when the tool is invoked. + """ + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + + with pytest.raises(ValueError, match="unknown parameter"): + MCPToolset( + server_info=server_info, + tool_names=["add"], + connection_timeout=10, + invocation_timeout=10, + eager_connect=True, + inputs_from_state={ + "add": {"state_key": "non_existent_param"}, # 'add' tool has 'a' and 'b' parameters + }, + ) + + async def test_toolset_no_state_config(self, calculator_toolset): + """Test that tools have no state config when none is provided.""" + toolset = calculator_toolset + + for tool in toolset.tools: + assert tool.inputs_from_state is None + assert tool.outputs_to_state is None + assert tool.outputs_to_string is None + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.integration async def test_pipeline_warmup_with_mcp_toolset(self): @@ -464,19 +613,18 @@ def subtract(a: int, b: int) -> int: if os.path.exists(server_script_path): os.remove(server_script_path) - def test_pipeline_deserialization_fails_without_github_token(self, monkeypatch): + def test_pipeline_deserialization_succeeds_with_lazy_connection(self, monkeypatch): """ - Test that pipeline deserialization + MCPToolset initialization fails when GitHub - token is not resolved during deserialization. + Test that pipeline deserialization succeeds with lazy connection (eager_connect=False). - The issue: - - Setup: Agent pipeline template with MCPToolset with a token from env var (PERSONAL_ACCESS_TOKEN_GITHUB) - - MCPToolset tries to connect immediately during __init__ after validation - - Secrets get resolved during validation, after MCPToolset is initialized - - Connection fails because token can't be resolved in __init__ - - Pipeline deserialization fails with DeserializationError + With lazy connection (the default), MCPToolset defers connection until warm_up() is called. + This allows pipelines to be deserialized even when the server is not available or + credentials are not yet resolved. - This test demonstrates why we need warmup for MCPToolset on first use rather than during deserialization. + This test demonstrates that: + - Pipeline deserialization succeeds even with an invalid token + - MCPToolset creates a placeholder tool during initialization + - Actual connection happens later during warm_up() """ pipeline_yaml = """ components: @@ -528,7 +676,103 @@ def test_pipeline_deserialization_fails_without_github_token(self, monkeypatch): connections: [] """ monkeypatch.setenv("PERSONAL_ACCESS_TOKEN_GITHUB", "SOME_OBVIOUSLY_INVALID_TOKEN") - # Attempt to deserialize the pipeline - this will fail because MCPToolset - # tries to connect immediately and the token isn't available - with pytest.raises(haystack.core.errors.DeserializationError): - Pipeline.loads(pipeline_yaml) + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + + # Deserialization should succeed because eager_connect defaults to False + # With lazy connection, MCPToolset creates a placeholder tool and doesn't try to connect + pipeline = Pipeline.loads(pipeline_yaml) + + # Verify the pipeline was created successfully + assert pipeline is not None + agent = pipeline.get_component("agent") + assert agent is not None + + # The key point is that deserialization succeeded even with an invalid token + # because the connection is deferred until warm_up() is called + + +class TestStateConfigHelpers: + """Tests for the state configuration serialization helper functions.""" + + def test_serialize_outputs_to_string_with_handler(self): + """Test serializing outputs_to_string config with a handler function.""" + config = { + "add": {"source": "content", "handler": format_result}, + "subtract": {"source": "diff"}, + } + + serialized = _serialize_state_config(config) + + assert serialized is not None + assert "add" in serialized + assert "subtract" in serialized + assert isinstance(serialized["add"]["handler"], str) # Handler serialized to string + assert serialized["subtract"]["source"] == "diff" + assert "handler" not in serialized["subtract"] # No handler for subtract + + def test_serialize_outputs_to_state_with_handler(self): + """Test serializing outputs_to_state config with a handler function.""" + config = { + "add": { + "sum_result": {"source": "content", "handler": format_result}, + "raw_result": {}, + }, + } + + serialized = _serialize_state_config(config) + + assert serialized is not None + assert "add" in serialized + assert isinstance(serialized["add"]["sum_result"]["handler"], str) + assert serialized["add"]["raw_result"] == {} + + def test_serialize_empty_config(self): + """Test that empty config returns None.""" + assert _serialize_state_config({}) is None + assert _serialize_state_config(None) is None + + def test_deserialize_outputs_to_string_with_handler(self): + """Test deserializing outputs_to_string config with a handler function.""" + # First serialize to get the correct handler path + original = {"add": {"source": "content", "handler": format_result}} + serialized = _serialize_state_config(original) + + # Now deserialize + deserialized = _deserialize_state_config(serialized) + + assert "add" in deserialized + assert callable(deserialized["add"]["handler"]) + assert deserialized["add"]["source"] == "content" + + def test_deserialize_outputs_to_state_with_handler(self): + """Test deserializing outputs_to_state config with a handler function.""" + # First serialize to get the correct handler path + original = {"add": {"sum_result": {"source": "content", "handler": format_result}}} + serialized = _serialize_state_config(original) + + # Now deserialize + deserialized = _deserialize_state_config(serialized) + + assert "add" in deserialized + assert callable(deserialized["add"]["sum_result"]["handler"]) + + def test_deserialize_empty_config(self): + """Test that empty config returns empty dict.""" + assert _deserialize_state_config({}) == {} + assert _deserialize_state_config(None) == {} + + def test_roundtrip_serialization(self): + """Test that serialization and deserialization are inverse operations.""" + original = { + "add": {"source": "content", "handler": format_result}, + "subtract": {"source": "diff"}, + } + + serialized = _serialize_state_config(original) + deserialized = _deserialize_state_config(serialized) + + assert "add" in deserialized + assert "subtract" in deserialized + assert deserialized["add"]["source"] == "content" + assert callable(deserialized["add"]["handler"]) + assert deserialized["subtract"]["source"] == "diff"