Skip to content

Commit 34ad073

Browse files
committed
Squashed 'astrbot-sdk/' changes from 56943300b..0a9c86345
0a9c86345 chore: refresh vendor snapshot [skip ci] b5d9b934b Merge pull request #105 from united-pooh:dev c07f04e63 feat: 更新多个客户端和模块,增强类型注解和文档说明 REVERT: 56943300b chore: refresh vendor snapshot [skip ci] git-subtree-dir: astrbot-sdk git-subtree-split: 0a9c86345ea2192154580d0ef054b72b99892b9e
1 parent 4e8009d commit 34ad073

15 files changed

Lines changed: 145 additions & 100 deletions

File tree

src/astrbot_sdk/clients/_proxy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
self._caller_plugin_id = caller_plugin_id
8181
self._request_scope_id = request_scope_id
8282

83-
def _get_descriptor(self, name: str):
83+
def _get_descriptor(self, name: str) -> _CapabilityDescriptorLike | None:
8484
"""获取能力描述符。
8585
8686
Args:

src/astrbot_sdk/clients/files.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""文件服务客户端。
2+
3+
提供文件令牌注册和令牌反查能力,封装 `system.file.*` capabilities。
4+
"""
5+
16
from __future__ import annotations
27

38
from dataclasses import dataclass
@@ -8,6 +13,8 @@
813

914
@dataclass(slots=True)
1015
class FileRegistration:
16+
"""文件注册结果。"""
17+
1118
token: str
1219
url: str
1320

@@ -20,32 +27,44 @@ def from_payload(cls, payload: dict[str, Any]) -> FileRegistration:
2027

2128

2229
class FileServiceClient:
30+
"""文件服务能力客户端。"""
31+
2332
def __init__(self, proxy: CapabilityProxy) -> None:
2433
self._proxy = proxy
2534

26-
async def register_file(
35+
async def _register(
2736
self,
2837
path: str,
29-
timeout: float | None = None,
30-
) -> str:
38+
*,
39+
timeout: float | None,
40+
) -> FileRegistration:
3141
output = await self._proxy.call(
3242
"system.file.register",
3343
{"path": str(path), "timeout": timeout},
3444
)
35-
return FileRegistration.from_payload(output).token
45+
return FileRegistration.from_payload(output)
46+
47+
async def register_file(
48+
self,
49+
path: str,
50+
timeout: float | None = None,
51+
) -> str:
52+
"""注册本地文件并返回文件令牌。"""
53+
54+
return (await self._register(path, timeout=timeout)).token
3655

3756
async def register_file_url(
3857
self,
3958
path: str,
4059
timeout: float | None = None,
4160
) -> str:
42-
output = await self._proxy.call(
43-
"system.file.register",
44-
{"path": str(path), "timeout": timeout},
45-
)
46-
return FileRegistration.from_payload(output).url
61+
"""注册本地文件并返回公开访问 URL。"""
62+
63+
return (await self._register(path, timeout=timeout)).url
4764

