Skip to content

Commit 454d9ee

Browse files
authored
Proper tool validation in mcp (#2654)
1 parent ecb6203 commit 454d9ee

2 files changed

Lines changed: 41 additions & 3 deletions

File tree

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,30 @@ async def ainvoke(self, **kwargs: Any) -> str:
11031103
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
11041104
raise MCPInvocationError(message, self.name, kwargs) from e
11051105

1106+
def _get_valid_inputs(self) -> set[str]:
1107+
"""
1108+
Return the set of valid input parameter names from the MCP tool schema.
1109+
1110+
Used to validate that `inputs_from_state` only references parameters that actually exist.
1111+
Unlike the default implementation that introspects the function signature,
1112+
this returns parameters from the MCP tool's JSON schema.
1113+
1114+
When eager_connect=False and we have placeholder parameters, returns an empty set
1115+
to skip validation until warm_up() is called.
1116+
1117+
:returns: Set of valid input parameter names from the MCP tool schema.
1118+
"""
1119+
# Get parameters from the JSON schema (not from function introspection)
1120+
# MCPTool uses _invoke_tool(**kwargs) so introspection would only find 'kwargs'
1121+
properties = self.parameters.get("properties", {})
1122+
1123+
# If we have placeholder parameters (eager_connect=False), return empty set to skip validation
1124+
# Validation will happen during warm_up when real schema is fetched
1125+
if not properties:
1126+
return set()
1127+
1128+
return set(properties.keys())
1129+
11061130
def warm_up(self) -> None:
11071131
"""Connect and fetch the tool schema if eager_connect is turned off."""
11081132
with self._lock:
@@ -1111,6 +1135,19 @@ def warm_up(self) -> None:
11111135
tool = self._connect_and_initialize(self.name)
11121136
self.parameters = tool.inputSchema
11131137

1138+
# Validate inputs_from_state now that we have the real schema
1139+
# Note: Duplicates Tool.__post_init__() logic, but needed here for early error detection
1140+
# when eager_connect=False (validation was skipped during __init__ via empty _get_valid_inputs())
1141+
if self._inputs_from_state:
1142+
valid_inputs = set(self.parameters.get("properties", {}).keys())
1143+
for state_key, param_name in self._inputs_from_state.items():
1144+
if param_name not in valid_inputs:
1145+
msg = (
1146+
f"inputs_from_state maps '{state_key}' to unknown parameter '{param_name}'. "
1147+
f"Valid parameters are: {valid_inputs}."
1148+
)
1149+
raise ValueError(msg)
1150+
11141151
# Remove inputs_from_state keys from parameters schema if present
11151152
# This matches the behavior of ComponentTool
11161153
if self._inputs_from_state and "properties" in self.parameters:

integrations/mcp/tests/test_mcp_tool.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,13 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
169169
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
170170

171171
# Create tool with state-mapping parameters
172+
# The 'add' tool has parameters 'a' and 'b', so we map to 'a'
172173
tool = MCPTool(
173174
name="add",
174175
server_info=server_info,
175176
eager_connect=False,
176177
outputs_to_string={"source": "result"},
177-
inputs_from_state={"filter": "query_filter"},
178+
inputs_from_state={"state_a": "a"},
178179
outputs_to_state={"result": {"source": "output"}},
179180
)
180181
mcp_tool_cleanup(tool)
@@ -184,7 +185,7 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
184185

185186
# Verify state-mapping parameters are serialized
186187
assert tool_dict["data"]["outputs_to_string"] == {"source": "result"}
187-
assert tool_dict["data"]["inputs_from_state"] == {"filter": "query_filter"}
188+
assert tool_dict["data"]["inputs_from_state"] == {"state_a": "a"}
188189
assert tool_dict["data"]["outputs_to_state"] == {"result": {"source": "output"}}
189190

190191
# Test deserialization (from_dict)
@@ -193,7 +194,7 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
193194

194195
# Verify state-mapping parameters are restored
195196
assert new_tool._outputs_to_string == {"source": "result"}
196-
assert new_tool._inputs_from_state == {"filter": "query_filter"}
197+
assert new_tool._inputs_from_state == {"state_a": "a"}
197198
assert new_tool._outputs_to_state == {"result": {"source": "output"}}
198199

199200
@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")

0 commit comments

Comments
 (0)