Skip to content

Commit 7fbe91f

Browse files
committed
feat: add dynamic registration probe plugin with capabilities for skill and MCP management
- Implemented DynamicRegistrationProbe class with methods for registering, listing, and unregistering skills. - Added capabilities for managing global MCP servers including registration, listing, enabling, disabling, and unregistration. - Created plugin.yaml for dynamic_registration_probe with necessary metadata. - Added runtime probe skill documentation. - Developed unit tests for context API round trip, including file, platform, provider, and session management. - Implemented tests for dynamic registration and lifecycle of skills and MCP servers. - Ensured proper handling of global MCP risk acknowledgment in plugin capabilities.
1 parent d397953 commit 7fbe91f

10 files changed

Lines changed: 1263 additions & 0 deletions
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from pathlib import Path
2+
3+
from astrbot_sdk import Context, Star, acknowledge_global_mcp_risk
4+
from astrbot_sdk.decorators import provide_capability
5+
6+
7+
@acknowledge_global_mcp_risk
8+
class DynamicRegistrationProbe(Star):
9+
@staticmethod
10+
def _skill_dir() -> Path:
11+
return Path(__file__).resolve().parent / "skills" / "runtime_probe"
12+
13+
@staticmethod
14+
def _skill_payload(record) -> dict:
15+
return {
16+
"name": record.name,
17+
"description": record.description,
18+
"path": record.path,
19+
"skill_dir": record.skill_dir,
20+
}
21+
22+
@staticmethod
23+
def _mcp_payload(record) -> dict | None:
24+
if record is None:
25+
return None
26+
return {
27+
"name": record.name,
28+
"scope": record.scope.value,
29+
"active": record.active,
30+
"running": record.running,
31+
"config": dict(record.config),
32+
"tools": list(record.tools),
33+
"errlogs": list(record.errlogs),
34+
"last_error": record.last_error,
35+
}
36+
37+
@provide_capability(
38+
"dynamic_probe.skill.register",
39+
description="Register the probe skill through ctx.skills",
40+
)
41+
async def register_skill_capability(self, payload: dict, ctx: Context) -> dict:
42+
description = str(payload.get("description", "Runtime probe skill"))
43+
record = await ctx.skills.register(
44+
name=str(payload.get("name", "dynamic_probe.runtime_probe")),
45+
path=str(self._skill_dir()),
46+
description=description,
47+
)
48+
return self._skill_payload(record)
49+
50+
@provide_capability(
51+
"dynamic_probe.skill.list",
52+
description="List registered probe skills through ctx.skills",
53+
)
54+
async def list_skill_capability(self, payload: dict, ctx: Context) -> dict:
55+
del payload
56+
items = await ctx.skills.list()
57+
return {"skills": [self._skill_payload(item) for item in items]}
58+
59+
@provide_capability(
60+
"dynamic_probe.skill.unregister",
61+
description="Unregister the probe skill through ctx.skills",
62+
)
63+
async def unregister_skill_capability(self, payload: dict, ctx: Context) -> dict:
64+
removed = await ctx.skills.unregister(
65+
str(payload.get("name", "dynamic_probe.runtime_probe"))
66+
)
67+
return {"removed": bool(removed)}
68+
69+
@provide_capability(
70+
"dynamic_probe.mcp.global.register",
71+
description="Register a global MCP server through ctx.mcp",
72+
)
73+
async def register_global_mcp_capability(self, payload: dict, ctx: Context) -> dict:
74+
record = await ctx.mcp.register_global_server(
75+
str(payload.get("name", "probe-global")),
76+
dict(payload.get("config", {"mock_tools": ["inspect"]})),
77+
timeout=float(payload.get("timeout", 0.2)),
78+
)
79+
return {"server": self._mcp_payload(record)}
80+
81+
@provide_capability(
82+
"dynamic_probe.mcp.global.get",
83+
description="Get a global MCP server through ctx.mcp",
84+
)
85+
async def get_global_mcp_capability(self, payload: dict, ctx: Context) -> dict:
86+
record = await ctx.mcp.get_global_server(
87+
str(payload.get("name", "probe-global"))
88+
)
89+
return {"server": self._mcp_payload(record)}
90+
91+
@provide_capability(
92+
"dynamic_probe.mcp.global.list",
93+
description="List global MCP servers through ctx.mcp",
94+
)
95+
async def list_global_mcp_capability(self, payload: dict, ctx: Context) -> dict:
96+
del payload
97+
records = await ctx.mcp.list_global_servers()
98+
return {"servers": [self._mcp_payload(record) for record in records]}
99+
100+
@provide_capability(
101+
"dynamic_probe.mcp.global.disable",
102+
description="Disable a global MCP server through ctx.mcp",
103+
)
104+
async def disable_global_mcp_capability(self, payload: dict, ctx: Context) -> dict:
105+
record = await ctx.mcp.disable_global_server(
106+
str(payload.get("name", "probe-global"))
107+
)
108+
return {"server": self._mcp_payload(record)}
109+
110+
@provide_capability(
111+
"dynamic_probe.mcp.global.enable",
112+
description="Enable a global MCP server through ctx.mcp",
113+
)
114+
async def enable_global_mcp_capability(self, payload: dict, ctx: Context) -> dict:
115+
record = await ctx.mcp.enable_global_server(
116+
str(payload.get("name", "probe-global")),
117+
timeout=float(payload.get("timeout", 0.2)),
118+
)
119+
return {"server": self._mcp_payload(record)}
120+
121+
@provide_capability(
122+
"dynamic_probe.mcp.global.unregister",
123+
description="Unregister a global MCP server through ctx.mcp",
124+
)
125+
async def unregister_global_mcp_capability(
126+
self,
127+
payload: dict,
128+
ctx: Context,
129+
) -> dict:
130+
record = await ctx.mcp.unregister_global_server(
131+
str(payload.get("name", "probe-global"))
132+
)
133+
return {"server": self._mcp_payload(record)}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_schema_version: 2
2+
name: dynamic_registration_probe
3+
author: tests
4+
version: 1.0.0
5+
desc: Dynamic registration probe plugin
6+
7+
runtime:
8+
python: "3.12"
9+
10+
components:
11+
- class: main:DynamicRegistrationProbe
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Runtime Probe Skill
2+
3+
This skill exists to validate runtime registration through the SDK context.

