Skip to content

Commit dcd4f69

Browse files
committed
feat: Enhance HTTP API registration with plugin namespace validation
- Added validation to ensure HTTP routes and handler capabilities belong to the current plugin namespace in the SDK. - Updated tests to reflect changes in API registration routes, ensuring they include the plugin ID as a prefix. - Introduced new error handling for invalid route and capability registrations. - Refactored existing tests to accommodate the new route structure and added tests for new validation logic. - Improved the structure of SDK record creation for better readability and maintainability.
1 parent 5f2ad81 commit dcd4f69

17 files changed

Lines changed: 475 additions & 91 deletions

astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def __init__(self, router: MockCapabilityRouter) -> None:
389389
role="core",
390390
version="local",
391391
)
392-
self.remote_capabilities = list(router.descriptors())
392+
self.remote_capabilities = list(router.all_descriptors())
393393
self.remote_capability_map = {
394394
item.name: item for item in self.remote_capabilities
395395
}

astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ async def _run_capability(
361361
)
362362
if inspect.isawaitable(result):
363363
result = await result
364+
if inspect.isasyncgen(result):
365+
return StreamExecution(
366+
iterator=self._iterate_generator(result),
367+
finalize=lambda chunks: {"items": chunks},
368+
)
364369
if isinstance(result, StreamExecution):
365370
return result
366371
raise AstrBotError.protocol_error(

astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,11 @@ def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None:
647647
queue.put_nowait(event)
648648

649649
def descriptors(self) -> list[CapabilityDescriptor]:
650+
return [
651+
entry.descriptor for entry in self._registrations.values() if entry.exposed
652+
]
653+
654+
def all_descriptors(self) -> list[CapabilityDescriptor]:
650655
return [entry.descriptor for entry in self._registrations.values()]
651656

652657
def contains(self, name: str) -> bool:

astrbot-sdk/src/astrbot_sdk/runtime/peer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ async def invoke_stream(
454454
)
455455

456456
async def iterator() -> AsyncIterator[EventMessage]:
457+
terminal_received = False
457458
try:
458459
while True:
459460
item = await queue.get()
@@ -467,15 +468,30 @@ async def iterator() -> AsyncIterator[EventMessage]:
467468
yield item
468469
continue
469470
if item.phase == "completed":
471+
terminal_received = True
470472
if include_completed:
471473
yield item
472474
break
473475
if item.phase == "failed":
476+
terminal_received = True
474477
raise AstrBotError.from_payload(
475478
item.error.model_dump() if item.error else {}
476479
)
477480
finally:
478481
self._pending_streams.pop(request_id, None)
482+
if not terminal_received:
483+
try:
484+
await self.cancel(
485+
request_id,
486+
reason="consumer_closed_stream_early",
487+
)
488+
except Exception as exc:
489+
logger.debug(
490+
"Failed to cancel stream after consumer closed early: "
491+
"request_id=%s error=%s",
492+
request_id,
493+
exc,
494+
)
479495

480496
return iterator()
481497

astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ async def cancel(self, request_id: str) -> None:
390390
async def _handle_initialize(self, _message) -> InitializeOutput:
391391
return InitializeOutput(
392392
peer=PeerInfo(name="astrbot-supervisor", role="core", version="v4"),
393-
capabilities=self.capability_router.descriptors(),
393+
capabilities=self.capability_router.all_descriptors(),
394394
metadata={
395395
"group_id": self.group_id,
396396
"plugins": [plugin.name for plugin in self.plugins],

astrbot/core/sdk_bridge/dispatch_engine.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ async def dispatch_message_event(
365365
args={},
366366
)
367367
if isinstance(output, dict):
368+
handler_result = extract_sdk_handler_result(output)
368369
if "sdk_local_extras" in output:
369370
self.bridge._persist_sdk_local_extras_from_handler(
370371
overlay,
@@ -390,6 +391,15 @@ async def dispatch_message_event(
390391
event_result,
391392
result_payload,
392393
)
394+
if handler_result["stop"]:
395+
event.stop_event()
396+
if handler_result["call_llm"]:
397+
overlay.requested_llm = True
398+
overlay.should_call_llm = True
399+
if handler_result["sent_message"]:
400+
event._has_send_oper = True
401+
if handler_result["sent_message"] or handler_result["stop"]:
402+
overlay.should_call_llm = False
393403
except Exception as exc:
394404
logger.warning(
395405
"SDK event handler failed: plugin=%s handler=%s error=%s",
@@ -468,6 +478,13 @@ async def dispatch_waiter_event(
468478
handler_result = extract_sdk_handler_result(
469479
output if isinstance(output, dict) else {}
470480
)
481+
if isinstance(output, dict) and "sdk_local_extras" in output:
482+
self.bridge._persist_sdk_local_extras_from_handler(
483+
overlay,
484+
output.get("sdk_local_extras"),
485+
plugin_id=record.plugin_id,
486+
handler_id="__sdk_session_waiter__",
487+
)
471488
result.executed_handlers.append(
472489
{"plugin_id": record.plugin_id, "handler_id": "__sdk_session_waiter__"}
473490
)

astrbot/core/sdk_bridge/lifecycle_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def reload_all(self, *, reset_restart_budget: bool = False) -> None:
7474
reset_restart_budget=reset_restart_budget,
7575
)
7676
self.bridge.refresh_command_compatibility_issues()
77-
await self.bridge._refresh_native_platform_commands({"telegram"})
77+
await self.bridge._refresh_native_platform_commands()
7878

7979
async def reload_plugin(self, plugin_id: str) -> None:
8080
async with self._reload_lock:
@@ -90,7 +90,7 @@ async def reload_plugin(self, plugin_id: str) -> None:
9090
reset_restart_budget=True,
9191
)
9292
self.bridge.refresh_command_compatibility_issues()
93-
await self.bridge._refresh_native_platform_commands({"telegram"})
93+
await self.bridge._refresh_native_platform_commands()
9494
return
9595
raise ValueError(f"SDK plugin not found: {plugin_id}")
9696

@@ -105,7 +105,7 @@ async def turn_off_plugin(self, plugin_id: str) -> None:
105105
record.failure_reason = ""
106106
self.bridge._set_disabled_override(plugin_id, disabled=True)
107107
self.bridge.refresh_command_compatibility_issues()
108-
await self.bridge._refresh_native_platform_commands({"telegram"})
108+
await self.bridge._refresh_native_platform_commands()
109109

110110
async def turn_on_plugin(self, plugin_id: str) -> None:
111111
async with self._reload_lock:
@@ -122,7 +122,7 @@ async def turn_on_plugin(self, plugin_id: str) -> None:
122122
reset_restart_budget=True,
123123
)
124124
self.bridge.refresh_command_compatibility_issues()
125-
await self.bridge._refresh_native_platform_commands({"telegram"})
125+
await self.bridge._refresh_native_platform_commands()
126126
return
127127
raise ValueError(f"SDK plugin not found: {plugin_id}")
128128

astrbot/core/sdk_bridge/registry_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING, Any
88

9+
from astrbot_sdk._internal.plugin_ids import (
10+
capability_belongs_to_plugin,
11+
http_route_belongs_to_plugin,
12+
plugin_capability_prefix,
13+
plugin_http_route_root,
14+
)
915
from astrbot_sdk.errors import AstrBotError
1016

1117
from astrbot.core import logger
@@ -268,6 +274,8 @@ def register_http_api(
268274
raise AstrBotError.invalid_input(
269275
"http.register_api requires handler_capability"
270276
)
277+
self._validate_http_route_namespace(normalized_route, plugin_id)
278+
self._validate_http_handler_namespace(handler_capability, plugin_id)
271279
self.bridge._ensure_http_route_available(
272280
plugin_id=plugin_id,
273281
route=normalized_route,
@@ -297,6 +305,32 @@ def register_http_api(
297305
handler_capability,
298306
)
299307

308+
@staticmethod
309+
def _validate_http_route_namespace(route: str, plugin_id: str) -> None:
310+
if http_route_belongs_to_plugin(route, plugin_id):
311+
return
312+
route_root = plugin_http_route_root(plugin_id)
313+
raise AstrBotError.invalid_input(
314+
"http.register_api requires route to use the current plugin namespace: "
315+
f"route={route!r}, plugin_id={plugin_id!r}, expected={route_root!r} "
316+
f"or {route_root + '/...'}"
317+
)
318+
319+
@staticmethod
320+
def _validate_http_handler_namespace(
321+
handler_capability: str,
322+
plugin_id: str,
323+
) -> None:
324+
if capability_belongs_to_plugin(handler_capability, plugin_id):
325+
return
326+
expected_prefix = plugin_capability_prefix(plugin_id)
327+
raise AstrBotError.invalid_input(
328+
"http.register_api requires handler_capability to belong to the current "
329+
"plugin: "
330+
f"capability={handler_capability!r}, plugin_id={plugin_id!r}, "
331+
f"expected_prefix={expected_prefix!r}"
332+
)
333+
300334
def unregister_http_api(
301335
self,
302336
*,

astrbot/core/sdk_bridge/request_runtime.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,10 @@ def set_result_payload_on_overlay(
289289
else None
290290
)
291291
overlay.result_is_set = True
292-
self.set_overlay_stop_state(overlay, stopped=False)
292+
self.set_overlay_stop_state(
293+
overlay,
294+
stopped=bool(normalized_payload.get("stop", False)),
295+
)
293296

294297
def sync_overlay_payload_from_result_object(
295298
self,
@@ -494,10 +497,13 @@ def legacy_result_to_sdk_payload(
494497
if isinstance(result.chain, MessageChain)
495498
else result.chain
496499
)
497-
return {
500+
payload = {
498501
"type": "chain" if chain else "empty",
499502
"chain": SdkRequestRuntime.components_to_sdk_payload(chain),
500503
}
504+
if result.is_stopped():
505+
payload["stop"] = True
506+
return payload
501507

502508
@staticmethod
503509
def components_to_sdk_payload(
@@ -738,6 +744,10 @@ def apply_sdk_result_payload(
738744
result.chain = updated.chain
739745
result.use_t2i_ = updated.use_t2i_
740746
result.type = updated.type
747+
if bool(payload.get("stop", False)):
748+
result.stop_event()
749+
else:
750+
result.continue_event()
741751
return result
742752

743753
def get_effective_result(

tests/test_sdk/unit/_context_api_roundtrip.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ruff: noqa: E402
12
from __future__ import annotations
23

34
import asyncio
@@ -45,7 +46,14 @@ def install(name: str, attrs: dict[str, object]) -> None:
4546
_install_optional_dependency_stubs()
4647

4748
from astrbot_sdk._internal.invocation_context import current_caller_plugin_id
49+
from astrbot_sdk._internal.plugin_ids import (
50+
capability_belongs_to_plugin,
51+
http_route_belongs_to_plugin,
52+
plugin_capability_prefix,
53+
plugin_http_route_root,
54+
)
4855
from astrbot_sdk.context import Context
56+
from astrbot_sdk.errors import AstrBotError
4957
from astrbot_sdk.message.components import component_to_payload_sync
5058
from astrbot_sdk.runtime._streaming import StreamExecution
5159

@@ -432,6 +440,21 @@ def register_http_api(
432440
) -> None:
433441
normalized_route = self._normalize_route(route)
434442
normalized_methods = self._normalize_methods(methods)
443+
if not http_route_belongs_to_plugin(normalized_route, plugin_id):
444+
route_root = plugin_http_route_root(plugin_id)
445+
raise AstrBotError.invalid_input(
446+
"http.register_api requires route to use the current plugin "
447+
f"namespace: route={normalized_route!r}, plugin_id={plugin_id!r}, "
448+
f"expected={route_root!r} or {route_root + '/...'}"
449+
)
450+
if not capability_belongs_to_plugin(str(handler_capability), plugin_id):
451+
expected_prefix = plugin_capability_prefix(plugin_id)
452+
raise AstrBotError.invalid_input(
453+
"http.register_api requires handler_capability to belong to the "
454+
"current plugin: "
455+
f"capability={handler_capability!r}, plugin_id={plugin_id!r}, "
456+
f"expected_prefix={expected_prefix!r}"
457+
)
435458
existing = [
436459
item
437460
for item in self.http_routes.get(plugin_id, [])
@@ -1055,7 +1078,7 @@ def __init__(self, bridge: CoreCapabilityBridge) -> None:
10551078
self._request_counter = 0
10561079
self.remote_peer = object()
10571080
self.remote_capability_map = {
1058-
descriptor.name: descriptor for descriptor in bridge.descriptors()
1081+
descriptor.name: descriptor for descriptor in bridge.all_descriptors()
10591082
}
10601083

10611084
def _next_request_id(self) -> str:

0 commit comments

Comments
 (0)