diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index f784e82a1d..d23c585c6c 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -1103,6 +1103,30 @@ async def ainvoke(self, **kwargs: Any) -> str: message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}" raise MCPInvocationError(message, self.name, kwargs) from e + def _get_valid_inputs(self) -> set[str]: + """ + Return the set of valid input parameter names from the MCP tool schema. + + Used to validate that `inputs_from_state` only references parameters that actually exist. + Unlike the default implementation that introspects the function signature, + this returns parameters from the MCP tool's JSON schema. + + When eager_connect=False and we have placeholder parameters, returns an empty set + to skip validation until warm_up() is called. + + :returns: Set of valid input parameter names from the MCP tool schema. + """ + # Get parameters from the JSON schema (not from function introspection) + # MCPTool uses _invoke_tool(**kwargs) so introspection would only find 'kwargs' + properties = self.parameters.get("properties", {}) + + # If we have placeholder parameters (eager_connect=False), return empty set to skip validation + # Validation will happen during warm_up when real schema is fetched + if not properties: + return set() + + return set(properties.keys()) + def warm_up(self) -> None: """Connect and fetch the tool schema if eager_connect is turned off.""" with self._lock: @@ -1111,6 +1135,19 @@ def warm_up(self) -> None: tool = self._connect_and_initialize(self.name) self.parameters = tool.inputSchema + # Validate inputs_from_state now that we have the real schema + # Note: Duplicates Tool.__post_init__() logic, but needed here for early error detection + # when eager_connect=False (validation was skipped during __init__ via empty _get_valid_inputs()) + if self._inputs_from_state: + valid_inputs = set(self.parameters.get("properties", {}).keys()) + for state_key, param_name in self._inputs_from_state.items(): + if param_name not in valid_inputs: + msg = ( + f"inputs_from_state maps '{state_key}' to unknown parameter '{param_name}'. " + f"Valid parameters are: {valid_inputs}." + ) + raise ValueError(msg) + # Remove inputs_from_state keys from parameters schema if present # This matches the behavior of ComponentTool if self._inputs_from_state and "properties" in self.parameters: diff --git a/integrations/mcp/tests/test_mcp_tool.py b/integrations/mcp/tests/test_mcp_tool.py index 5786334617..db325be593 100644 --- a/integrations/mcp/tests/test_mcp_tool.py +++ b/integrations/mcp/tests/test_mcp_tool.py @@ -169,12 +169,13 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup): server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) # Create tool with state-mapping parameters + # The 'add' tool has parameters 'a' and 'b', so we map to 'a' tool = MCPTool( name="add", server_info=server_info, eager_connect=False, outputs_to_string={"source": "result"}, - inputs_from_state={"filter": "query_filter"}, + inputs_from_state={"state_a": "a"}, outputs_to_state={"result": {"source": "output"}}, ) mcp_tool_cleanup(tool) @@ -184,7 +185,7 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup): # Verify state-mapping parameters are serialized assert tool_dict["data"]["outputs_to_string"] == {"source": "result"} - assert tool_dict["data"]["inputs_from_state"] == {"filter": "query_filter"} + assert tool_dict["data"]["inputs_from_state"] == {"state_a": "a"} assert tool_dict["data"]["outputs_to_state"] == {"result": {"source": "output"}} # Test deserialization (from_dict) @@ -193,7 +194,7 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup): # Verify state-mapping parameters are restored assert new_tool._outputs_to_string == {"source": "result"} - assert new_tool._inputs_from_state == {"filter": "query_filter"} + assert new_tool._inputs_from_state == {"state_a": "a"} assert new_tool._outputs_to_state == {"result": {"source": "output"}} @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")