Skip to content

Commit e2ec224

Browse files
committed
add pluggable transport examples
1 parent 8914490 commit e2ec224

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Example demonstrating how the high-level `Client` class interacts with
2+
the `BaseClientSession` abstraction for callbacks.
3+
"""
4+
5+
import asyncio
6+
7+
from mcp import Client
8+
from mcp.client.base_client_session import BaseClientSession
9+
from mcp.server.mcpserver import MCPServer
10+
from mcp.shared._context import RequestContext
11+
from mcp.types import (
12+
CreateMessageRequestParams,
13+
CreateMessageResult,
14+
TextContent,
15+
)
16+
17+
18+
async def main():
19+
# 1. Create a simple server with a tool that requires sampling
20+
server = MCPServer("ExampleServer")
21+
22+
@server.tool("ask_assistant")
23+
async def ask_assistant(message: str) -> str:
24+
# The tool asks the client to sample a message (requires the sampling callback)
25+
print(f"[Server] Received request: {message}")
26+
result = await server.get_context().session.create_message(
27+
messages=[{"role": "user", "content": {"type": "text", "text": message}}],
28+
max_tokens=100,
29+
)
30+
return f"Assistant replied: {result.content.text}"
31+
32+
# 2. Define a callback typed against the abstract `BaseClientSession`.
33+
# Notice that we are NOT tied to `ClientSession` streams here!
34+
# Because of the contravariance assigned to `ClientSessionT_contra` in the
35+
# Protocol, this callback is a completely valid mathematical subtype of the
36+
# `SamplingFnT[ClientSession]` expected by `Client` during instantiation.
37+
async def abstract_sampling_callback(
38+
context: RequestContext[BaseClientSession], params: CreateMessageRequestParams
39+
) -> CreateMessageResult:
40+
print("[Client Callback] Server requested sampling via abstract callback!")
41+
42+
# We can safely use `BaseClientSession` abstract methods on `context.session`.
43+
return CreateMessageResult(
44+
role="assistant",
45+
content=TextContent(type="text", text="Hello from the abstract callback!"),
46+
model="gpt-test",
47+
stop_reason="endTurn",
48+
)
49+
50+
# 3. Instantiate the Client, injecting our abstract callback.
51+
# The SDK automatically handles the underlying streams and creates the concrete
52+
# `ClientSession`, which safely fulfills the `BaseClientSession` contract our
53+
# callback expects.
54+
async with Client(server, sampling_callback=abstract_sampling_callback) as client:
55+
print("Executing tool 'ask_assistant' from the Client...")
56+
result = await client.call_tool("ask_assistant", {"message": "Please say hello"})
57+
58+
if not result.is_error:
59+
for content in result.content:
60+
if isinstance(content, TextContent):
61+
print(f"Server Tool Output: {content.text}")
62+
63+
64+
if __name__ == "__main__":
65+
asyncio.run(main())
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Example demonstrating how to implement a custom transport
2+
that complies with `BaseClientSession` without using read/write streams or JSON-RPC.
3+
"""
4+
5+
import asyncio
6+
from typing import Any
7+
8+
from mcp import types
9+
from mcp.client.base_client_session import BaseClientSession
10+
from mcp.shared.session import ProgressFnT
11+
12+
13+
class CustomDirectSession:
14+
"""A custom MCP session that communicates with a hypothetical internal API
15+
rather than using streaming JSON-RPC.
16+
17+
It satisfies the `BaseClientSession` protocol simply by implementing the required
18+
methods – no inheritance from `BaseSession` or stream initialization required!
19+
"""
20+
21+
async def initialize(self) -> types.InitializeResult:
22+
print("[CustomSession] Initializing custom transport...")
23+
return types.InitializeResult(
24+
protocolVersion="2024-11-05",
25+
capabilities=types.ServerCapabilities(),
26+
serverInfo=types.Implementation(name="CustomDirectServer", version="1.0.0"),
27+
)
28+
29+
async def list_tools(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListToolsResult:
30+
print("[CustomSession] Fetching tools...")
31+
return types.ListToolsResult(
32+
tools=[
33+
types.Tool(
34+
name="direct_tool",
35+
description="A tool executed via direct internal Python call",
36+
inputSchema={"type": "object", "properties": {}},
37+
)
38+
]
39+
)
40+
41+
async def call_tool(
42+
self,
43+
name: str,
44+
arguments: dict[str, Any] | None = None,
45+
read_timeout_seconds: float | None = None,
46+
progress_callback: ProgressFnT | None = None,
47+
*,
48+
meta: types.RequestParamsMeta | None = None,
49+
) -> types.CallToolResult:
50+
print(f"[CustomSession] Executing tool '{name}'...")
51+
return types.CallToolResult(
52+
content=[
53+
types.TextContent(
54+
type="text", text=f"Hello from the custom transport! Tool '{name}' executed successfully."
55+
)
56+
]
57+
)
58+
59+
# Note: To fully satisfy the structural protocol of BaseClientSession for static
60+
# type checking (mypy/pyright), all protocol methods must be defined.
61+
# Here we stub the remaining methods for brevity.
62+
async def send_ping(self, *, meta: types.RequestParamsMeta | None = None) -> types.EmptyResult:
63+
return types.EmptyResult()
64+
65+
async def send_request(self, *args: Any, **kwargs: Any) -> Any:
66+
raise NotImplementedError()
67+
68+
async def send_notification(self, *args: Any, **kwargs: Any) -> None:
69+
raise NotImplementedError()
70+
71+
async def send_progress_notification(self, *args: Any, **kwargs: Any) -> None:
72+
raise NotImplementedError()
73+
74+
async def list_resources(self, *args: Any, **kwargs: Any) -> Any:
75+
raise NotImplementedError()
76+
77+
async def list_resource_templates(self, *args: Any, **kwargs: Any) -> Any:
78+
raise NotImplementedError()
79+
80+
async def read_resource(self, *args: Any, **kwargs: Any) -> Any:
81+
raise NotImplementedError()
82+
83+
async def subscribe_resource(self, *args: Any, **kwargs: Any) -> Any:
84+
raise NotImplementedError()
85+
86+
async def unsubscribe_resource(self, *args: Any, **kwargs: Any) -> Any:
87+
raise NotImplementedError()
88+
89+
async def list_prompts(self, *args: Any, **kwargs: Any) -> Any:
90+
raise NotImplementedError()
91+
92+
async def get_prompt(self, *args: Any, **kwargs: Any) -> Any:
93+
raise NotImplementedError()
94+
95+
async def complete(self, *args: Any, **kwargs: Any) -> Any:
96+
raise NotImplementedError()
97+
98+
async def set_logging_level(self, *args: Any, **kwargs: Any) -> Any:
99+
raise NotImplementedError()
100+
101+
async def send_roots_list_changed(self, *args: Any, **kwargs: Any) -> None:
102+
raise NotImplementedError()
103+
104+
105+
# ---------------------------------------------------------------------------
106+
# Using the session with code strictly typed against BaseClientSession
107+
# ---------------------------------------------------------------------------
108+
109+
async def interact_with_mcp(session: BaseClientSession) -> None:
110+
"""This function doesn't know or care if the session is communicating
111+
via stdio streams, SSE, or a custom internal API!
112+
It only depends on the abstract `BaseClientSession` methods.
113+
"""
114+
115+
# 1. Initialize
116+
init_result = await session.initialize()
117+
print(f"Connected to: {init_result.serverInfo.name}@{init_result.serverInfo.version}")
118+
119+
# 2. List Tools
120+
tools_result = await session.list_tools()
121+
for tool in tools_result.tools:
122+
print(f"Found tool: {tool.name} - {tool.description}")
123+
124+
# 3. Call Tool
125+
if tools_result.tools:
126+
call_result = await session.call_tool(tools_result.tools[0].name, arguments={})
127+
for content in call_result.content:
128+
if isinstance(content, types.TextContent):
129+
print(f"Tool Output: {content.text}")
130+
131+
132+
async def main():
133+
# Instantiate our custom non-streaming transport session
134+
custom_session = CustomDirectSession()
135+
136+
# Pass it to the generic runner!
137+
await interact_with_mcp(custom_session)
138+
139+
140+
if __name__ == "__main__":
141+
asyncio.run(main())

0 commit comments

Comments
 (0)