|
| 1 | +"""Concurrency over a single client session: multiple requests in flight at once, in both directions.""" |
| 2 | + |
| 3 | +import anyio |
| 4 | +import pytest |
| 5 | +from inline_snapshot import snapshot |
| 6 | + |
| 7 | +from mcp import Client |
| 8 | +from mcp.client import ClientRequestContext |
| 9 | +from mcp.server.mcpserver import Context, MCPServer |
| 10 | +from mcp.types import ( |
| 11 | + CallToolResult, |
| 12 | + CreateMessageRequestParams, |
| 13 | + CreateMessageResult, |
| 14 | + SamplingMessage, |
| 15 | + TextContent, |
| 16 | +) |
| 17 | + |
| 18 | +pytestmark = pytest.mark.anyio |
| 19 | + |
| 20 | + |
| 21 | +async def test_concurrent_tool_calls_resolve_out_of_order_to_their_own_callers() -> None: |
| 22 | + """Three tool calls in flight at once on one session each receive their own result, even though |
| 23 | + the responses come back in the reverse of the order the requests were sent. |
| 24 | +
|
| 25 | + SDK-defined contract: pins the client request machinery's support for concurrent in-flight |
| 26 | + calls with out-of-order response correlation. Each handler parks on its own release event |
| 27 | + after signalling it started; a session that serialized requests would never start the later |
| 28 | + handlers and the test would time out instead. |
| 29 | + """ |
| 30 | + send_order = ["a", "b", "c"] |
| 31 | + started = {tag: anyio.Event() for tag in send_order} |
| 32 | + release = {tag: anyio.Event() for tag in send_order} |
| 33 | + done = {tag: anyio.Event() for tag in send_order} |
| 34 | + completion_order: list[str] = [] |
| 35 | + results: dict[str, CallToolResult] = {} |
| 36 | + |
| 37 | + server = MCPServer("parking") |
| 38 | + |
| 39 | + @server.tool() |
| 40 | + async def park(tag: str) -> str: |
| 41 | + started[tag].set() |
| 42 | + await release[tag].wait() |
| 43 | + return f"result:{tag}" |
| 44 | + |
| 45 | + async with Client(server) as client: |
| 46 | + |
| 47 | + async def call_and_record(tag: str) -> None: |
| 48 | + results[tag] = await client.call_tool("park", {"tag": tag}) |
| 49 | + completion_order.append(tag) |
| 50 | + done[tag].set() |
| 51 | + |
| 52 | + with anyio.fail_after(5): |
| 53 | + async with anyio.create_task_group() as task_group: # pragma: no branch |
| 54 | + # Waiting for each handler to start before issuing the next call fixes the send |
| 55 | + # order, and leaves all three parked in flight together once the loop finishes. |
| 56 | + for tag in send_order: |
| 57 | + task_group.start_soon(call_and_record, tag) |
| 58 | + await started[tag].wait() |
| 59 | + |
| 60 | + # Nothing completed yet: all three calls are genuinely concurrent. |
| 61 | + assert completion_order == [] |
| 62 | + |
| 63 | + # Release in reverse, awaiting each completion so the finish order is forced. |
| 64 | + for tag in reversed(send_order): |
| 65 | + release[tag].set() |
| 66 | + await done[tag].wait() |
| 67 | + |
| 68 | + assert completion_order == ["c", "b", "a"] |
| 69 | + assert results == snapshot( |
| 70 | + { |
| 71 | + "c": CallToolResult(content=[TextContent(text="result:c")], structured_content={"result": "result:c"}), |
| 72 | + "b": CallToolResult(content=[TextContent(text="result:b")], structured_content={"result": "result:b"}), |
| 73 | + "a": CallToolResult(content=[TextContent(text="result:a")], structured_content={"result": "result:a"}), |
| 74 | + } |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +async def test_overlapping_sampling_requests_are_serviced_concurrently_by_the_client() -> None: |
| 79 | + """A server tool that fans out two sampling requests at once gets both echoes back: the client |
| 80 | + runs overlapping inbound `create_message` requests concurrently instead of serializing them in |
| 81 | + its receive loop. |
| 82 | +
|
| 83 | + Regression pin for https://github.com/modelcontextprotocol/python-sdk/issues/2489 -- v1's |
| 84 | + `BaseSession` awaited each inbound request handler inline, so the second sampling callback |
| 85 | + could not start until the first returned; here both rendezvous before either is released. |
| 86 | + """ |
| 87 | + sampling_started = {"x": anyio.Event(), "y": anyio.Event()} |
| 88 | + sampling_release = anyio.Event() |
| 89 | + tool_results: list[CallToolResult] = [] |
| 90 | + |
| 91 | + server = MCPServer("fan_out_server") |
| 92 | + |
| 93 | + @server.tool() |
| 94 | + async def fan_out(ctx: Context) -> str: |
| 95 | + echoes: dict[str, str] = {} |
| 96 | + |
| 97 | + async def sample(tag: str) -> None: |
| 98 | + result = await ctx.session.create_message( |
| 99 | + messages=[SamplingMessage(role="user", content=TextContent(text=tag))], |
| 100 | + max_tokens=10, |
| 101 | + ) |
| 102 | + assert isinstance(result.content, TextContent) |
| 103 | + echoes[tag] = result.content.text |
| 104 | + |
| 105 | + async with anyio.create_task_group() as sampler_group: |
| 106 | + sampler_group.start_soon(sample, "x") |
| 107 | + sampler_group.start_soon(sample, "y") |
| 108 | + return f"{echoes['x']} {echoes['y']}" |
| 109 | + |
| 110 | + async def sampling_callback( |
| 111 | + context: ClientRequestContext, params: CreateMessageRequestParams |
| 112 | + ) -> CreateMessageResult: |
| 113 | + content = params.messages[0].content |
| 114 | + assert isinstance(content, TextContent) |
| 115 | + sampling_started[content.text].set() |
| 116 | + await sampling_release.wait() |
| 117 | + return CreateMessageResult( |
| 118 | + role="assistant", |
| 119 | + content=TextContent(text=f"echo:{content.text}"), |
| 120 | + model="test-model", |
| 121 | + stop_reason="endTurn", |
| 122 | + ) |
| 123 | + |
| 124 | + async with Client(server, sampling_callback=sampling_callback) as client: |
| 125 | + with anyio.fail_after(5): |
| 126 | + async with anyio.create_task_group() as task_group: # pragma: no branch |
| 127 | + |
| 128 | + async def invoke_fan_out() -> None: |
| 129 | + tool_results.append(await client.call_tool("fan_out", {})) |
| 130 | + |
| 131 | + task_group.start_soon(invoke_fan_out) |
| 132 | + |
| 133 | + # Both sampling callbacks are mid-flight before either may answer -- a client that |
| 134 | + # serialized inbound requests would never start the second one. |
| 135 | + await sampling_started["x"].wait() |
| 136 | + await sampling_started["y"].wait() |
| 137 | + sampling_release.set() |
| 138 | + |
| 139 | + assert tool_results == snapshot( |
| 140 | + [CallToolResult(content=[TextContent(text="echo:x echo:y")], structured_content={"result": "echo:x echo:y"})] |
| 141 | + ) |
0 commit comments