Skip to content

Commit f6de722

Browse files
committed
Fix session plugin disable for hooks and tools
1 parent 70872cd commit f6de722

7 files changed

Lines changed: 279 additions & 66 deletions

File tree

astrbot/core/astr_agent_tool_exec.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,35 @@
3737
from astrbot.core.platform.message_session import MessageSession
3838
from astrbot.core.provider.entites import ProviderRequest
3939
from astrbot.core.provider.register import llm_tools
40+
from astrbot.core.star.session_plugin_manager import SessionPluginManager
41+
from astrbot.core.star.star import star_map
4042
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
4143
from astrbot.core.utils.history_saver import persist_agent_history
4244
from astrbot.core.utils.image_ref_utils import is_supported_image_ref
4345
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
4446

4547

4648
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
49+
@classmethod
50+
def _tool_enabled_for_session(
51+
cls,
52+
tool: FunctionTool,
53+
session_config: dict | None,
54+
) -> bool:
55+
mp = tool.handler_module_path
56+
if not mp:
57+
return True
58+
59+
plugin = star_map.get(mp)
60+
if not plugin:
61+
return True
62+
63+
return SessionPluginManager.is_plugin_enabled_for_session_config(
64+
plugin.name,
65+
session_config,
66+
reserved=plugin.reserved,
67+
)
68+
4769
@classmethod
4870
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
4971
if image_urls_raw is None:
@@ -193,14 +215,17 @@ def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
193215
return {}
194216

