forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_sampling_callback.py
More file actions
155 lines (131 loc) · 5.93 KB
/
test_sampling_callback.py
File metadata and controls
155 lines (131 loc) · 5.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import anyio
import pytest
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
SamplingMessage,
TextContent,
)
@pytest.mark.anyio
async def test_sampling_callback():
from mcp.server.fastmcp import FastMCP
server = FastMCP("test")
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="This is a response from the sampling callback"),
model="test-model",
stopReason="endTurn",
)
async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
@server.tool("test_sampling")
async def test_sampling_tool(message: str):
value = await server.get_context().session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
assert value == callback_return
return True
# Test with sampling callback
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Test without sampling callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
@pytest.mark.anyio
async def test_concurrent_sampling_callback():
"""Test multiple concurrent sampling calls using time-sort verification."""
from mcp.server.fastmcp import FastMCP
server = FastMCP("test")
# Track completion order using time-sort approach
completion_order = []
async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
# Extract delay from the message content (e.g., "delay_0.3")
assert isinstance(params.messages[0].content, TextContent)
message_text = params.messages[0].content.text
if message_text.startswith("delay_"):
delay = float(message_text.split("_")[1])
# Simulate different LLM response times
await anyio.sleep(delay)
completion_order.append(delay)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text=f"Response after {delay}s"),
model="test-model",
stopReason="endTurn",
)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Default response"),
model="test-model",
stopReason="endTurn",
)
@server.tool("concurrent_sampling_tool")
async def concurrent_sampling_tool():
"""Tool that makes multiple concurrent sampling calls."""
# Use TaskGroup to make multiple concurrent sampling calls
# Using out-of-order durations: 0.6s, 0.2s, 0.4s
# If concurrent, should complete in order: 0.2s, 0.4s, 0.6s
async with anyio.create_task_group() as tg:
results = {}
async def make_sampling_call(call_id: str, delay: float):
result = await server.get_context().session.create_message(
messages=[
SamplingMessage(
role="user",
content=TextContent(type="text", text=f"delay_{delay}"),
)
],
max_tokens=100,
)
results[call_id] = result
# Start operations with out-of-order timing
tg.start_soon(make_sampling_call, "slow_call", 0.6) # Should finish last
tg.start_soon(make_sampling_call, "fast_call", 0.2) # Should finish first
tg.start_soon(make_sampling_call, "medium_call", 0.4) # Should finish middle
# Combine results to show all completed
combined_response = " | ".join(
[
results["slow_call"].content.text,
results["fast_call"].content.text,
results["medium_call"].content.text,
]
)
return combined_response
# Test concurrent sampling calls with time-sort verification
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
# Make a request that triggers multiple concurrent sampling calls
result = await client_session.call_tool("concurrent_sampling_tool", {})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
# Verify all sampling calls completed with expected responses
expected_result = "Response after 0.6s | Response after 0.2s | Response after 0.4s"
assert result.content[0].text == expected_result
# Key test: verify concurrent execution using time-sort
# Started in order: 0.6s, 0.2s, 0.4s
# Should complete in order: 0.2s, 0.4s, 0.6s (fastest first)
assert len(completion_order) == 3
assert completion_order == [
0.2,
0.4,
0.6,
], f"Expected [0.2, 0.4, 0.6] but got {completion_order}"