Skip to content

Commit 33eb42c

Browse files
committed
feat: support remote websocket workers in sdk runtime
1 parent 0fe1155 commit 33eb42c

10 files changed

Lines changed: 866 additions & 199 deletions

File tree

astrbot-sdk/src/astrbot_sdk/cli.py

Lines changed: 149 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,39 @@ def _build_plugin(plugin_dir: Path, output_dir: Path | None) -> None:
11441144
click.echo(f"artifact: {archive_path}")
11451145

11461146

1147+
def _run_websocket_worker_entrypoint(
1148+
*,
1149+
worker_id: str | None,
1150+
plugin_dirs: tuple[Path, ...],
1151+
host: str,
1152+
port: int,
1153+
path: str,
1154+
tls_ca_file: Path,
1155+
tls_cert_file: Path,
1156+
tls_key_file: Path,
1157+
) -> None:
1158+
resolved_plugin_dirs = list(plugin_dirs) if plugin_dirs else [Path.cwd()]
1159+
_run_async_entrypoint(
1160+
run_websocket_server(
1161+
worker_id=worker_id,
1162+
plugin_dirs=resolved_plugin_dirs,
1163+
host=host,
1164+
port=port,
1165+
path=path,
1166+
tls_ca_file=tls_ca_file,
1167+
tls_cert_file=tls_cert_file,
1168+
tls_key_file=tls_key_file,
1169+
),
1170+
log_message=f"启动 WebSocket Worker,端口:{port}",
1171+
context={
1172+
"worker_id": worker_id,
1173+
"plugin_dirs": resolved_plugin_dirs,
1174+
"port": port,
1175+
"path": path,
1176+
},
1177+
)
1178+
1179+
11471180
@click.group()
11481181
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose output")
11491182
@click.pass_context
@@ -1161,20 +1194,37 @@ def cli(ctx, verbose: bool) -> None:
11611194
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
11621195
help="Directory containing plugin folders",
11631196
)
1197+
@click.option(
1198+
"--workers-manifest",
1199+
default=None,
1200+
type=click.Path(file_okay=True, dir_okay=False, path_type=Path),
1201+
help="Supervisor manifest describing remote websocket workers",
1202+
)
11641203
@click.option(
11651204
"--protocol-stdout",
11661205
default=None,
11671206
type=str,
11681207
help="Redirect runtime protocol stdout to console, silent, or a file path",
11691208
)
1170-
def run(plugins_dir: Path, protocol_stdout: str | None) -> None:
1209+
def run(
1210+
plugins_dir: Path,
1211+
workers_manifest: Path | None,
1212+
protocol_stdout: str | None,
1213+
) -> None:
11711214
"""Start the plugin supervisor over stdio."""
11721215
transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout)
11731216
try:
11741217
_run_async_entrypoint(
1175-
run_supervisor(plugins_dir=plugins_dir, stdout=transport_stdout),
1218+
run_supervisor(
1219+
plugins_dir=plugins_dir,
1220+
stdout=transport_stdout,
1221+
workers_manifest=workers_manifest,
1222+
),
11761223
log_message=f"启动插件主管进程,插件目录:{plugins_dir}",
1177-
context={"plugins_dir": plugins_dir},
1224+
context={
1225+
"plugins_dir": plugins_dir,
1226+
"workers_manifest": workers_manifest,
1227+
},
11781228
)
11791229
finally:
11801230
if opened_stdout is not None:
@@ -1362,12 +1412,101 @@ def worker(
13621412
opened_stdout.close()
13631413

13641414

1415+
@cli.command("serve-worker")
1416+
@click.option("--worker-id", default=None, type=str, help="Stable websocket worker id")
1417+
@click.option(
1418+
"--plugin-dir",
1419+
"plugin_dirs",
1420+
multiple=True,
1421+
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
1422+
help="Plugin directory to serve; repeat to host multiple plugins in one worker",
1423+
)
1424+
@click.option("--host", default="127.0.0.1", show_default=True)
1425+
@click.option("--port", default=8765, type=int, show_default=True)
1426+
@click.option("--path", default="/", show_default=True)
1427+
@click.option(
1428+
"--tls-ca-file",
1429+
required=True,
1430+
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
1431+
)
1432+
@click.option(
1433+
"--tls-cert-file",
1434+
required=True,
1435+
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
1436+
)
1437+
@click.option(
1438+
"--tls-key-file",
1439+
required=True,
1440+
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
1441+
)
1442+
def serve_worker(
1443+
worker_id: str | None,
1444+
plugin_dirs: tuple[Path, ...],
1445+
host: str,
1446+
port: int,
1447+
path: str,
1448+
tls_ca_file: Path,
1449+
tls_cert_file: Path,
1450+
tls_key_file: Path,
1451+
) -> None:
1452+
"""Serve one or more plugins as a standalone websocket worker."""
1453+
_run_websocket_worker_entrypoint(
1454+
worker_id=worker_id,
1455+
plugin_dirs=plugin_dirs,
1456+
host=host,
1457+
port=port,
1458+
path=path,
1459+
tls_ca_file=tls_ca_file,
1460+
tls_cert_file=tls_cert_file,
1461+
tls_key_file=tls_key_file,
1462+
)
1463+
1464+
13651465
@cli.command(hidden=True)
1366-
@click.option("--port", default=8765, type=int, help="WebSocket server port")
1367-
def websocket(port: int) -> None:
1368-
"""WebSocket runtime entrypoint kept for standalone bridge scenarios."""
1369-
_run_async_entrypoint(
1370-
run_websocket_server(port=port),
1371-
log_message=f"启动 WebSocket 服务器,端口:{port}",
1372-
context={"port": port},
1466+
@click.option("--worker-id", default=None, type=str)
1467+
@click.option(
1468+
"--plugin-dir",
1469+
"plugin_dirs",
1470+
multiple=True,
1471+
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
1472+
)
1473+
@click.option("--host", default="127.0.0.1", show_default=True)
1474+
@click.option("--port", default=8765, type=int, show_default=True)
1475+
@click.option("--path", default="/", show_default=True)
1476+
@click.option(
1477+
"--tls-ca-file",
1478+
required=True,
1479+
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
1480+
)
1481+
@click.option(
1482+
"--tls-cert-file",
1483+
required=True,
1484+
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
1485+
)
1486+
@click.option(
1487+
"--tls-key-file",
1488+
required=True,
1489+
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=Path),
1490+
)
1491+
def websocket(
1492+
worker_id: str | None,
1493+
plugin_dirs: tuple[Path, ...],
1494+
host: str,
1495+
port: int,
1496+
path: str,
1497+
tls_ca_file: Path,
1498+
tls_cert_file: Path,
1499+
tls_key_file: Path,
1500+
) -> None:
1501+
"""Deprecated websocket runtime wrapper for standalone worker scenarios."""
1502+
logger.warning("'astr websocket' is deprecated; use 'astr serve-worker' instead")
1503+
_run_websocket_worker_entrypoint(
1504+
worker_id=worker_id,
1505+
plugin_dirs=plugin_dirs,
1506+
host=host,
1507+
port=port,
1508+
path=path,
1509+
tls_ca_file=tls_ca_file,
1510+
tls_cert_file=tls_cert_file,
1511+
tls_key_file=tls_key_file,
13731512
)

astrbot-sdk/src/astrbot_sdk/llm/agents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def from_payload(cls, payload: dict[str, Any]) -> AgentSpec:
2828

2929

3030
class BaseAgentRunner(ABC):
31-
"""P0.5 agent registration surface.
31+
""" agent registration surface.
3232
33-
P0.5 only supports agent registration metadata. Actual execution remains
33+
only supports agent registration metadata. Actual execution remains
3434
owned by the core tool loop and is not directly callable from SDK plugins.
3535
"""
3636

astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@
2828
_sdk_source_dir,
2929
_wait_for_shutdown,
3030
)
31-
from .transport import StdioTransport, WebSocketServerTransport
32-
from .worker import GroupWorkerRuntime, PluginWorkerRuntime
31+
from .transport import (
32+
StdioTransport,
33+
WebSocketServerTransport,
34+
build_websocket_server_ssl_context,
35+
)
36+
from .worker import GroupWorkerRuntime, PluginWorkerRuntime, _load_plugin_specs
3337

3438
__all__ = [
3539
"GroupWorkerRuntime",
@@ -52,6 +56,7 @@ async def run_supervisor(
5256
stdin: IO[str] | None = None,
5357
stdout: IO[str] | None = None,
5458
env_manager: PluginEnvironmentManager | None = None,
59+
workers_manifest: Path | None = None,
5560
) -> None:
5661
transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport(
5762
stdin,
@@ -62,6 +67,7 @@ async def run_supervisor(
6267
transport=transport,
6368
plugins_dir=plugins_dir,
6469
env_manager=env_manager,
70+
workers_manifest=workers_manifest,
6571
)
6672

6773
try:
@@ -115,15 +121,47 @@ async def run_plugin_worker(
115121

116122
async def run_websocket_server(
117123
*,
124+
worker_id: str | None = None,
118125
host: str = "127.0.0.1",
119126
port: int = 8765,
120127
path: str = "/",
121-
plugin_dir: Path | None = None,
128+
plugin_dirs: list[Path] | None = None,
129+
tls_ca_file: Path | None = None,
130+
tls_cert_file: Path | None = None,
131+
tls_key_file: Path | None = None,
122132
) -> None:
123-
runtime = PluginWorkerRuntime(
124-
plugin_dir=plugin_dir or Path.cwd(),
125-
transport=WebSocketServerTransport(host=host, port=port, path=path),
133+
resolved_plugin_dirs = [path.resolve() for path in (plugin_dirs or [Path.cwd()])]
134+
if tls_ca_file is None or tls_cert_file is None or tls_key_file is None:
135+
raise ValueError(
136+
"tls_ca_file, tls_cert_file, and tls_key_file are required for websocket workers"
137+
)
138+
transport = WebSocketServerTransport(
139+
host=host,
140+
port=port,
141+
path=path,
142+
ssl_context=build_websocket_server_ssl_context(
143+
ca_file=tls_ca_file,
144+
cert_file=tls_cert_file,
145+
key_file=tls_key_file,
146+
),
126147
)
148+
resolved_worker_id = worker_id
149+
if resolved_worker_id is None and len(resolved_plugin_dirs) == 1:
150+
resolved_worker_id = _load_plugin_specs([resolved_plugin_dirs[0]])[0].name
151+
if len(resolved_plugin_dirs) == 1:
152+
runtime = PluginWorkerRuntime(
153+
plugin_dir=resolved_plugin_dirs[0],
154+
worker_id=resolved_worker_id,
155+
transport=transport,
156+
)
157+
else:
158+
if resolved_worker_id is None:
159+
raise ValueError("worker_id is required when serving multiple plugins")
160+
runtime = GroupWorkerRuntime(
161+
plugin_dirs=resolved_plugin_dirs,
162+
worker_id=resolved_worker_id,
163+
transport=transport,
164+
)
127165
try:
128166
await runtime.start()
129167
stop_event = asyncio.Event()

astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,17 +346,26 @@ def upsert_plugin(
346346
normalized_metadata.setdefault("support_platforms", [])
347347
normalized_metadata.setdefault("astrbot_version", None)
348348
local_mcp_servers = normalized_metadata.pop("local_mcp_servers", {})
349-
self._plugins[name] = _RegisteredPlugin(
350-
metadata=normalized_metadata,
351-
config=dict(config or {}),
352-
handlers=[],
353-
local_mcp_servers={
349+
normalized_servers = (
350+
{
354351
str(server_name): dict(server_payload)
355352
for server_name, server_payload in local_mcp_servers.items()
356353
if str(server_name).strip() and isinstance(server_payload, dict)
357354
}
358355
if isinstance(local_mcp_servers, dict)
359-
else {},
356+
else {}
357+
)
358+
existing = self._plugins.get(name)
359+
if existing is not None:
360+
existing.metadata = normalized_metadata
361+
existing.config = dict(config or {})
362+
existing.local_mcp_servers = normalized_servers
363+
return
364+
self._plugins[name] = _RegisteredPlugin(
365+
metadata=normalized_metadata,
366+
config=dict(config or {}),
367+
handlers=[],
368+
local_mcp_servers=normalized_servers,
360369
)
361370

362371
def set_plugin_handlers(

astrbot-sdk/src/astrbot_sdk/runtime/peer.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ def _select_negotiated_protocol_version(
133133
remote_metadata: dict[str, Any],
134134
local_supported_versions: Sequence[str],
135135
) -> str | None:
136+
"""从双方支持的版本中选出最佳兼容版本。
137+
138+
协商策略:优先精确匹配,否则在同主版本号范围内选双方都支持的最高版本。
139+
排除比请求版本更高的候选,因为远端能提供高于我们请求的版本说明我们本地
140+
尚未实现该版本协议,无法正确处理对应的协议消息。
141+
"""
136142
if requested_version in local_supported_versions:
137143
return requested_version
138144
requested_key = _parse_protocol_version(requested_version)
@@ -241,7 +247,15 @@ async def start(self) -> None:
241247
self._transport_watch_task = asyncio.create_task(self._watch_transport_closed())
242248

243249
async def stop(self) -> None:
244-
"""关闭 `Peer` 并清理所有挂起中的请求、流和入站任务。"""
250+
"""关闭 `Peer` 并清理所有挂起中的请求、流和入站任务。
251+
252+
重入安全性:transport.stop() 关闭底层连接时会触发原始消息处理器的
253+
异常路径,该路径调用 _fail_connection() -> _schedule_stop() -> stop(),
254+
形成间接递归。_stopping 标志和 _stop_task 引用共同防止重复清理资源。
255+
使用 asyncio.shield 等待是因为:如果当前任务在等待另一个 stop() 完成
256+
期间被取消,shield 保护内部 stop_task 不被连带取消,避免 Peer 停留在
257+
半关闭状态。
258+
"""
245259
if self._closed.is_set():
246260
return
247261
current_task = asyncio.current_task()
@@ -275,7 +289,9 @@ async def stop(self) -> None:
275289
await self.transport.stop()
276290
self._closed.set()
277291
finally:
278-
# 只在当前 task 就是 stop_task 时才清除引用,避免误清其他 task 的记录
292+
# 只在当前 task 就是 stop_task 时才清除引用,避免误清其他 task 的记录。
293+
# 场景:A 任务正在 stop() 中,B 任务也进入了 stop() 并等待 A 完成,
294+
# 如果 B 在 finally 中清除了 _stop_task,A 还未执行完就会失去引用。
279295
if self._stop_task is current_task:
280296
self._stop_task = None
281297

@@ -564,8 +580,12 @@ def _on_invoke_done(
564580
exc = _task.exception()
565581
if exc is None:
566582
return
567-
# 后台 invoke 理论上应把错误编码成协议消息;若异常仍逃逸,通常说明
568-
# 回复发送失败或连接状态异常,必须立刻标记连接失效,避免对端永久等待。
583+
# 为什么整个连接都要失败?正常情况下 invoke handler 会把错误编码成
584+
# ResultMessage 发回给对端。如果异常仍然逃逸,说明要么回复发送失败
585+
# (transport 已断),要么 handler 实现有 bug。无论哪种情况,连接的
586+
# 消息交换契约已不可靠,继续使用可能导致对端无限等待或数据丢失。
587+
# 采用"单点故障 → 全连接失败"策略而非隔离单个 handler,是因为协议层
588+
# 无法保证后续消息的正确性。
569589
logger.error(
570590
"Peer inbound invoke task crashed unexpectedly: "
571591
"request_id={} error={!r}",

0 commit comments

Comments
 (0)