tests/test_sdk/unit/_context_api_roundtrip.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import json
45
import sys
56
import types
67
import uuid
@@ -245,6 +246,9 @@ def upsert_plugin(
245246
"version": str(metadata.get("version", "1.0.0")),
246247
"enabled": bool(metadata.get("enabled", True)),
247248
"reserved": bool(metadata.get("reserved", False)),
249+
"acknowledge_global_mcp_risk": bool(
250+
metadata.get("acknowledge_global_mcp_risk", False)
251+
),
248252
"support_platforms": list(metadata.get("support_platforms", [])),
249253
}
250254
self._plugin_configs.setdefault(plugin_id, dict(config or {}))
@@ -532,6 +536,70 @@ def unregister_skill(self, *, plugin_id: str, name: str) -> bool:
532536
def list_registered_skills(self, plugin_id: str) -> list[dict[str, str]]:
533537
return [dict(item) for item in self._skill_records.get(plugin_id, [])]
534538

539+
def acknowledges_global_mcp_risk(self, plugin_id: str) -> bool:
540+
metadata = self._plugin_metadata.get(str(plugin_id), {})
541+
return bool(metadata.get("acknowledge_global_mcp_risk", False))
542+
543+
def remove_plugin(self, plugin_id: str) -> None:
544+
normalized_plugin_id = str(plugin_id)
545+
self._plugin_metadata.pop(normalized_plugin_id, None)
546+
self._plugin_configs.pop(normalized_plugin_id, None)
547+
self._skill_records.pop(normalized_plugin_id, None)
548+
self._handlers_by_plugin.pop(normalized_plugin_id, None)
549+
self.http_routes.pop(normalized_plugin_id, None)
550+
self._latest_request_context_by_plugin.pop(normalized_plugin_id, None)
551+
request_ids = [
552+
request_id
553+
for request_id in self._request_contexts
554+
if self.resolve_request_plugin_id(request_id) == normalized_plugin_id
555+
]
556+
for request_id in request_ids:
557+
request_context = self._request_contexts.pop(request_id, None)
558+
self._request_overlays.pop(request_id, None)
559+
if request_context is None:
560+
continue
561+
self._request_contexts_by_token.pop(request_context.dispatch_token, None)
562+
563+
564+
class FakeFunctionToolManager:
565+
def __init__(self) -> None:
566+
self.func_list: list[object] = []
567+
self._config: dict[str, Any] = {"mcpServers": {}}
568+
self.mcp_server_runtime_view: dict[str, Any] = {}
569+
570+
def load_mcp_config(self) -> dict[str, Any]:
571+
return json.loads(json.dumps(self._config))
572+
573+
def save_mcp_config(self, config: dict[str, Any]) -> bool:
574+
self._config = json.loads(json.dumps(config))
575+
return True
576+
577+
async def enable_mcp_server(
578+
self,
579+
name: str,
580+
config: dict[str, Any],
581+
*_args,
582+
**_kwargs,
583+
) -> None:
584+
tools = [
585+
SimpleNamespace(name=str(tool_name))
586+
for tool_name in config.get("mock_tools", [f"{name}_tool"])
587+
if str(tool_name).strip()
588+
]
589+
self.mcp_server_runtime_view[str(name)] = SimpleNamespace(
590+
client=SimpleNamespace(tools=tools, server_errlogs=[]),
591+
)
592+
593+
async def disable_mcp_server(
594+
self,
595+
name: str | None = None,
596+
**_kwargs,
597+
) -> None:
598+
if name is None:
599+
self.mcp_server_runtime_view.clear()
600+
return
601+
self.mcp_server_runtime_view.pop(str(name), None)
602+
535603