195217
@classmethod
196-
def _build_handoff_toolset(
218+
async def _build_handoff_toolset(
197219
cls,
198220
run_context: ContextWrapper[AstrAgentContext],
199221
tools: list[str | FunctionTool] | None,
200222
) -> ToolSet | None:
201223
ctx = run_context.context.context
202224
event = run_context.context.event
203225
cfg = ctx.get_config(umo=event.unified_msg_origin)
226+
session_config = await SessionPluginManager.get_session_plugin_config(
227+
event.unified_msg_origin
228+
)
204229
provider_settings = cfg.get("provider_settings", {})
205230
runtime = str(provider_settings.get("computer_use_runtime", "local"))
206231
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
@@ -212,7 +237,10 @@ def _build_handoff_toolset(
212237
for registered_tool in llm_tools.func_list:
213238
if isinstance(registered_tool, HandoffTool):
214239
continue
215-
if registered_tool.active:
240+
if registered_tool.active and cls._tool_enabled_for_session(
241+
registered_tool,
242+
session_config,
243+
):
216244
toolset.add_tool(registered_tool)
217245
for runtime_tool in runtime_computer_tools.values():
218246
toolset.add_tool(runtime_tool)
@@ -225,14 +253,19 @@ def _build_handoff_toolset(
225253
for tool_name_or_obj in tools:
226254
if isinstance(tool_name_or_obj, str):
227255
registered_tool = llm_tools.get_func(tool_name_or_obj)
228-
if registered_tool and registered_tool.active:
256+
if (
257+
registered_tool
258+
and registered_tool.active
259+
and cls._tool_enabled_for_session(registered_tool, session_config)
260+
):
229261
toolset.add_tool(registered_tool)
230262
continue
231263
runtime_tool = runtime_computer_tools.get(tool_name_or_obj)
232264
if runtime_tool:
233265
toolset.add_tool(runtime_tool)
234266
elif isinstance(tool_name_or_obj, FunctionTool):
235-
toolset.add_tool(tool_name_or_obj)
267+
if cls._tool_enabled_for_session(tool_name_or_obj, session_config):
268+
toolset.add_tool(tool_name_or_obj)
236269
return None if toolset.empty() else toolset
237270

238271
@classmethod
@@ -264,7 +297,7 @@ async def _execute_handoff(
264297
tool_args["image_urls"] = image_urls
265298

266299
# Build handoff toolset from registered tools plus runtime computer tools.
267-
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
300+
toolset = await cls._build_handoff_toolset(run_context, tool.agent.tools)
268301

269302
ctx = run_context.context.context
270303
event = run_context.context.event

astrbot/core/astr_main_agent.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
6363
from astrbot.core.star.context import Context
6464
from astrbot.core.star.star_handler import star_map
65+
from astrbot.core.star.session_plugin_manager import SessionPluginManager
6566
from astrbot.core.tools.cron_tools import (
6667
CREATE_CRON_JOB_TOOL,
6768
DELETE_CRON_JOB_TOOL,
@@ -846,33 +847,49 @@ def _sanitize_context_by_modalities(
846847
req.contexts = sanitized_contexts
847848

848849

849-
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
850+
async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
850851
"""根据事件中的插件设置,过滤请求中的工具列表。
851852
852853
注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留,
853854
因为它们不属于任何插件,不应被插件过滤逻辑影响。
854855
"""
855-
if event.plugins_name is not None and req.func_tool:
856-
new_tool_set = ToolSet()
857-
for tool in req.func_tool.tools:
858-
if isinstance(tool, MCPTool):
859-
# 保留 MCP 工具
860-
new_tool_set.add_tool(tool)
861-
continue
862-
mp = tool.handler_module_path
863-
if not mp:
864-
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
865-
# 不应受到会话插件过滤影响。
866-
new_tool_set.add_tool(tool)
867-
continue
868-
plugin = star_map.get(mp)
869-
if not plugin:
870-
# 无法解析插件归属时,保守保留工具,避免误过滤。
871-
new_tool_set.add_tool(tool)
872-
continue
873-
if plugin.name in event.plugins_name or plugin.reserved:
874-
new_tool_set.add_tool(tool)
875-
req.func_tool = new_tool_set
856+
if not req.func_tool:
857+
return
858+
859+
session_config = await SessionPluginManager.get_session_plugin_config(
860+
event.unified_msg_origin
861+
)
862+
new_tool_set = ToolSet()
863+
for tool in req.func_tool.tools:
864+
if isinstance(tool, MCPTool):
865+
# 保留 MCP 工具
866+
new_tool_set.add_tool(tool)
867+
continue
868+
mp = tool.handler_module_path
869+
if not mp:
870+
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
871+
# 不应受到会话插件过滤影响。
872+
new_tool_set.add_tool(tool)
873+
continue
874+
plugin = star_map.get(mp)
875+
if not plugin:
876+
# 无法解析插件归属时,保守保留工具,避免误过滤。
877+
new_tool_set.add_tool(tool)
878+
continue
879+
if (
880+
event.plugins_name is not None
881+
and not plugin.reserved
882+
and plugin.name not in event.plugins_name
883+
):
884+
continue
885+
if not SessionPluginManager.is_plugin_enabled_for_session_config(
886+
plugin.name,
887+
session_config,
888+
reserved=plugin.reserved,
889+
):
890+
continue
891+
new_tool_set.add_tool(tool)
892+
req.func_tool = new_tool_set
876893

877894

878895
async def _handle_webchat(
@@ -1243,7 +1260,7 @@ async def build_main_agent(
12431260
req.session_id = event.unified_msg_origin
12441261

12451262
_modalities_fix(provider, req)
1246-
_plugin_tool_fix(event, req)
1263+
await _plugin_tool_fix(event, req)
12471264
_sanitize_context_by_modalities(config, provider, req)
12481265

12491266
if config.llm_safety_mode:

astrbot/core/pipeline/context_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from astrbot.core.platform.astr_message_event import AstrMessageEvent
88
from astrbot.core.star.star import star_map
99
from astrbot.core.star.star_handler import EventType, star_handlers_registry
10+
from astrbot.core.star.session_plugin_manager import SessionPluginManager
1011

1112

1213
async def call_handler(
@@ -89,19 +90,32 @@ async def call_event_hook(
8990
hook_type,
9091
plugins_name=event.plugins_name,
9192
)
93+
session_config = await SessionPluginManager.get_session_plugin_config(
94+
event.unified_msg_origin
95+
)
9296
for handler in handlers:
97+
plugin = star_map.get(handler.handler_module_path)
98+
if plugin and not SessionPluginManager.is_plugin_enabled_for_session_config(
99+
plugin.name,
100+
session_config,
101+
reserved=plugin.reserved,
102+
):
103+
logger.debug(
104+
f"插件 {plugin.name} 在会话 {event.unified_msg_origin} 中被禁用,跳过 hook {handler.handler_name}",
105+
)
106+
continue
93107
try:
94108
assert inspect.iscoroutinefunction(handler.handler)
95109
logger.debug(
96-
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
110+
f"hook({hook_type.name}) -> {plugin.name if plugin else handler.handler_module_path} - {handler.handler_name}",
97111
)
98112
await handler.handler(event, *args, **kwargs)
99113
except BaseException:
100114
logger.error(traceback.format_exc())
101115

102116
if event.is_stopped():
103117
logger.info(
104-
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。",
118+
f"{plugin.name if plugin else handler.handler_module_path} - {handler.handler_name} 终止了事件传播。",
105119
)
106120
return True
107121

astrbot/core/star/session_plugin_manager.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,65 @@ class SessionPluginManager:
88
"""管理会话级别的插件启停状态"""
99

1010
@staticmethod
11-
async def is_plugin_enabled_for_session(
12-
session_id: str,
13-
plugin_name: str,
14-
) -> bool:
15-
"""检查插件是否在指定会话中启用
16-
17-
Args:
18-
session_id: 会话ID (unified_msg_origin)
19-
plugin_name: 插件名称
20-
21-
Returns:
22-
bool: True表示启用,False表示禁用
23-
24-
"""
25-
# 获取会话插件配置
11+
async def get_session_plugin_config(session_id: str) -> dict:
12+
"""获取指定会话的插件配置。"""
2613
session_plugin_config = await sp.get_async(
2714
scope="umo",
2815
scope_id=session_id,
2916
key="session_plugin_config",
3017
default={},
3118
)
32-
session_config = session_plugin_config.get(session_id, {})
19+
return session_plugin_config.get(session_id, {})
20+
21+
@staticmethod
22+
def is_plugin_enabled_for_session_config(
23+
plugin_name: str | None,
24+
session_config: dict | None,
25+
*,
26+
reserved: bool = False,
27+
) -> bool:
28+
"""检查插件是否在指定会话配置中启用。"""
29+
if reserved or not plugin_name:
30+
return True
31+
32+
if not session_config:
33+
return True
3334

3435
enabled_plugins = session_config.get("enabled_plugins", [])
3536
disabled_plugins = session_config.get("disabled_plugins", [])
3637

37-
# 如果插件在禁用列表中,返回False
3838
if plugin_name in disabled_plugins:
3939
return False
4040

41-
# 如果插件在启用列表中,返回True
4241
if plugin_name in enabled_plugins:
4342
return True
4443

45-
# 如果都没有配置,默认为启用(兼容性考虑)
4644
return True
4745

46+
@staticmethod
47+
async def is_plugin_enabled_for_session(
48+
session_id: str,
49+
plugin_name: str,
50+
*,
51+
reserved: bool = False,
52+
) -> bool:
53+
"""检查插件是否在指定会话中启用
54+
55+
Args:
56+
session_id: 会话ID (unified_msg_origin)
57+
plugin_name: 插件名称
58+
59+
Returns:
60+
bool: True表示启用,False表示禁用
61+
62+
"""
63+
session_config = await SessionPluginManager.get_session_plugin_config(session_id)
64+
return SessionPluginManager.is_plugin_enabled_for_session_config(
65+
plugin_name,
66+
session_config,
67+
reserved=reserved,
68+
)
69+
4870
@staticmethod
4971
async def filter_handlers_by_session(
5072
event: AstrMessageEvent,
@@ -65,14 +87,7 @@ async def filter_handlers_by_session(
6587
session_id = event.unified_msg_origin
6688
filtered_handlers = []
6789

68-
session_plugin_config = await sp.get_async(
69-
scope="umo",
70-
scope_id=session_id,
71-
key="session_plugin_config",
72-
default={},
73-
)
74-
session_config = session_plugin_config.get(session_id, {})
75-
disabled_plugins = session_config.get("disabled_plugins", [])
90+
session_config = await SessionPluginManager.get_session_plugin_config(session_id)
7691

7792
for handler in handlers:
7893
# 获取处理器对应的插件
@@ -91,7 +106,11 @@ async def filter_handlers_by_session(
91106
continue
92107

93108
# 检查插件是否在当前会话中启用
94-
if plugin.name in disabled_plugins:
109+
if not SessionPluginManager.is_plugin_enabled_for_session_config(
110+
plugin.name,
111+
session_config,
112+
reserved=plugin.reserved,
113+
):
95114
logger.debug(
96115
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
97116
)

0 commit comments

Comments
 (0)