forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsession_group.py
More file actions
271 lines (216 loc) · 9.6 KB
/
session_group.py
File metadata and controls
271 lines (216 loc) · 9.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
SessionGroup concurrently manages multiple MCP session connections.
Tools, resources, and prompts are aggregated across servers. Servers may
be connected to or disconnected from at any point after initialization.
This abstractions can handle naming collisions using a custom user-provided
hook.
"""
import contextlib
from collections.abc import Callable
from datetime import timedelta
from typing import Any, TypeAlias
from pydantic import BaseModel
import mcp
from mcp import types
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.exceptions import McpError
class SseServerParameters(BaseModel):
"""Parameters for intializing a sse_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: float = 5
# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
class StreamableHttpParameters(BaseModel):
"""Parameters for intializing a streamablehttp_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Close the client session when the transport closes.
terminate_on_close: bool = True
ServerParameters: TypeAlias = (
StdioServerParameters | SseServerParameters | StreamableHttpParameters
)
class ClientSessionGroup:
"""Client for managing connections to multiple MCP servers.
This class is responsible for encapsulating management of server connections.
It it aggregates tools, resources, and prompts from all connected servers.
For auxiliary handlers, such as resource subscription, this is delegated to
the client and can be accessed via the session. For example:
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
"""
class _ComponentNames(BaseModel):
"""Used for reverse index to find components."""
prompts: set[str] = set()
resources: set[str] = set()
tools: set[str] = set()
# Standard MCP components.
_prompts: dict[str, types.Prompt]
_resources: dict[str, types.Resource]
_tools: dict[str, types.Tool]
# Client-server connection management.
_sessions: dict[mcp.ClientSession, _ComponentNames]
_tool_to_session: dict[str, mcp.ClientSession]
_exit_stack: contextlib.AsyncExitStack
# Optional fn consuming (component_name, serverInfo) for custom names.
# This is provide a means to mitigate naming conflicts across servers.
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
_component_name_hook: _ComponentNameHook | None
def __init__(
self,
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
component_name_hook: _ComponentNameHook | None = None,
) -> None:
"""Initializes the MCP client."""
self._tools = {}
self._resources = {}
self._prompts = {}
self._sessions = {}
self._tool_to_session = {}
self._exit_stack = exit_stack
self._component_name_hook = component_name_hook
@property
def prompts(self) -> dict[str, types.Prompt]:
"""Returns the prompts as a dictionary of names to prompts."""
return self._prompts
@property
def resources(self) -> dict[str, types.Resource]:
"""Returns the resources as a dictionary of names to resources."""
return self._resources
@property
def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
return await session.call_tool(name, args)
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
"""Disconnects from a single MCP server."""
if session not in self._sessions:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message="Provided session is not being managed.",
)
)
component_names = self._sessions[session]
# Remove prompts associated with the session.
for name in component_names.prompts:
del self._prompts[name]
# Remove resources associated with the session.
for name in component_names.resources:
del self._resources[name]
# Remove tools associated with the session.
for name in component_names.tools:
del self._tools[name]
del self._sessions[session]
async def connect_to_server(
self,
server_params: ServerParameters,
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
# Establish server connection and create session.
server_info, session = await self._establish_session(server_params)
# Create a reverse index so we can find all prompts, resources, and
# tools belonging to this session. Used for removing components from
# the session group via self.disconnect_from_server.
component_names = self._ComponentNames()
# Temporary components dicts. We do not want to modify the aggregate
# lists in case of an intermediate failure.
prompts_temp: dict[str, types.Prompt] = {}
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
# Query the server for its prompts and aggregate to list.
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
if name in self._prompts:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{name} already exists in group prompts.",
)
)
prompts_temp[name] = prompt
component_names.prompts.add(name)
# Query the server for its resources and aggregate to list.
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
if name in self._resources:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{name} already exists in group resources.",
)
)
resources_temp[name] = resource
component_names.resources.add(name)
# Query the server for its tools and aggregate to list.
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
if name in self._tools:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{name} already exists in group tools.",
)
)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
# Aggregate components.
self._sessions[session] = component_names
self._prompts.update(prompts_temp)
self._resources.update(resources_temp)
self._tools.update(tools_temp)
self._tool_to_session.update(tool_to_session_temp)
return session
async def _establish_session(
self, server_params: ServerParameters
) -> tuple[types.Implementation, mcp.ClientSession]:
"""Establish a client session to an MCP server."""
# Create read and write streams that facilitate io with the server.
if isinstance(server_params, StdioServerParameters):
client = mcp.stdio_client(server_params)
read, write = await self._exit_stack.enter_async_context(client)
elif isinstance(server_params, SseServerParameters):
client = sse_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
)
read, write = await self._exit_stack.enter_async_context(client)
else:
client = streamablehttp_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
terminate_on_close=server_params.terminate_on_close,
)
read, write, _ = await self._exit_stack.enter_async_context(client)
session = await self._exit_stack.enter_async_context(
mcp.ClientSession(read, write)
)
result = await session.initialize()
return result.serverInfo, session
def _component_name(self, name: str, server_info: types.Implementation) -> str:
if self._component_name_hook:
return self._component_name_hook(name, server_info)
return name