1+ """MCP 管理客户端。
2+
3+ 提供本地 MCP 服务、全局 MCP 服务和临时 MCP session 的 SDK 封装。
4+ """
5+
16from __future__ import annotations
27
38from contextlib import AbstractAsyncContextManager
49from dataclasses import dataclass , field
510from enum import Enum
11+ from types import TracebackType
612from typing import Any
713
814from ..errors import AstrBotError
@@ -16,6 +22,8 @@ class MCPServerScope(str, Enum):
1622
1723@dataclass (slots = True )
1824class MCPServerRecord :
25+ """MCP 服务快照。"""
26+
1927 name : str
2028 scope : MCPServerScope
2129 active : bool
@@ -69,7 +77,33 @@ def from_payload(
6977 )
7078
7179
80+ def _server_records_from_payload (items : Any ) -> list [MCPServerRecord ]:
81+ if not isinstance (items , list ):
82+ return []
83+ return [
84+ record
85+ for record in (
86+ MCPServerRecord .from_payload (item ) if isinstance (item , dict ) else None
87+ for item in items
88+ )
89+ if record is not None
90+ ]
91+
92+
93+ def _require_server_record (
94+ payload : dict [str , Any ],
95+ * ,
96+ action : str ,
97+ ) -> MCPServerRecord :
98+ record = MCPServerRecord .from_payload (payload .get ("server" ))
99+ if record is None :
100+ raise ValueError (f"{ action } returned no server" )
101+ return record
102+
103+
72104class MCPSession (AbstractAsyncContextManager ["MCPSession" ]):
105+ """临时 MCP session 的异步上下文封装。"""
106+
73107 def __init__ (
74108 self ,
75109 proxy : CapabilityProxy ,
@@ -106,7 +140,12 @@ async def __aenter__(self) -> MCPSession:
106140 )
107141 return self
108142
109- async def __aexit__ (self , exc_type , exc , tb ) -> None :
143+ async def __aexit__ (
144+ self ,
145+ exc_type : type [BaseException ] | None ,
146+ exc : BaseException | None ,
147+ tb : TracebackType | None ,
148+ ) -> None :
110149 session_id = self ._session_id
111150 self ._session_id = None
112151 self ._tools = []
@@ -162,6 +201,8 @@ def _require_session_id(self) -> str:
162201
163202
164203class MCPManagerClient :
204+ """MCP 服务管理客户端。"""
205+
165206 def __init__ (self , proxy : CapabilityProxy ) -> None :
166207 self ._proxy = proxy
167208
@@ -171,31 +212,15 @@ async def get_server(self, name: str) -> MCPServerRecord | None:
171212
172213 async def list_servers (self ) -> list [MCPServerRecord ]:
173214 output = await self ._proxy .call ("mcp.local.list" , {})
174- items = output .get ("servers" )
175- if not isinstance (items , list ):
176- return []
177- return [
178- record
179- for record in (
180- MCPServerRecord .from_payload (item ) if isinstance (item , dict ) else None
181- for item in items
182- )
183- if record is not None
184- ]
215+ return _server_records_from_payload (output .get ("servers" ))
185216
186217 async def enable_server (self , name : str ) -> MCPServerRecord :
187218 output = await self ._proxy .call ("mcp.local.enable" , {"name" : str (name )})
188- record = MCPServerRecord .from_payload (output .get ("server" ))
189- if record is None :
190- raise ValueError ("mcp.local.enable returned no server" )
191- return record
219+ return _require_server_record (output , action = "mcp.local.enable" )
192220
193221 async def disable_server (self , name : str ) -> MCPServerRecord :
194222 output = await self ._proxy .call ("mcp.local.disable" , {"name" : str (name )})
195- record = MCPServerRecord .from_payload (output .get ("server" ))
196- if record is None :
197- raise ValueError ("mcp.local.disable returned no server" )
198- return record
223+ return _require_server_record (output , action = "mcp.local.disable" )
199224
200225 async def wait_until_ready (
201226 self ,
@@ -207,10 +232,7 @@ async def wait_until_ready(
207232 "mcp.local.wait_until_ready" ,
208233 {"name" : str (name ), "timeout" : float (timeout )},
209234 )
210- record = MCPServerRecord .from_payload (output .get ("server" ))
211- if record is None :
212- raise ValueError ("mcp.local.wait_until_ready returned no server" )
213- return record
235+ return _require_server_record (output , action = "mcp.local.wait_until_ready" )
214236
215237 def session (
216238 self ,
@@ -241,28 +263,15 @@ async def register_global_server(
241263 "timeout" : float (timeout ),
242264 },
243265 )
244- record = MCPServerRecord .from_payload (output .get ("server" ))
245- if record is None :
246- raise ValueError ("mcp.global.register returned no server" )
247- return record
266+ return _require_server_record (output , action = "mcp.global.register" )
248267
249268 async def get_global_server (self , name : str ) -> MCPServerRecord | None :
250269 output = await self ._proxy .call ("mcp.global.get" , {"name" : str (name )})
251270 return MCPServerRecord .from_payload (output .get ("server" ))
252271
253272 async def list_global_servers (self ) -> list [MCPServerRecord ]:
254273 output = await self ._proxy .call ("mcp.global.list" , {})
255- items = output .get ("servers" )
256- if not isinstance (items , list ):
257- return []
258- return [
259- record
260- for record in (
261- MCPServerRecord .from_payload (item ) if isinstance (item , dict ) else None
262- for item in items
263- )
264- if record is not None
265- ]
274+ return _server_records_from_payload (output .get ("servers" ))
266275
267276 async def enable_global_server (
268277 self ,
@@ -274,24 +283,15 @@ async def enable_global_server(
274283 "mcp.global.enable" ,
275284 {"name" : str (name ), "timeout" : float (timeout )},
276285 )
277- record = MCPServerRecord .from_payload (output .get ("server" ))
278- if record is None :
279- raise ValueError ("mcp.global.enable returned no server" )
280- return record
286+ return _require_server_record (output , action = "mcp.global.enable" )
281287
282288 async def disable_global_server (self , name : str ) -> MCPServerRecord :
283289 output = await self ._proxy .call ("mcp.global.disable" , {"name" : str (name )})
284- record = MCPServerRecord .from_payload (output .get ("server" ))
285- if record is None :
286- raise ValueError ("mcp.global.disable returned no server" )
287- return record
290+ return _require_server_record (output , action = "mcp.global.disable" )
288291
289292 async def unregister_global_server (self , name : str ) -> MCPServerRecord :
290293 output = await self ._proxy .call ("mcp.global.unregister" , {"name" : str (name )})
291- record = MCPServerRecord .from_payload (output .get ("server" ))
292- if record is None :
293- raise ValueError ("mcp.global.unregister returned no server" )
294- return record
294+ return _require_server_record (output , action = "mcp.global.unregister" )
295295
296296
297297__all__ = [
0 commit comments