Skip to content

Commit 560de5c

Browse files
committed
Validate full sampling tool result history
1 parent 5d82649 commit 560de5c

3 files changed

Lines changed: 222 additions & 22 deletions

File tree

src/mcp/server/validation.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from mcp.shared.exceptions import MCPError
7-
from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, Tool, ToolChoice
7+
from mcp.types import INVALID_PARAMS, ClientCapabilities, SamplingMessage, SamplingMessageContentBlock, Tool, ToolChoice
88

99

1010
def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool:
@@ -52,6 +52,7 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None:
5252
1. Messages with tool_result content contain ONLY tool_result content
5353
2. tool_result messages are preceded by a message with tool_use
5454
3. tool_result IDs match the tool_use IDs from the previous message
55+
4. Every tool_use message in the history is followed by matching tool_result content
5556
5657
See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
5758
@@ -64,24 +65,26 @@ def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None:
6465
if not messages:
6566
return
6667

67-
last_content = messages[-1].content_as_list
68-
has_tool_results = any(c.type == "tool_result" for c in last_content)
69-
70-
previous_content = messages[-2].content_as_list if len(messages) >= 2 else None
71-
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
72-
73-
if has_tool_results:
74-
# Per spec: "SamplingMessage with tool result content blocks
75-
# MUST NOT contain other content types."
76-
if any(c.type != "tool_result" for c in last_content):
77-
raise ValueError("The last message must contain only tool_result content if any is present")
78-
if previous_content is None:
79-
raise ValueError("tool_result requires a previous message containing tool_use")
80-
if not has_previous_tool_use:
81-
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
82-
83-
if has_previous_tool_use and previous_content:
84-
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
85-
tool_result_ids = {c.tool_use_id for c in last_content if c.type == "tool_result"}
86-
if tool_use_ids != tool_result_ids:
87-
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
68+
previous_content: list[SamplingMessageContentBlock] | None = None
69+
for content in (message.content_as_list for message in messages):
70+
has_tool_results = any(c.type == "tool_result" for c in content)
71+
previous_tool_use_ids: set[str] = set()
72+
if previous_content is not None:
73+
previous_tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
74+
75+
if has_tool_results:
76+
# Per spec: "SamplingMessage with tool result content blocks
77+
# MUST NOT contain other content types."
78+
if any(c.type != "tool_result" for c in content):
79+
raise ValueError("A message must contain only tool_result content if any is present")
80+
if previous_content is None:
81+
raise ValueError("tool_result requires a previous message containing tool_use")
82+
if not previous_tool_use_ids:
83+
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
84+
85+
if previous_tool_use_ids:
86+
tool_result_ids = {c.tool_use_id for c in content if c.type == "tool_result"}
87+
if previous_tool_use_ids != tool_result_ids:
88+
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
89+
90+
previous_content = content

tests/server/test_session.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,118 @@ async def test_send_request_without_back_channel_or_related_id_fails_fast():
125125
assert dispatcher.requests[0][3] == 3
126126

127127