4865
async def handle_file(self, token: str) -> str:
66+
"""将文件令牌解析回本地文件路径。"""
67+
4968
output = await self._proxy.call(
5069
"system.file.handle",
5170
{"token": str(token)},

src/astrbot_sdk/clients/managers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ async def append(
649649
session: MessageSession,
650650
*,
651651
parts: list[BaseMessageComponent],
652-
sender: MessageHistorySender,
652+
sender: MessageHistorySender | dict[str, Any],
653653
metadata: dict[str, Any] | None = None,
654654
idempotency_key: str | None = None,
655655
) -> MessageHistoryRecord:

src/astrbot_sdk/clients/mcp.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
"""MCP 管理客户端。
2+
3+
提供本地 MCP 服务、全局 MCP 服务和临时 MCP session 的 SDK 封装。
4+
"""
5+
16
from __future__ import annotations
27

38
from contextlib import AbstractAsyncContextManager
49
from dataclasses import dataclass, field
510
from enum import Enum
11+
from types import TracebackType
612
from typing import Any
713

814
from ..errors import AstrBotError
@@ -16,6 +22,8 @@ class MCPServerScope(str, Enum):
1622

1723
@dataclass(slots=True)
1824
class MCPServerRecord:
25+
"""MCP 服务快照。"""
26+
1927
name: str
2028
scope: MCPServerScope
2129
active: bool
@@ -69,7 +77,33 @@ def from_payload(
6977
)
7078

7179

80+
def _server_records_from_payload(items: Any) -> list[MCPServerRecord]:
81+
if not isinstance(items, list):
82+
return []
83+
return [
84+
record
85+
for record in (
86+
MCPServerRecord.from_payload(item) if isinstance(item, dict) else None
87+
for item in items
88+
)
89+
if record is not None
90+
]
91+
92+
93+
def _require_server_record(
94+
payload: dict[str, Any],
95+
*,
96+
action: str,
97+
) -> MCPServerRecord:
98+
record = MCPServerRecord.from_payload(payload.get("server"))
99+
if record is None:
100+
raise ValueError(f"{action} returned no server")
101+
return record
102+
103+
72104
class MCPSession(AbstractAsyncContextManager["MCPSession"]):
105+
"""临时 MCP session 的异步上下文封装。"""
106+
73107
def __init__(
74108
self,
75109
proxy: CapabilityProxy,
@@ -106,7 +140,12 @@ async def __aenter__(self) -> MCPSession:
106140
)
107141
return self
108142

109-
async def __aexit__(self, exc_type, exc, tb) -> None:
143+
async def __aexit__(
144+
self,
145+
exc_type: type[BaseException] | None,
146+
exc: BaseException | None,
147+
tb: TracebackType | None,
148+
) -> None:
110149
session_id = self._session_id
111150
self._session_id = None
112151
self._tools = []
@@ -162,6 +201,8 @@ def _require_session_id(self) -> str:
162201

163202

164203
class MCPManagerClient:
204+
"""MCP 服务管理客户端。"""
205+
165206
def __init__(self, proxy: CapabilityProxy) -> None:
166207
self._proxy = proxy
167208

@@ -171,31 +212,15 @@ async def get_server(self, name: str) -> MCPServerRecord | None:
171212

172213
async def list_servers(self) -> list[MCPServerRecord]:
173214
output = await self._proxy.call("mcp.local.list", {})
174-
items = output.get("servers")
175-
if not isinstance(items, list):
176-
return []
177-
return [
178-
record
179-
for record in (
180-
MCPServerRecord.from_payload(item) if isinstance(item, dict) else None
181-
for item in items
182-
)
183-
if record is not None
184-
]
215+
return _server_records_from_payload(output.get("servers"))
185216

186217
async def enable_server(self, name: str) -> MCPServerRecord:
187218
output = await self._proxy.call("mcp.local.enable", {"name": str(name)})
188-
record = MCPServerRecord.from_payload(output.get("server"))
189-
if record is None:
190-
raise ValueError("mcp.local.enable returned no server")
191-
return record
219+
return _require_server_record(output, action="mcp.local.enable")
192220

193221
async def disable_server(self, name: str) -> MCPServerRecord:
194222
output = await self._proxy.call("mcp.local.disable", {"name": str(name)})
195-
record = MCPServerRecord.from_payload(output.get("server"))
196-
if record is None:
197-
raise ValueError("mcp.local.disable returned no server")
198-
return record
223+
return _require_server_record(output, action="mcp.local.disable")
199224

200225
async def wait_until_ready(
201226
self,
@@ -207,10 +232,7 @@ async def wait_until_ready(
207232
"mcp.local.wait_until_ready",
208233
{"name": str(name), "timeout": float(timeout)},
209234
)
210-
record = MCPServerRecord.from_payload(output.get("server"))
211-
if record is None:
212-
raise ValueError("mcp.local.wait_until_ready returned no server")
213-
return record
235+
return _require_server_record(output, action="mcp.local.wait_until_ready")
214236

215237
def session(
216238
self,
@@ -241,28 +263,15 @@ async def register_global_server(
241263
"timeout": float(timeout),
242264
},
243265
)
244-
record = MCPServerRecord.from_payload(output.get("server"))
245-
if record is None:
246-
raise ValueError("mcp.global.register returned no server")
247-
return record
266+
return _require_server_record(output, action="mcp.global.register")
248267

249268
async def get_global_server(self, name: str) -> MCPServerRecord | None:
250269
output = await self._proxy.call("mcp.global.get", {"name": str(name)})
251270
return MCPServerRecord.from_payload(output.get("server"))
252271

253272
async def list_global_servers(self) -> list[MCPServerRecord]:
254273
output = await self._proxy.call("mcp.global.list", {})
255-
items = output.get("servers")
256-
if not isinstance(items, list):
257-
return []
258-
return [
259-
record
260-
for record in (
261-
MCPServerRecord.from_payload(item) if isinstance(item, dict) else None
262-
for item in items
263-
)
264-
if record is not None
265-
]
274+
return _server_records_from_payload(output.get("servers"))
266275

267276
async def enable_global_server(
268277
self,
@@ -274,24 +283,15 @@ async def enable_global_server(
274283
"mcp.global.enable",
275284
{"name": str(name), "timeout": float(timeout)},
276285
)
277-
record = MCPServerRecord.from_payload(output.get("server"))
278-
if record is None:
279-
raise ValueError("mcp.global.enable returned no server")
280-
return record
286+
return _require_server_record(output, action="mcp.global.enable")
281287

282288
async def disable_global_server(self, name: str) -> MCPServerRecord:
283289
output = await self._proxy.call("mcp.global.disable", {"name": str(name)})
284-
record = MCPServerRecord.from_payload(output.get("server"))
285-
if record is None:
286-
raise ValueError("mcp.global.disable returned no server")
287-
return record
290+
return _require_server_record(output, action="mcp.global.disable")
288291

289292
async def unregister_global_server(self, name: str) -> MCPServerRecord:
290293
output = await self._proxy.call("mcp.global.unregister", {"name": str(name)})
291-
record = MCPServerRecord.from_payload(output.get("server"))
292-
if record is None:
293-
raise ValueError("mcp.global.unregister returned no server")
294-
return record
294+
return _require_server_record(output, action="mcp.global.unregister")
295295

296296

297297
__all__ = [

src/astrbot_sdk/clients/memory.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self._namespace = join_memory_namespace(namespace)
5757

5858
def namespace(self, *parts: Any) -> MemoryClient:
59-
"""Create a derived client that operates inside a child namespace."""
59+
"""创建一个工作在子命名空间中的派生客户端。"""
6060

6161
return MemoryClient(
6262
self._proxy,
@@ -203,7 +203,7 @@ async def list_keys(
203203
*,
204204
namespace: str | None = None,
205205
) -> list[str]:
206-
"""List keys in the exact namespace using case-insensitive ordering."""
206+
"""列出指定精确命名空间下的全部键。"""
207207

208208
payload: dict[str, Any] = {
209209
"namespace": self._resolve_exact_namespace(namespace)
@@ -220,7 +220,7 @@ async def exists(
220220
*,
221221
namespace: str | None = None,
222222
) -> bool:
223-
"""Check whether a key exists in the exact namespace."""
223+
"""检查指定精确命名空间中是否存在某个键。"""
224224

225225
payload: dict[str, Any] = {"key": key}
226226
payload["namespace"] = self._resolve_exact_namespace(namespace)
@@ -251,7 +251,7 @@ async def clear_namespace(
251251
namespace: str | None = None,
252252
include_descendants: bool = False,
253253
) -> int:
254-
"""Delete memories in a namespace and optionally its descendants."""
254+
"""清空命名空间中的记忆项,可选递归清空子命名空间。"""
255255

256256
payload: dict[str, Any] = {
257257
"namespace": self._resolve_exact_namespace(namespace),
@@ -364,7 +364,7 @@ async def count(
364364
namespace: str | None = None,
365365
include_descendants: bool = False,
366366
) -> int:
367-
"""Count memories in a namespace and optionally its descendants."""
367+
"""统计命名空间中的记忆项数量,可选包含子命名空间。"""
368368

369369
payload: dict[str, Any] = {
370370
"namespace": self._resolve_exact_namespace(namespace),
@@ -409,24 +409,18 @@ async def stats(
409409
"total_items": output.get("total_items", 0),
410410
"total_bytes": output.get("total_bytes"),
411411
}
412-
if "namespace" in output:
413-
stats["namespace"] = output.get("namespace")
414-
if "namespace_count" in output:
415-
stats["namespace_count"] = output.get("namespace_count")
416-
if "fts_enabled" in output:
417-
stats["fts_enabled"] = output.get("fts_enabled")
418-
if "vector_backend" in output:
419-
stats["vector_backend"] = output.get("vector_backend")
420-
if "vector_indexes" in output:
421-
stats["vector_indexes"] = output.get("vector_indexes")
422-
if "plugin_id" in output:
423-
stats["plugin_id"] = output.get("plugin_id")
424-
if "ttl_entries" in output:
425-
stats["ttl_entries"] = output.get("ttl_entries")
426-
if "indexed_items" in output:
427-
stats["indexed_items"] = output.get("indexed_items")
428-
if "embedded_items" in output:
429-
stats["embedded_items"] = output.get("embedded_items")
430-
if "dirty_items" in output:
431-
stats["dirty_items"] = output.get("dirty_items")
412+
for key in (
413+
"namespace",
414+
"namespace_count",
415+
"fts_enabled",
416+
"vector_backend",
417+
"vector_indexes",
418+
"plugin_id",
419+
"ttl_entries",
420+
"indexed_items",
421+
"embedded_items",
422+
"dirty_items",
423+
):
424+
if key in output:
425+
stats[key] = output.get(key)
432426
return stats

src/astrbot_sdk/clients/permission.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Permission capability clients."""
1+
"""权限能力客户端。"""
22

33
from __future__ import annotations
44

@@ -10,6 +10,8 @@
1010

1111

1212
class PermissionCheckResult(BaseModel):
13+
"""权限检查结果。"""
14+
1315
model_config = ConfigDict(extra="forbid")
1416

1517
is_admin: bool
@@ -26,6 +28,8 @@ def from_payload(
2628

2729

2830
class PermissionClient:
31+
"""权限查询客户端。"""
32+
2933
def __init__(self, proxy: CapabilityProxy) -> None:
3034
self._proxy = proxy
3135

@@ -52,6 +56,8 @@ async def get_admins(self) -> list[str]:
5256

5357

5458
class PermissionManagerClient:
59+
"""权限管理客户端。"""
60+
5561
def __init__(
5662
self,
5763
proxy: CapabilityProxy,

0 commit comments

Comments
 (0)