536604
@dataclass(slots=True)
537605
class FakeProviderMeta:
@@ -892,12 +960,14 @@ def __init__(
892960
self,
893961
*,
894962
plugin_bridge: FakePluginBridge,
963+
func_tool_manager: FakeFunctionToolManager,
895964
provider_manager: FakeProviderManager,
896965
platforms: list[FakePlatform],
897966
config: FakeConfig,
898967
message_history_manager: FakeMessageHistoryManager,
899968
) -> None:
900969
self._plugin_bridge = plugin_bridge
970+
self._func_tool_manager = func_tool_manager
901971
self.provider_manager = provider_manager
902972
self.platform_manager = SimpleNamespace(get_insts=lambda: list(platforms))
903973
self._config = config
@@ -922,6 +992,9 @@ async def send_message(self, session: str, message_chain: MessageChain) -> None:
922992
def get_config(self) -> FakeConfig:
923993
return self._config
924994

995+
def get_llm_tool_manager(self) -> FakeFunctionToolManager:
996+
return self._func_tool_manager
997+
925998
def get_all_stars(self) -> list[Any]:
926999
return [
9271000
SimpleNamespace(
@@ -1028,6 +1101,7 @@ class RoundTripRuntime:
10281101
bridge: CoreCapabilityBridge
10291102
peer: BridgeBackedPeer
10301103
plugin_bridge: FakePluginBridge
1104+
func_tool_manager: FakeFunctionToolManager
10311105
runtime_sp: FakeRuntimeSP
10321106
star_context: FakeStarContext
10331107
provider_manager: FakeProviderManager
@@ -1114,11 +1188,13 @@ def build_roundtrip_runtime(
11141188
file_token_service = FakeFileTokenService()
11151189
config = FakeConfig()
11161190
plugin_bridge = FakePluginBridge()
1191+
func_tool_manager = FakeFunctionToolManager()
11171192
chat_provider = FakeChatProvider("chat-provider-a", model="gpt-roundtrip")
11181193
provider_manager = FakeProviderManager(chat_provider)
11191194
message_history_manager = FakeMessageHistoryManager()
11201195
star_context = FakeStarContext(
11211196
plugin_bridge=plugin_bridge,
1197+
func_tool_manager=func_tool_manager,
11221198
provider_manager=provider_manager,
11231199
platforms=[FakePlatform()],
11241200
config=config,
@@ -1156,6 +1232,7 @@ def build_roundtrip_runtime(
11561232
bridge=bridge,
11571233
peer=peer,
11581234
plugin_bridge=plugin_bridge,
1235+
func_tool_manager=func_tool_manager,
11591236
runtime_sp=runtime_sp,
11601237
star_context=star_context,
11611238
provider_manager=provider_manager,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# ruff: noqa: E402
2+
"""Files 客户端 Core Bridge 集成测试。
3+
4+
测试覆盖 01_context_api.md 中 ctx.files 的所有方法:
5+
- register_file(): 注册文件并获取令牌
6+
- handle_file(): 通过令牌解析文件路径
7+
"""
8+
from __future__ import annotations
9+
10+
from pathlib import Path
11+
12+
import pytest
13+
14+
from tests.test_sdk.unit._context_api_roundtrip import build_roundtrip_runtime
15+
16+
17+
@pytest.mark.unit
18+
@pytest.mark.asyncio
19+
async def test_context_files_register_file_returns_token(tmp_path, monkeypatch):
20+
"""register_file 注册文件并返回 token。"""
21+
runtime = build_roundtrip_runtime(monkeypatch, tmp_path=tmp_path)
22+
ctx = runtime.make_context("plugin-a")
23+
24+
# 创建测试文件
25+
test_file = tmp_path / "test_image.jpg"
26+
test_file.write_text("fake image content")
27+
28+
token = await ctx.files.register_file(str(test_file))
29+
30+
assert token is not None
31+
assert token.startswith("file-token-")
32+
33+
34+
@pytest.mark.unit
35+
@pytest.mark.asyncio
36+
async def test_context_files_register_file_with_timeout(tmp_path, monkeypatch):
37+
"""register_file 支持 timeout 参数。"""
38+
runtime = build_roundtrip_runtime(monkeypatch, tmp_path=tmp_path)
39+
ctx = runtime.make_context("plugin-a")
40+
41+
test_file = tmp_path / "timeout_test.png"
42+
test_file.write_text("content")
43+
44+
token = await ctx.files.register_file(str(test_file), timeout=3600)
45+
46+
assert token is not None
47+
48+
49+
@pytest.mark.unit
50+
@pytest.mark.asyncio
51+
async def test_context_files_handle_file_resolves_token(tmp_path, monkeypatch):
52+
"""handle_file 通过 token 解析回原始文件路径。"""
53+
runtime = build_roundtrip_runtime(monkeypatch, tmp_path=tmp_path)
54+
ctx = runtime.make_context("plugin-a")
55+
56+
test_file = tmp_path / "resolve_test.txt"
57+
test_file.write_text("test content")
58+
59+
# 先注册
60+
token = await ctx.files.register_file(str(test_file))
61+
62+
# 再解析
63+
resolved_path = await ctx.files.handle_file(token)
64+
65+
assert Path(resolved_path) == test_file
66+
67+
68+
@pytest.mark.unit
69+
@pytest.mark.asyncio
70+
async def test_context_files_round_trip_workflow(tmp_path, monkeypatch):
71+
"""完整的文件注册和解析工作流。"""
72+
runtime = build_roundtrip_runtime(monkeypatch, tmp_path=tmp_path)
73+
ctx = runtime.make_context("plugin-a")
74+
75+
# 创建多个测试文件
76+
files = []
77+
for i in range(3):
78+
file_path = tmp_path / f"file_{i}.dat"
79+
file_path.write_text(f"content {i}")
80+
files.append(file_path)
81+
82+
# 注册所有文件
83+
tokens = []
84+
for file_path in files:
85+
token = await ctx.files.register_file(str(file_path))
86+
tokens.append(token)
87+
88+
# 验证每个 token 都能解析回正确的路径
89+
for token, expected_path in zip(tokens, files):
90+
resolved = await ctx.files.handle_file(token)
91+
assert Path(resolved) == expected_path

0 commit comments

Comments
 (0)