88import pytest
99
1010from mcp .client .session import ClientSession
11- from mcp .server .mcpserver import MCPServer
11+ from mcp .server .mcpserver import Context , MCPServer
12+ from mcp .shared ._context import RequestContext
1213from mcp .shared .dispatcher import (
1314 JSONRPCDispatcher ,
1415 OnErrorFn ,
1718)
1819from mcp .shared .memory import create_client_server_memory_streams
1920from mcp .shared .message import MessageMetadata
20- from mcp .types import ErrorData , RequestId
21+ from mcp .types import (
22+ CreateMessageRequestParams ,
23+ CreateMessageResult ,
24+ ErrorData ,
25+ RequestId ,
26+ SamplingMessage ,
27+ TextContent ,
28+ )
2129
2230pytestmark = pytest .mark .anyio
2331
@@ -35,6 +43,7 @@ def __init__(self, inner: JSONRPCDispatcher) -> None:
3543 self ._inner = inner
3644 self .sent_requests : list [dict [str , Any ]] = []
3745 self .sent_notifications : list [dict [str , Any ]] = []
46+ self .sent_responses : list [dict [str , Any ] | ErrorData ] = []
3847
3948 def set_handlers (self , on_request : OnRequestFn , on_notification : OnNotificationFn , on_error : OnErrorFn ) -> None :
4049 self ._inner .set_handlers (on_request , on_notification , on_error )
@@ -59,16 +68,33 @@ async def send_notification(
5968 await self ._inner .send_notification (notification , related_request_id )
6069
6170 async def send_response (self , request_id : RequestId , response : dict [str , Any ] | ErrorData ) -> None :
62- await self ._inner .send_response (request_id , response ) # pragma: no cover
71+ self .sent_responses .append (response )
72+ await self ._inner .send_response (request_id , response )
6373
6474
6575async def test_client_session_accepts_custom_dispatcher ():
66- """ClientSession round-trips through a custom dispatcher end-to-end."""
76+ """ClientSession round-trips through a custom dispatcher end-to-end, including
77+ a server-initiated request (sampling) so all five dispatcher methods fire."""
6778 app = MCPServer ("test" )
6879
6980 @app .tool ()
70- def greet (name : str ) -> str :
71- return f"Hello, { name } !"
81+ async def ask (question : str , ctx : Context ) -> str :
82+ answer = await ctx .session .create_message (
83+ messages = [SamplingMessage (role = "user" , content = TextContent (type = "text" , text = question ))],
84+ max_tokens = 10 ,
85+ )
86+ assert isinstance (answer .content , TextContent )
87+ return answer .content .text
88+
89+ async def sampling_callback (
90+ context : RequestContext [ClientSession ], params : CreateMessageRequestParams
91+ ) -> CreateMessageResult :
92+ return CreateMessageResult (
93+ role = "assistant" ,
94+ content = TextContent (type = "text" , text = "42" ),
95+ model = "test" ,
96+ stop_reason = "endTurn" ,
97+ )
7298
7399 async with create_client_server_memory_streams () as (client_streams , server_streams ):
74100 client_read , client_write = client_streams
@@ -83,17 +109,20 @@ def greet(name: str) -> str:
83109 server = app ._lowlevel_server # type: ignore[reportPrivateUsage]
84110 tg .start_soon (lambda : server .run (server_read , server_write , server .create_initialization_options ()))
85111
86- async with ClientSession (dispatcher = spy ) as session :
112+ async with ClientSession (dispatcher = spy , sampling_callback = sampling_callback ) as session :
87113 await session .initialize ()
88- result = await session .call_tool ("greet " , {"name " : "world " })
89- assert result .content [0 ].text == "Hello, world! " # type: ignore[union-attr]
114+ result = await session .call_tool ("ask " , {"question " : "meaning of life? " })
115+ assert result .content [0 ].text == "42 " # type: ignore[union-attr]
90116
91117 tg .cancel_scope .cancel ()
92118
93- # Initialize + call_tool + list_tools (output-schema refresh after the call).
119+ # initialize, tools/call (triggers sampling on the server), tools/list (schema refresh)
94120 assert [r ["method" ] for r in spy .sent_requests ] == ["initialize" , "tools/call" , "tools/list" ]
95- # InitializedNotification.
96121 assert [n ["method" ] for n in spy .sent_notifications ] == ["notifications/initialized" ]
122+ # The server's sampling/createMessage request hit us; our response went back through the spy.
123+ assert len (spy .sent_responses ) == 1
124+ response = spy .sent_responses [0 ]
125+ assert isinstance (response , dict ) and response ["model" ] == "test"
97126
98127
99128async def test_base_session_requires_streams_or_dispatcher ():
0 commit comments