128+
@pytest.mark.anyio
129+
async def test_create_message_tool_result_validation():
130+
"""Test tool_use/tool_result validation in create_message."""
131+
dispatcher = StubDispatcher(
132+
result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}
133+
)
134+
session = _make_session(
135+
dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability()))
136+
)
137+
tool = types.Tool(name="test_tool", input_schema={"type": "object"})
138+
text = types.TextContent(type="text", text="hello")
139+
tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={})
140+
tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[])
141+
142+
# Case 1: tool_result mixed with other content
143+
with pytest.raises(ValueError, match="only tool_result content"):
144+
await session.create_message(
145+
messages=[
146+
types.SamplingMessage(role="user", content=text),
147+
types.SamplingMessage(role="assistant", content=tool_use),
148+
types.SamplingMessage(role="user", content=[tool_result, text]),
149+
],
150+
max_tokens=100,
151+
tools=[tool],
152+
)
153+
154+
# Case 2: tool_result without previous message
155+
with pytest.raises(ValueError, match="requires a previous message"):
156+
await session.create_message(
157+
messages=[types.SamplingMessage(role="user", content=tool_result)],
158+
max_tokens=100,
159+
tools=[tool],
160+
)
161+
162+
# Case 3: tool_result without previous tool_use
163+
with pytest.raises(ValueError, match="do not match any tool_use"):
164+
await session.create_message(
165+
messages=[
166+
types.SamplingMessage(role="user", content=text),
167+
types.SamplingMessage(role="user", content=tool_result),
168+
],
169+
max_tokens=100,
170+
tools=[tool],
171+
)
172+
173+
# Case 4: mismatched tool IDs
174+
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
175+
await session.create_message(
176+
messages=[
177+
types.SamplingMessage(role="user", content=text),
178+
types.SamplingMessage(role="assistant", content=tool_use),
179+
types.SamplingMessage(
180+
role="user",
181+
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
182+
),
183+
],
184+
max_tokens=100,
185+
tools=[tool],
186+
)
187+
188+
# Case 4b: earlier mismatched tool result with a later plain message
189+
with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"):
190+
await session.create_message(
191+
messages=[
192+
types.SamplingMessage(role="assistant", content=tool_use),
193+
types.SamplingMessage(
194+
role="user",
195+
content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]),
196+
),
197+
types.SamplingMessage(role="assistant", content=text),
198+
],
199+
max_tokens=100,
200+
tools=[tool],
201+
)
202+
203+
# Case 5: text-only message with tools (no tool_results) - passes validation
204+
await session.create_message(
205+
messages=[types.SamplingMessage(role="user", content=text)],
206+
max_tokens=100,
207+
tools=[tool],
208+
)
209+
210+
# Case 6: valid matching tool_result/tool_use IDs - passes validation
211+
await session.create_message(
212+
messages=[
213+
types.SamplingMessage(role="user", content=text),
214+
types.SamplingMessage(role="assistant", content=tool_use),
215+
types.SamplingMessage(role="user", content=tool_result),
216+
],
217+
max_tokens=100,
218+
tools=[tool],
219+
)
220+
221+
# Case 7: validation runs even without `tools` parameter
222+
# (tool loop continuation may omit tools while containing tool_result)
223+
with pytest.raises(ValueError, match="do not match any tool_use"):
224+
await session.create_message(
225+
messages=[
226+
types.SamplingMessage(role="user", content=text),
227+
types.SamplingMessage(role="user", content=tool_result),
228+
],
229+
max_tokens=100,
230+
)
231+
232+
# Case 8: empty messages list - skips validation entirely
233+
no_tools_session = _make_session(
234+
StubDispatcher(result={"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"}),
235+
capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())),
236+
)
237+
await no_tools_session.create_message(messages=[], max_tokens=100)
238+
239+
128240
@pytest.mark.anyio
129241
async def test_send_request_validates_result_alias_only():
130242
"""Peer results validate alias-only; a snake_case key from the wire is

tests/server/test_validation.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_mixed_with_ot
108108
validate_tool_use_result_messages(messages)
109109

110110

111+
def test_validate_tool_use_result_messages_raises_for_earlier_mixed_tool_result() -> None:
112+
"""Raises when an earlier message mixes tool_result with other content."""
113+
messages = [
114+
SamplingMessage(
115+
role="assistant",
116+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
117+
),
118+
SamplingMessage(
119+
role="user",
120+
content=[
121+
ToolResultContent(type="tool_result", tool_use_id="tool-1"),
122+
TextContent(type="text", text="also this"),
123+
],
124+
),
125+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
126+
]
127+
128+
with pytest.raises(ValueError, match="only tool_result content"):
129+
validate_tool_use_result_messages(messages)
130+
131+
111132
def test_validate_tool_use_result_messages_raises_when_tool_result_without_previous_tool_use() -> None:
112133
"""Raises when tool_result appears without preceding tool_use."""
113134
messages = [
@@ -146,6 +167,39 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_matc
146167
validate_tool_use_result_messages(messages)
147168

148169

170+
def test_validate_tool_use_result_messages_raises_when_earlier_tool_result_ids_dont_match_tool_use() -> None:
171+
"""Raises when an earlier tool_result does not match the previous tool_use."""
172+
messages = [
173+
SamplingMessage(
174+
role="assistant",
175+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
176+
),
177+
SamplingMessage(
178+
role="user",
179+
content=ToolResultContent(type="tool_result", tool_use_id="tool-2"),
180+
),
181+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
182+
]
183+
184+
with pytest.raises(ValueError, match="do not match"):
185+
validate_tool_use_result_messages(messages)
186+
187+
188+
def test_validate_tool_use_result_messages_raises_when_tool_use_is_not_answered() -> None:
189+
"""Raises when a tool_use is followed by a non-tool_result message."""
190+
messages = [
191+
SamplingMessage(
192+
role="assistant",
193+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
194+
),
195+
SamplingMessage(role="user", content=TextContent(type="text", text="not a result")),
196+
SamplingMessage(role="assistant", content=TextContent(type="text", text="done")),
197+
]
198+
199+
with pytest.raises(ValueError, match="do not match"):
200+
validate_tool_use_result_messages(messages)
201+
202+
149203
def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_tool_use() -> None:
150204
"""No error when tool_result IDs match tool_use IDs."""
151205
messages = [
@@ -159,3 +213,34 @@ def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_too
159213
),
160214
]
161215
validate_tool_use_result_messages(messages) # Should not raise
216+
217+
218+
def test_validate_tool_use_result_messages_no_error_for_multiple_tool_pairs() -> None:
219+
"""No error when every tool_use in the history has a matching tool_result."""
220+
messages = [
221+
SamplingMessage(role="user", content=TextContent(type="text", text="first")),
222+
SamplingMessage(
223+
role="assistant",
224+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
225+
),
226+
SamplingMessage(
227+
role="user",
228+
content=ToolResultContent(type="tool_result", tool_use_id="tool-1"),
229+
),
230+
SamplingMessage(
231+
role="assistant",
232+
content=[
233+
ToolUseContent(type="tool_use", id="tool-2", name="test", input={}),
234+
ToolUseContent(type="tool_use", id="tool-3", name="test", input={}),
235+
],
236+
),
237+
SamplingMessage(
238+
role="user",
239+
content=[
240+
ToolResultContent(type="tool_result", tool_use_id="tool-3"),
241+
ToolResultContent(type="tool_result", tool_use_id="tool-2"),
242+
],
243+
),
244+
]
245+
246+
validate_tool_use_result_messages(messages)

0 commit comments

Comments
 (0)