diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index 42282dc500..f5cde71efa 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -189,6 +189,18 @@ async def reset(self, message: AstrMessageEvent) -> None: ret = "✅ Conversation reset successfully." + # 清理该会话下的所有 subagent + try: + from astrbot.core.subagent_manager import SubAgentManager + + cleanup_result = await SubAgentManager.cleanup_session(umo) + if cleanup_result["status"] == "cleaned": + cleaned_count = len(cleanup_result["cleaned_agents"]) + if cleaned_count > 0: + ret += f" 🧹 Also cleaned {cleaned_count} subagent(s): {', '.join(cleanup_result['cleaned_agents'])}." + except Exception as e: + logger.warning(f"[SubAgent] Failed to cleanup subagents on /reset: {e}") + message.set_extra("_clean_group_context_session", True) message.set_result(MessageEventResult().message(ret)) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 3f74f0ec9b..e2ef81543e 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -991,6 +991,13 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: ), ) + # 获取 trace span:优先使用 subagent trace_span,否则回退到 event.trace + _agent_ctx = getattr(self.run_context, "context", None) + _trace = getattr(_agent_ctx, "trace_span", None) + if _trace is None and _agent_ctx is not None: + _event = getattr(_agent_ctx, "event", None) + _trace = getattr(_event, "trace", None) + # 执行函数调用 for func_tool_name, func_tool_args, func_tool_id in zip( llm_response.tools_call_name, @@ -1014,10 +1021,23 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: ], ) ) + # 记录工具调用追踪 + if _trace: + _trace.record("agent_tool_call", tool_name=func_tool_name) try: if not req.func_tool: return + # Resolve tool from regular tool sets + if ( + self.tool_schema_mode == "skills_like" + and self._skill_like_raw_tool_set + ): + # in 'skills_like' mode, raw.func_tool is light schema, does not have handler + # so we need to get the tool from the raw tool set + func_tool = self._skill_like_raw_tool_set.get_tool(func_tool_name) + else: + func_tool = req.func_tool.get_tool(func_tool_name) if ( self.tool_schema_mode == "skills_like" and self._skill_like_raw_tool_set @@ -1231,6 +1251,15 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: ) ) logger.info(f"Tool `{func_tool_name}` Result: {tool_result_content}") + # 记录工具结果追踪 + if _trace: + _trace.record( + "agent_tool_result", + tool_name=func_tool_name, + tool_result=tool_result_content[:500] + if tool_result_content + else None, + ) # 处理函数调用响应 if tool_call_result_blocks: @@ -1359,13 +1388,16 @@ async def _finalize_aborted_step( llm_resp: LLMResponse | None = None, ) -> AgentResponse: logger.info("Agent execution was requested to stop by user.") + if llm_resp is None: llm_resp = LLMResponse(role="assistant", completion_text="") + if llm_resp.role != "assistant": llm_resp = LLMResponse( role="assistant", completion_text=self.USER_INTERRUPTION_MESSAGE, ) + self.final_llm_resp = llm_resp self._aborted = True self._transition_state(AgentState.DONE) diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py index 9c6451cc74..747b4d191e 100644 --- a/astrbot/core/astr_agent_context.py +++ b/astrbot/core/astr_agent_context.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import Field from pydantic.dataclasses import dataclass @@ -14,8 +16,13 @@ class AstrAgentContext: """The star context instance""" event: AstrMessageEvent """The message event associated with the agent context.""" - extra: dict[str, str] = Field(default_factory=dict) + extra: dict[str, Any] = Field(default_factory=dict) """Customized extra data.""" + trace_span: Any = Field(default=None) + """Optional custom TraceSpan for subagent tracing. When set, tool calls within + the agent loop will be recorded to this trace instead of event.trace. + This prevents concurrent subagent and main agent tool calls from mixing up + trace records.""" AgentContextWrapper = ContextWrapper[AstrAgentContext] diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 6bdf3011b6..29e3ab8b67 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -185,13 +185,6 @@ async def run_agent( if resp.type == "tool_call_result": msg_chain = resp.data["chain"] - astr_event.trace.record( - "agent_tool_result", - tool_result=msg_chain.get_plain_text( - with_other_comps_mark=True - ), - ) - if msg_chain.type == "tool_direct_result": # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 await astr_event.send(msg_chain) @@ -218,10 +211,6 @@ async def run_agent( yield MessageChain(chain=[], type="break") tool_info = _extract_chain_json_data(resp.data["chain"]) - astr_event.trace.record( - "agent_tool_call", - tool_name=tool_info if tool_info else "unknown", - ) _record_tool_call_name(tool_info, tool_name_by_call_id) if astr_event.get_platform_name() == "webchat": diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..13b44c5f0c 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -1,6 +1,7 @@ import asyncio import inspect import json +import time import traceback import typing as T import uuid @@ -30,6 +31,11 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.subagent_manager import ( + RET_PENDING_TASK_CREATE_FAILED, + SubAgentManager, + SubAgentStatus, +) from astrbot.core.tools.computer_tools import ( CuaKeyboardTypeTool, CuaMouseClickTool, @@ -125,6 +131,46 @@ async def _collect_handoff_image_urls( ) return sanitized + @classmethod + def _get_session_from_context(cls, run_context: ContextWrapper[AstrAgentContext]): + """Extract the SubAgentSession from run_context. + + Walks through run_context -> context -> event -> unified_msg_origin + to locate the session_id, then returns the corresponding session + from SubAgentManager. Returns ``None`` when any step fails. + """ + run_context_context = getattr(run_context, "context", None) + event = ( + getattr(run_context_context, "event", None) if run_context_context else None + ) + session_id = getattr(event, "unified_msg_origin", None) if event else None + if not session_id: + return None + + return SubAgentManager.get_session(session_id) + + @classmethod + def _resolve_handoff_by_name( + cls, run_context: ContextWrapper[AstrAgentContext], name: str + ) -> HandoffTool | None: + """Resolve a HandoffTool from SubAgentManager by subagent name.""" + session = cls._get_session_from_context(run_context) + if not session: + return None + + return session.handoff_tools.get(name, None) + + @classmethod + def _list_available_subagents( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> list[str]: + """List available subagent names for the current session.""" + session = cls._get_session_from_context(run_context) + if not session: + return [] + + return list(session.handoff_tools.keys()) + @classmethod async def execute(cls, tool, run_context, **tool_args): """执行函数调用。 @@ -137,6 +183,51 @@ async def execute(cls, tool, run_context, **tool_args): AsyncGenerator[None | mcp.types.CallToolResult, None] """ + # 防止subagent的名字叫"subagent"造成工具歧义(在create中已经不会发生,此处用于兜底) + if ( + isinstance(tool, FunctionTool) + and not isinstance(tool, HandoffTool) + and tool.name == "transfer_to_subagent" + ): + tool_args = dict(tool_args) + subagent_name = tool_args.pop("name", None) + if not subagent_name: + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text="Error: 'name' parameter is required for transfer_to_subagent. Use list_subagents to see available names.", + ) + ] + ) + return + + handoff_tool = cls._resolve_handoff_by_name(run_context, subagent_name) + if handoff_tool is None: + available = cls._list_available_subagents(run_context) + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=f"Error: Subagent '{subagent_name}' not found. Available subagents: {available}. Use create_subagent to create new ones.", + ) + ] + ) + return + + is_bg = tool_args.pop("background_task", False) + if is_bg: + async for r in cls._execute_handoff_background( + handoff_tool, run_context, **tool_args + ): + yield r + else: + async for r in cls._execute_handoff( + handoff_tool, run_context, **tool_args + ): + yield r + return + if isinstance(tool, HandoffTool): is_bg = tool_args.pop("background_task", False) if is_bg: @@ -290,6 +381,21 @@ def _build_handoff_toolset( toolset.add_tool(runtime_tool) elif isinstance(tool_name_or_obj, FunctionTool): toolset.add_tool(tool_name_or_obj) + + # Always add send_shared_context tool for shared context feature + try: + from astrbot.core.subagent_manager import ( + SEND_SHARED_CONTEXT_TOOL, + SubAgentManager, + ) + + session_id = event.unified_msg_origin + session = SubAgentManager.get_session(session_id) + if session and session.shared_context_enabled: + toolset.add_tool(SEND_SHARED_CONTEXT_TOOL) + except Exception as e: + logger.debug(f"[SubAgent] Failed to add shared context tool: {e}") + return None if toolset.empty() else toolset @classmethod @@ -322,10 +428,10 @@ async def _execute_handoff( # Build handoff toolset from registered tools plus runtime computer tools. toolset = cls._build_handoff_toolset(run_context, tool.agent.tools) - ctx = run_context.context.context event = run_context.context.event umo = event.unified_msg_origin + agent_name = getattr(tool.agent, "name", "unknown") # Use per-subagent provider override if configured; otherwise fall back # to the current/default provider resolution. @@ -351,18 +457,132 @@ async def _execute_handoff( prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {}) agent_max_step = int(prov_settings.get("max_agent_step", 30)) stream = prov_settings.get("streaming_response", False) - llm_resp = await ctx.tool_loop_agent( - event=event, - chat_provider_id=prov_id, - prompt=input_, - image_urls=image_urls, - system_prompt=tool.agent.instructions, - tools=toolset, - contexts=contexts, + + # Create trace span for subagent execution + from astrbot.core.utils.trace import TraceSpan + + parent_trace = getattr(event, "trace", None) + subagent_trace = TraceSpan( + name=f"SubAgent:{agent_name}", + umo=event.unified_msg_origin, + sender_name=event.get_sender_name() + if hasattr(event, "get_sender_name") + else None, + message_outline=f"Handoff to {agent_name}: {input_[:100] if input_ else ''}", + parent_span_id=parent_trace.span_id if parent_trace else None, + ) + subagent_trace.record( + "subagent_execution_begin", + agent_name=agent_name, + input=input_ if input_ else None, + image_count=len(image_urls), + tools=[t.name for t in toolset] if toolset else [], max_steps=agent_max_step, - tool_call_timeout=run_context.tool_call_timeout, stream=stream, ) + + # 获取子代理的历史上下文 + subagent_history, agent_name = cls._load_subagent_history(umo, tool) + # 如果有历史上下文,合并到 contexts 中 + if subagent_history: + subagent_trace.record( + "subagent_history_loaded", + agent_name=agent_name, + history_messages_count=len(subagent_history), + ) + if contexts is None: + contexts = subagent_history + else: + contexts = subagent_history + contexts + + # 构建子代理的 system_prompt + subagent_system_prompt = cls._build_subagent_system_prompt( + umo, tool, prov_settings + ) + subagent_trace.record( + "subagent_system_prompt", + agent_name=agent_name, + prompt_length=len(subagent_system_prompt), + prompt=subagent_system_prompt if subagent_system_prompt else None, + ) + + # 构建子代理的追加内容 + extra_content_parts = SubAgentManager.build_subagent_extra_content_parts( + umo, agent_name + ) + + # 获取子代理的超时时间 + execution_timeout = cls._get_subagent_execution_timeout() + + # 用于存储本轮的完整历史上下文 + runner_messages = [] + + # 构建 tool_loop_agent 协程 + async def _run_subagent(): + return await ctx.tool_loop_agent( + event=event, + chat_provider_id=prov_id, + prompt=input_, + image_urls=image_urls, + system_prompt=subagent_system_prompt, + tools=toolset, + contexts=contexts, + max_steps=agent_max_step, + tool_call_timeout=run_context.tool_call_timeout, + stream=stream, + runner_messages=runner_messages, + extra_user_content_parts=extra_content_parts, + trace_span=subagent_trace, + ) + + # 添加执行超时控制 + if execution_timeout > 0: + try: + llm_resp = await asyncio.wait_for( + _run_subagent(), timeout=execution_timeout + ) + except asyncio.TimeoutError: + # 若超时,保存已产生的部分历史 + cls._save_subagent_history(umo, runner_messages, agent_name) + subagent_trace.record( + "subagent_execution_timeout", + timeout_seconds=execution_timeout, + ) + error_msg = f"SubAgent '{agent_name}' execution timeout after {execution_timeout:.1f} seconds." + logger.warning(f"[SubAgent:Timeout] {error_msg}") + + cls._handle_subagent_timeout(umo=umo, agent_name=agent_name) + + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent(type="text", text=f"error: {error_msg}") + ] + ) + return + else: + # 不设置超时 + llm_resp = await _run_subagent() + + execution_time = time.time() - subagent_trace.started_at + subagent_trace.record( + "subagent_execution_complete", + agent_name=agent_name, + result=llm_resp.completion_text + if hasattr(llm_resp, "completion_text") and llm_resp.completion_text + else None, + result_length=len(llm_resp.completion_text) + if hasattr(llm_resp, "completion_text") and llm_resp.completion_text + else 0, + execution_time=execution_time, + ) + + # 保存历史上下文 + cls._save_subagent_history(umo, runner_messages, agent_name) + subagent_trace.record( + "subagent_history_saved", + messages_count=len(runner_messages), + ) + yield mcp.types.CallToolResult( content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] ) @@ -381,32 +601,59 @@ async def _execute_handoff_background( ``CronMessageEvent`` is created so the main LLM can inform the user of the result – the same pattern used by ``_execute_background`` for regular background tasks. + + 当启用增强SubAgent时,会在 SubAgentManager 中创建 pending 任务, + 并返回 task_id 给主 Agent,以便后续通过 wait_for_subagent 获取结果。 """ - task_id = uuid.uuid4().hex + event = run_context.context.event + umo = event.unified_msg_origin + agent_name = getattr(tool.agent, "name", None) + + # check if enhanced subagent + subagent_task_id = cls._register_subagent_task(umo, agent_name) + + original_task_id = uuid.uuid4().hex + + # Create trace span for background task creation + from astrbot.core.utils.trace import TraceSpan + + parent_trace = getattr(event, "trace", None) + bg_trace = TraceSpan( + name=f"SubAgentBackground:{agent_name}", + umo=event.unified_msg_origin, + sender_name=event.get_sender_name() + if hasattr(event, "get_sender_name") + else None, + message_outline=f"Background handoff to {agent_name}", + parent_span_id=parent_trace.span_id if parent_trace else None, + ) + bg_trace.record( + "subagent_background_task_created", + agent_name=agent_name, + subagent_task_id=subagent_task_id, + original_task_id=original_task_id, + ) async def _run_handoff_in_background() -> None: try: await cls._do_handoff_background( tool=tool, run_context=run_context, - task_id=task_id, + task_id=original_task_id, + subagent_task_id=subagent_task_id, **tool_args, ) + except Exception as e: # noqa: BLE001 logger.error( - f"Background handoff {task_id} ({tool.name}) failed: {e!s}", + f"Background handoff {original_task_id} ({tool.name}) failed: {e!s}", exc_info=True, ) asyncio.create_task(_run_handoff_in_background()) - text_content = mcp.types.TextContent( - type="text", - text=( - f"Background task dedicated to subagent '{tool.agent.name}' submitted. task_id={task_id}. " - f"The subagent '{tool.agent.name}' is working on the task on hehalf you. " - f"You will be notified when it finishes." - ), + text_content = cls._build_background_submission_message( + agent_name, original_task_id, subagent_task_id ) yield mcp.types.CallToolResult(content=[text_content]) @@ -418,45 +665,114 @@ async def _do_handoff_background( task_id: str, **tool_args, ) -> None: - """Run the subagent handoff and, on completion, wake the main agent.""" + """Run the subagent handoff. + 当增强版 SubAgent 启用时,结果存储到 SubAgentManager,主 Agent 可通过 wait_for_subagent 获取。 + 否则使用原有的 _wake_main_agent_for_background_result 流程。 + """ + + start_time = time.time() result_text = "" + error_text = None tool_args = dict(tool_args) tool_args["image_urls"] = await cls._collect_handoff_image_urls( run_context, tool_args.get("image_urls"), ) + + event = run_context.context.event + umo = event.unified_msg_origin + agent_name = getattr(tool.agent, "name", None) + + # Create trace span for background subagent execution + from astrbot.core.utils.trace import TraceSpan + + parent_trace = getattr(event, "trace", None) + bg_trace = TraceSpan( + name=f"SubAgentBackground:{agent_name}", + umo=event.unified_msg_origin, + sender_name=event.get_sender_name() + if hasattr(event, "get_sender_name") + else None, + message_outline=f"Background handoff to {agent_name}", + parent_span_id=parent_trace.span_id if parent_trace else None, + ) + bg_trace.record( + "subagent_background_execution_start", + agent_name=agent_name, + ) + + # 获取SubAgent的超时时间 + execution_timeout = cls._get_subagent_execution_timeout() + try: - async for r in cls._execute_handoff( - tool, - run_context, - image_urls_prepared=True, - **tool_args, - ): - if isinstance(r, mcp.types.CallToolResult): - for content in r.content: - if isinstance(content, mcp.types.TextContent): - result_text += content.text + "\n" + + async def _run(): + nonlocal result_text + async for r in cls._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + **tool_args, + ): + if isinstance(r, mcp.types.CallToolResult): + for content in r.content: + if isinstance(content, mcp.types.TextContent): + result_text += content.text + "\n" + + if execution_timeout > 0: + await asyncio.wait_for(_run(), timeout=execution_timeout) + else: + await _run() + + except asyncio.TimeoutError: + error_text = f"Execution timeout after {execution_timeout:.1f} seconds." + result_text = f"error: Background SubAgent '{agent_name}' {error_text}" + logger.warning(f"[SubAgent:BackgroundTask] {error_text}") + except Exception as e: + error_text = str(e) result_text = ( f"error: Background task execution failed, internal error: {e!s}" ) - event = run_context.context.event - - await cls._wake_main_agent_for_background_result( - run_context=run_context, - task_id=task_id, - tool_name=tool.name, - result_text=result_text, - tool_args=tool_args, - note=( - event.get_extra("background_note") - or f"Background task for subagent '{tool.agent.name}' finished." - ), - summary_name=f"Dedicated to subagent `{tool.agent.name}`", - extra_result_fields={"subagent_name": tool.agent.name}, + execution_time = time.time() - start_time + bg_trace.record( + "subagent_background_execution_end", + agent_name=agent_name, + success=error_text is None, + result_preview=result_text[:500] if result_text else None, + execution_time=execution_time, ) + # Check if it's enhanced subagent + is_managed = cls._is_managed_subagent(umo, agent_name) + if is_managed: + await cls._handle_subagent_background_result( + umo=umo, + agent_name=agent_name, + task_id=tool_args.get("subagent_task_id"), + result_text=result_text, + error_text=error_text, + execution_time=execution_time, + run_context=run_context, + tool=tool, + tool_args=tool_args, + ) + else: + await cls._wake_main_agent_for_background_result( + run_context=run_context, + task_id=task_id, + tool_name=tool.name, + result_text=result_text, + tool_args=tool_args, + note=( + event.get_extra("background_note") + or f"Background task for subagent '{agent_name}' finished." + ), + summary_name=f"Dedicated to subagent `{agent_name}`", + extra_result_fields={"subagent_name": agent_name}, + ) + @classmethod async def _execute_background( cls, @@ -653,10 +969,18 @@ async def _execute_local( ) while True: try: - resp = await asyncio.wait_for( - anext(wrapper), - timeout=tool_call_timeout or run_context.tool_call_timeout, - ) + if ( + tool.name == "wait_for_subagent" or tool.name == "orchestrate_tasks" + ): # wait工具有自己的超时,避免受到tool_call_timeout影响 + resp = await asyncio.wait_for( + anext(wrapper), + timeout=3600, + ) + else: + resp = await asyncio.wait_for( + anext(wrapper), + timeout=tool_call_timeout or run_context.tool_call_timeout, + ) if resp is not None: if isinstance(resp, mcp.types.CallToolResult): yield resp @@ -704,6 +1028,248 @@ async def _execute_mcp( return yield res + @staticmethod + def _load_subagent_history( + umo: str, tool: HandoffTool + ) -> tuple[list[Message], str]: + agent_name = getattr(tool.agent, "name", None) + subagent_history = [] + if agent_name: + # 仅在历史功能启用时加载历史 + if SubAgentManager.is_history_enabled(): + try: + stored_history = SubAgentManager.get_subagent_history( + umo, agent_name + ) + if stored_history: + # 将历史消息转换为 Message 对象 + for hist_msg in stored_history: + try: + if isinstance(hist_msg, dict): + subagent_history.append( + Message.model_validate(hist_msg) + ) + elif isinstance(hist_msg, Message): + subagent_history.append(hist_msg) + except Exception: + continue + if subagent_history: + logger.debug( + f"[SubAgentHistory] Loaded {len(subagent_history)} history messages for {agent_name}" + ) + + except Exception as e: + logger.warning( + f"[SubAgentHistory] Failed to load history for {agent_name}: {e}" + ) + else: + logger.debug( + f"[SubAgentHistory] History is disabled, skipping load for {agent_name}" + ) + return subagent_history, agent_name + + @staticmethod + def _build_subagent_system_prompt( + umo: str, tool: HandoffTool, prov_settings: dict + ) -> str: + agent_name = getattr(tool.agent, "name", None) + base = tool.agent.instructions or "" + subagent_system_prompt = ( + f"# Role\nYour name is **{agent_name}** (used for tool calling)\n{base}\n" + ) + if agent_name: + runtime = prov_settings.get("computer_use_runtime", "local") + subagent_system_prompt += SubAgentManager.build_subagent_system_prompt( + umo, agent_name, runtime + ) + return subagent_system_prompt + + @staticmethod + def _save_subagent_history( + umo: str, runner_messages: list[Message], agent_name: str + ) -> None: + if agent_name and runner_messages: + # 仅在历史功能启用时保存历史 + if SubAgentManager.is_history_enabled(): + SubAgentManager.update_subagent_history( + umo, agent_name, runner_messages + ) + else: + logger.debug( + f"[SubAgentHistory] History is disabled, skipping save for {agent_name}" + ) + else: + return + + @staticmethod + def _register_subagent_task(umo: str, agent_name: str | None) -> str | None: + if not agent_name: + return None + try: + session = SubAgentManager.get_session(umo) + if session and (agent_name in session.subagents): + subagent_task_id = SubAgentManager.create_pending_subagent_task( + session_id=umo, agent_name=agent_name + ) + + if subagent_task_id.startswith(RET_PENDING_TASK_CREATE_FAILED): + logger.info( + f"[SubAgent:BackgroundTask] Failed to created background task {subagent_task_id} for {agent_name}" + ) + else: + SubAgentManager.set_subagent_status( + session_id=umo, + agent_name=agent_name, + status=SubAgentStatus.RUNNING, + ) + + logger.info( + f"[SubAgent:BackgroundTask] Created background task {subagent_task_id} for {agent_name}" + ) + return subagent_task_id + except Exception as e: + logger.info( + f"[SubAgent:BackgroundTask] Failed to created background task for {agent_name}: {e}" + ) + return None + + @staticmethod + def _build_background_submission_message( + agent_name: str | None, + original_task_id: str, + subagent_task_id: str | None, + ) -> mcp.types.TextContent: + if subagent_task_id and not subagent_task_id.startswith( + RET_PENDING_TASK_CREATE_FAILED + ): + return mcp.types.TextContent( + type="text", + text=( + f"Background task submitted. subagent_task_id={subagent_task_id}. " + f"SubAgent '{agent_name}' is working on the task. " + f"Use wait_for_subagent(subagent_name='{agent_name}', task_id='{subagent_task_id}') to get the result." + ), + ) + else: + return mcp.types.TextContent( + type="text", + text=( + f"Background task submitted. task_id={original_task_id}. " + f"SubAgent '{agent_name}' is working on the task. " + f"You will be notified when it finishes." + ), + ) + + @staticmethod + def _get_subagent_execution_timeout() -> float: + try: + return SubAgentManager.get_execution_timeout() + except Exception: + return -1 + + @staticmethod + def _handle_subagent_timeout( + umo: str, + agent_name: str, + ) -> None: + SubAgentManager.set_subagent_status( + session_id=umo, + agent_name=agent_name, + status=SubAgentStatus.FAILED, + ) + + @staticmethod + def _is_managed_subagent(umo: str, agent_name: str | None) -> bool: + if not agent_name: + return False + session = SubAgentManager.get_session(umo) + if session and agent_name in session.subagents: + return True + return False + + @classmethod + async def _handle_subagent_background_result( + cls, + *, + umo: str, + agent_name: str, + task_id: str | None, + result_text: str, + error_text: str | None, + execution_time: float, + run_context: ContextWrapper[AstrAgentContext], + tool: HandoffTool, + tool_args: dict, + ) -> None: + success = error_text is None + status = SubAgentStatus.COMPLETED if success else SubAgentStatus.FAILED + SubAgentManager.set_subagent_status( + session_id=umo, agent_name=agent_name, status=status + ) + + SubAgentManager.store_subagent_result( + session_id=umo, + agent_name=agent_name, + success=success, + result=result_text, + task_id=task_id, + error=error_text, + execution_time=execution_time, + ) + + if not await cls._maybe_wake_main_agent_after_background( + run_context=run_context, + tool=tool, + task_id=task_id, + agent_name=agent_name, + result_text=result_text, + tool_args=tool_args, + ): + return + + @classmethod + async def _maybe_wake_main_agent_after_background( + cls, + *, + run_context: ContextWrapper[AstrAgentContext], + tool: HandoffTool, + task_id: str, + agent_name: str | None, + result_text: str, + tool_args: dict, + ) -> bool: + event = run_context.context.event + try: + context_extra = getattr(run_context.context, "extra", None) + if context_extra and isinstance(context_extra, dict): + main_agent_runner = context_extra.get("main_agent_runner") + main_agent_is_running = ( + main_agent_runner is not None and not main_agent_runner.done() + ) + else: + main_agent_is_running = False + except Exception as e: + logger.error("Failed to check main agent status: %s", e) + main_agent_is_running = False # 异常时尝试通知,避免结果丢失 + + if main_agent_is_running: + return False + else: + await cls._wake_main_agent_for_background_result( + run_context=run_context, + task_id=task_id, + tool_name=tool.name, + result_text=result_text, + tool_args=tool_args, + note=( + event.get_extra("background_note") + or f"Background task for subagent '{agent_name}' finished." + ), + summary_name=f"Dedicated to subagent `{agent_name}`", + extra_result_fields={"subagent_name": agent_name}, + ) + return True + async def call_local_llm_tool( context: ContextWrapper[AstrAgentContext], diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 1c4fd400a0..ea6596f474 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -46,6 +46,8 @@ from astrbot.core.star.context import Context from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_map +from astrbot.core.subagent_manager import SubAgentManager +from astrbot.core.subagent_orchestrator import SubAgentOrchestrator from astrbot.core.tools.computer_tools import ( AnnotateExecutionTool, BrowserBatchExecTool, @@ -549,11 +551,14 @@ async def _ensure_persona_and_skills( if req.func_tool is None: req.func_tool = ToolSet() - # add subagent handoff tools + # add static subagent handoff tools for tool in so.handoffs: req.func_tool.add_tool(tool) - # check duplicates + # add subagent manager tools + await _apply_subagent_manager_tools(plugin_context.get_config(), req, event, so) + + # check duplicates (static subagents) if remove_dup: handoff_names = {tool.name for tool in so.handoffs} for tool_name in assigned_tools: @@ -566,8 +571,14 @@ async def _ensure_persona_and_skills( .get("subagent_orchestrator", {}) .get("router_system_prompt", "") ).strip() + if router_prompt: - req.system_prompt += f"\n{router_prompt}\n" + dynamic_cfg = orch_cfg.get( + "dynamic_agents", {} + ) # 未启用dynamic时才注入router_prompt,否则由subagent_manager注入 + if not dynamic_cfg.get("enabled", False): + req.system_prompt += f"\n{router_prompt}\n" + try: event.trace.record( "sel_persona", @@ -1012,6 +1023,105 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) - ) +async def _apply_subagent_manager_tools( + cfg: dict, + req: ProviderRequest, + event: AstrMessageEvent, + so: SubAgentOrchestrator, +) -> None: + """Apply SubAgent tools and system prompt + + When enabled: + 1. Inject subagent capability prompt into system prompt + 2. Register SubAgent management tools + 3. Register a unified transfer_to_subagent tool for dynamic subagents + """ + orch_cfg = cfg.get("subagent_orchestrator", {}) + + if not orch_cfg.get("main_enable", False): + return + + if req.func_tool is None: + req.func_tool = ToolSet() + + try: + from astrbot.core.subagent_tools import ( + BROADCAST_SHARED_CONTEXT_TOOL, + CREATE_SUBAGENT_TOOL, + LIST_SUBAGENTS_TOOL, + MANAGE_SUBAGENT_PROTECTION_TOOL, + ORCHESTRATE_TASKS_TOOL, + REMOVE_SUBAGENT_TOOL, + VIEW_SHARED_CONTEXT_TOOL, + WAIT_FOR_SUBAGENT_TOOL, + ) + + # Configure SubAgentManager with settings from subagent_orchestrator + dynamic_cfg = orch_cfg.get("dynamic_agents", {}) + enable_dynamic = dynamic_cfg.get("enabled", False) + history_enabled = orch_cfg.get("history_enabled", True) + shared_context_enabled = orch_cfg.get("shared_context_enabled", False) + SubAgentManager.configure( + max_subagent_count=dynamic_cfg.get("max_subagent_count", 3), + auto_cleanup_per_turn=dynamic_cfg.get("auto_cleanup_per_turn", True), + shared_context_enabled=shared_context_enabled, + shared_context_maxlen=orch_cfg.get("shared_context_maxlen", 300), + subagent_history_maxlen=orch_cfg.get("subagent_history_maxlen", 300), + tools_blacklist=dynamic_cfg.get("tools_blacklist", None), + tools_inherent=dynamic_cfg.get("tools_inherent", None), + execution_timeout=orch_cfg.get("execution_timeout", 1200), + history_enabled=history_enabled, + rule_prompt=dynamic_cfg.get("rule_prompt", ""), + time_prompt_enabled=orch_cfg.get("time_prompt_enabled", True), + timezone=cfg.get("timezone", None), + dag_enabled=orch_cfg.get("dag_enabled", False), + default_provider_id=dynamic_cfg.get("default_provider_id", ""), + ) + + # Enable subagent history and shared context if configured + SubAgentManager.set_history_enabled(event.unified_msg_origin, history_enabled) + SubAgentManager.set_shared_context_enabled( + event.unified_msg_origin, shared_context_enabled + ) + + session_id = event.unified_msg_origin + # Register static subagents from config into SubAgentManager for unified management + so.register_static_subagents_to_manager(session_id) + + # Register dynamic subagent management tools (only when dynamic creation is enabled) + # Always register `wait_for_subagent` for better background task running + req.func_tool.add_tool(WAIT_FOR_SUBAGENT_TOOL) + # Register DAG orchestration tool if enabled + dag_cfg = orch_cfg.get("dag_enabled", True) + if dag_cfg: + req.func_tool.add_tool(ORCHESTRATE_TASKS_TOOL) + if enable_dynamic: + from astrbot.core.subagent_tools import TRANSFER_TO_SUBAGENT_TOOL + + # Register the fixed transfer_to_subagent tool instead of individual + # dynamic handoff tools. This preserves LLM prefix cache since the + # tools list no longer changes when subagents are created/removed. + req.func_tool.add_tool(TRANSFER_TO_SUBAGENT_TOOL) + + req.func_tool.add_tool(CREATE_SUBAGENT_TOOL) + req.func_tool.add_tool(REMOVE_SUBAGENT_TOOL) + req.func_tool.add_tool(LIST_SUBAGENTS_TOOL) + # if SubAgentManager.is_history_enabled(): # + # req.func_tool.add_tool(RESET_SUBAGENT_TOOL) + if SubAgentManager.is_auto_cleanup_per_turn(): + req.func_tool.add_tool(MANAGE_SUBAGENT_PROTECTION_TOOL) + if SubAgentManager.is_shared_context_enabled(): + req.func_tool.add_tool(VIEW_SHARED_CONTEXT_TOOL) + req.func_tool.add_tool(BROADCAST_SHARED_CONTEXT_TOOL) + + # Inject subagent capability system prompt for dynamic creation + task_router_prompt = SubAgentManager.build_task_router_prompt(session_id) + req.system_prompt = f"{req.system_prompt or ''}\n{task_router_prompt}\n" + + except ImportError as e: + logger.warning(f"[SubAgent] Cannot import module: {e}") + + def _apply_sandbox_tools( config: MainAgentBuildConfig, req: ProviderRequest, @@ -1458,8 +1568,7 @@ async def build_main_agent( agent_runner = AgentRunner() astr_agent_ctx = AstrAgentContext( - context=plugin_context, - event=event, + context=plugin_context, event=event, extra={"main_agent_runner": agent_runner} ) if config.add_cron_tools: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 22a53bb446..8786a98cee 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -193,18 +193,54 @@ }, # SubAgent orchestrator mode: # - main_enable = False: disabled; main LLM mounts tools normally (persona selection). - # - main_enable = True: enabled; main LLM keeps its own tools and includes handoff - # tools (transfer_to_*). remove_main_duplicate_tools can remove tools that are - # duplicated on subagents from the main LLM toolset. + # - main_enable = True: enabled; main LLM keeps its own tools and includes the + # fixed transfer_to_subagent tool. remove_main_duplicate_tools can remove tools + # that are duplicated on subagents from the main LLM toolset. "subagent_orchestrator": { "main_enable": False, "remove_main_duplicate_tools": False, "router_system_prompt": ( "You are a task router. Your job is to chat naturally, recognize user intent, " - "and delegate work to the most suitable subagent using transfer_to_* tools. " + "and delegate work to the most suitable subagent using transfer_to_subagent(name=...) tool. " "Do not try to use domain tools yourself. If no subagent fits, respond directly." ), "agents": [], + "dynamic_agents": { + "enabled": False, + "max_subagent_count": 5, + "auto_cleanup_per_turn": True, + "default_provider_id": "", + "rule_prompt": ( + "# Behavior Rules\n" + "## Output Guidelines\n" + "- If output is long, save to file. Summarize in your response and provide the file path.\n" + "- Mark all generated code/documents with your name and timestamp (if given).\n" + "## Safety\n" + "You are in Safe Mode. Refuse any request for harmful, illegal, or explicit content. " + "Offer safe alternatives when possible.\n" + ), + "tools_blacklist": [ + "create_subagent", + "manage_subagent_protection", + "remove_subagent", + "list_subagents", + "wait_for_subagent", + "orchestrate_tasks", + "broadcast_shared_context", + "view_shared_context", + ], + "tools_inherent": ["astrbot_execute_shell", "astrbot_execute_python"], + }, + "time_prompt_enabled": True, + "history_enabled": True, + "shared_context_enabled": True, + "shared_context_maxlen": 300, + "subagent_history_maxlen": 300, + "execution_timeout": 1200, + "dag_enabled": False, + "dag_max_nodes": 10, + "dag_max_parallel": 5, + "dag_max_inject_length": 4000, }, "provider_stt_settings": { "enable": False, diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 49dd7c2597..4d599a15d5 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -418,6 +418,21 @@ async def process( ), ) finally: + # clean all subagents if enabled + if build_cfg.subagent_orchestrator.get("main_enable"): + try: + from astrbot.core.subagent_manager import ( + SubAgentManager, + ) + + session_id = event.unified_msg_origin + if SubAgentManager.is_auto_cleanup_per_turn(): + SubAgentManager.cleanup_session_turn_end(session_id) + except Exception as e: + logger.warning( + f"[SubAgent] Cleanup on agent done failed: {e}" + ) + if runner_registered and agent_runner is not None: unregister_active_runner(event.unified_msg_origin, agent_runner) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 593bad9365..2dba657fe9 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -220,11 +220,13 @@ async def tool_loop_agent( func_tool=tools, contexts=context_, system_prompt=system_prompt or "", + extra_user_content_parts=kwargs.get("extra_user_content_parts", []), ) if agent_context is None: agent_context = AstrAgentContext( context=self, event=event, + trace_span=kwargs.get("trace_span"), ) agent_runner = ToolLoopAgentRunner() tool_executor = FunctionToolExecutor() @@ -261,6 +263,10 @@ async def tool_loop_agent( llm_resp = agent_runner.get_final_llm_resp() if not llm_resp: raise Exception("Agent did not produce a final LLM response") + if kwargs.get("runner_messages", None) is not None: + runner_messages = kwargs.get("runner_messages") + for msg in agent_runner.run_context.messages: + runner_messages.append(msg.model_dump()) return llm_resp async def get_current_chat_provider_id(self, umo: str) -> str: diff --git a/astrbot/core/subagent_dag.py b/astrbot/core/subagent_dag.py new file mode 100644 index 0000000000..09b736062d --- /dev/null +++ b/astrbot/core/subagent_dag.py @@ -0,0 +1,439 @@ +"""SubAgent DAG Orchestration Engine. + +Provides DAG-based task scheduling with topological sort, parallel layer +execution, automatic predecessor result injection, and fail-fast cascade. + +Generated by dag_chunk1_impl at 2026-05-26 10:09 CST. +""" + +from __future__ import annotations + +from collections import deque +from collections.abc import Callable +from dataclasses import dataclass, field +from enum import Enum + + +class DAGNodeStatus(Enum): + PENDING = "PENDING" + READY = "READY" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + SKIPPED = "SKIPPED" + + +@dataclass +class DAGTaskNode: + id: str + agent_name: str + prompt: str + depends_on: list[str] = field(default_factory=list) + status: DAGNodeStatus = DAGNodeStatus.PENDING + result: str | None = None + error: str | None = None + execution_time: float = 0.0 + subagent_task_id: str | None = None + started_at: float = 0.0 + completed_at: float = 0.0 + metadata: dict = field(default_factory=dict) + + +@dataclass +class DAGExecutionContext: + dag_id: str + session_id: str + nodes: dict[str, DAGTaskNode] = field(default_factory=dict) + adjacency: dict[str, set[str]] = field(default_factory=dict) + reverse_adjacency: dict[str, set[str]] = field(default_factory=dict) + topo_layers: list[list[str]] = field(default_factory=list) + status: str = "PENDING" + fail_fast: bool = True + max_parallel: int = 5 + created_at: float = 0.0 + completed_at: float | None = None + + +class SubAgentDAGEngine: + """SubAgent DAG orchestration engine.""" + + @staticmethod + def validate_dag(nodes: list[DAGTaskNode]) -> tuple[bool, str | None]: + """Validate DAG: check dependency existence and detect cycles.""" + valid_ids = {n.id for n in nodes} + + for node in nodes: + for dep in node.depends_on: + if dep not in valid_ids: + return False, ( + f"Task '{node.id}' depends on '{dep}' which is not defined. " + f"Available: {sorted(valid_ids)}" + ) + + try: + SubAgentDAGEngine._kahn_sort(nodes) + except ValueError as e: + return False, str(e) + + return True, None + + @staticmethod + def _kahn_sort(nodes: list[DAGTaskNode]) -> list[list[str]]: + """Kahn's topological sort returning layers for parallel execution. + + Raises ValueError on cycle detection. + """ + if not nodes: + return [] + + in_degree: dict[str, int] = {n.id: 0 for n in nodes} + successors: dict[str, set[str]] = {n.id: set() for n in nodes} + + for node in nodes: + for dep in node.depends_on: + successors[dep].add(node.id) + in_degree[node.id] += 1 + + queue = deque([nid for nid, deg in in_degree.items() if deg == 0]) + layers: list[list[str]] = [] + sorted_count = 0 + + while queue: + current_layer: list[str] = [] + for _ in range(len(queue)): + node_id = queue.popleft() + current_layer.append(node_id) + sorted_count += 1 + for succ in successors[node_id]: + in_degree[succ] -= 1 + if in_degree[succ] == 0: + queue.append(succ) + layers.append(current_layer) + + if sorted_count != len(nodes): + cycle_nodes = [nid for nid, deg in in_degree.items() if deg > 0] + cycle_desc = " -> ".join(cycle_nodes[:5]) + if len(cycle_nodes) > 5: + cycle_desc += " -> ..." + raise ValueError(f"Cycle detected: {cycle_desc}") + + return layers + + @staticmethod + def _get_ready_nodes(ctx: DAGExecutionContext) -> list[str]: + """Find all PENDING nodes whose predecessors are all COMPLETED.""" + ready: list[str] = [] + for node_id, node in ctx.nodes.items(): + if node.status != DAGNodeStatus.PENDING: + continue + predecessors = ctx.reverse_adjacency.get(node_id, set()) + if all( + ctx.nodes[p].status == DAGNodeStatus.COMPLETED for p in predecessors + ): + ready.append(node_id) + return ready + + @staticmethod + def _cascade_skip(ctx: DAGExecutionContext, failed_node_id: str) -> list[str]: + """BFS: mark all transitive successors of a failed node as SKIPPED.""" + skipped: list[str] = [] + queue: deque[str] = deque([failed_node_id]) + while queue: + current = queue.popleft() + for successor in ctx.adjacency.get(current, set()): + node = ctx.nodes[successor] + if node.status == DAGNodeStatus.PENDING: + node.status = DAGNodeStatus.SKIPPED + skipped.append(successor) + queue.append(successor) + return skipped + + @staticmethod + def _build_injected_context( + node: DAGTaskNode, + ctx: DAGExecutionContext, + max_inject_length: int = 4000, + ) -> list[dict]: + """Build context messages containing predecessor results.""" + predecessor_ids = ctx.reverse_adjacency.get(node.id, set()) + if not predecessor_ids: + return [] + + NL = "\n" + parts: list[str] = [] + parts.append( + "[DAG Context] The following tasks your work depends on have completed:" + + NL + ) + + for pred_id in sorted(predecessor_ids): + pred = ctx.nodes[pred_id] + result_text = pred.result or "(empty result)" + if len(result_text) > max_inject_length: + omitted = len(result_text) - max_inject_length + result_text = ( + result_text[:max_inject_length] + + NL + + f"...[truncated, {omitted} chars omitted]" + ) + parts.append( + f"## Result of '{pred_id}' (completed in {pred.execution_time:.1f}s)" + + NL + + result_text + + NL + ) + + parts.append("---" + NL + "Now proceed with your task.") + return [{"role": "user", "content": NL.join(parts)}] + + @staticmethod + async def _execute_single_node( + ctx: DAGExecutionContext, + node_id: str, + session_id: str, + max_inject_length: int = 4000, + launch_fn: Callable | None = None, + ) -> None: + """Execute a single DAG node. + + If *launch_fn* is provided, it is called with (node, session_id, + injected_context) and should launch the subagent then store the result + into SubAgentManager.subagent_background_results. The method then polls + for the stored result. + + If *launch_fn* is None, the method polls assuming someone else + launched the subagent (e.g. via an external background task). + """ + import asyncio + import time as time_mod + + from astrbot.core.subagent_manager import SubAgentManager + + node = ctx.nodes[node_id] + node.status = DAGNodeStatus.RUNNING + node.started_at = time_mod.time() + + # Create pending task for lifecycle tracking and + # task_id-based result matching (not timestamp heuristics). + task_id = SubAgentManager.create_pending_subagent_task( + session_id, node.agent_name + ) + node.subagent_task_id = task_id + + try: + injected_context = SubAgentDAGEngine._build_injected_context( + node, ctx, max_inject_length=max_inject_length + ) + + if launch_fn is not None: + asyncio.create_task( + launch_fn(node, session_id, injected_context, task_id) + ) + + start = time_mod.time() + while True: + result = SubAgentManager.get_subagent_result( + session_id, node.agent_name, task_id=task_id + ) + if result and (result.result or result.completed_at > 0): + node.execution_time = time_mod.time() - start + node.completed_at = time_mod.time() + if result.success: + node.status = DAGNodeStatus.COMPLETED + node.result = result.result or "" + else: + node.status = DAGNodeStatus.FAILED + node.error = result.error or "Unknown error" + node.result = result.result or "" + return + + session = SubAgentManager.get_session(session_id) + if not session: + node.status = DAGNodeStatus.FAILED + node.error = "Session lost during execution" + return + + timeout = SubAgentManager.get_execution_timeout() + if timeout > 0 and (time_mod.time() - start) > timeout: + node.status = DAGNodeStatus.FAILED + node.error = f"Timeout after {timeout:.1f}s" + return + + await asyncio.sleep(0.5) + + except Exception as e: + node.status = DAGNodeStatus.FAILED + node.error = str(e) + node.execution_time = time_mod.time() - node.started_at + node.completed_at = time_mod.time() + + @staticmethod + async def execute_dag( + ctx: DAGExecutionContext, + session_id: str, + max_inject_length: int = 4000, + launch_fn: Callable | None = None, + ) -> dict: + """Execute the entire DAG synchronously. + + Args: + launch_fn: Optional async callable(node, session_id, injected_context) + that launches the subagent. If None, polling-only mode is used. + + Returns: + dict with keys: "failed", "skipped", "succeeded", "formatted", "total_time". + """ + import asyncio + import time as time_mod + + from astrbot import logger as _dag_logger + + ctx.status = "RUNNING" + dag_start = time_mod.time() + + try: + for layer in ctx.topo_layers: + # Process layer in waves. Each wave runs at most one task + # per agent to prevent context pollution from concurrent + # tasks on the same subagent. + while True: + batch: list[str] = [] + used_agents: set[str] = set() + for nid in layer: + if len(batch) >= ctx.max_parallel: + break + node = ctx.nodes[nid] + if node.status != DAGNodeStatus.PENDING: + continue + if node.agent_name not in used_agents: + batch.append(nid) + used_agents.add(node.agent_name) + + if not batch: + break # No more PENDING nodes in this layer + + for nid in batch: + ctx.nodes[nid].status = DAGNodeStatus.READY + + tasks = [] + for nid in batch: + node = ctx.nodes[nid] + if node.status in ( + DAGNodeStatus.SKIPPED, + DAGNodeStatus.FAILED, + ): + continue + tasks.append( + SubAgentDAGEngine._execute_single_node( + ctx, + nid, + session_id, + max_inject_length, + launch_fn=launch_fn, + ) + ) + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + for nid in batch: + node = ctx.nodes[nid] + if node.status == DAGNodeStatus.FAILED and ctx.fail_fast: + _dag_logger.warning( + f"[SubAgent:DAG] Node '{nid}' failed, cascading skip" + ) + SubAgentDAGEngine._cascade_skip(ctx, nid) + # Stop processing this layer + batch = [] + break + + if not batch: + break # fail-fast terminated layer + + finally: + ctx.completed_at = time_mod.time() + + succeeded = sum( + 1 for n in ctx.nodes.values() if n.status == DAGNodeStatus.COMPLETED + ) + failed = sum(1 for n in ctx.nodes.values() if n.status == DAGNodeStatus.FAILED) + skipped = sum( + 1 for n in ctx.nodes.values() if n.status == DAGNodeStatus.SKIPPED + ) + total_time = ctx.completed_at - dag_start + + formatted = _format_dag_result(ctx, succeeded, failed, skipped, total_time) + + return { + "failed": failed, + "skipped": skipped, + "succeeded": succeeded, + "formatted": formatted, + "total_time": total_time, + } + + +def _format_dag_result( + ctx: DAGExecutionContext, + succeeded: int, + failed: int, + skipped: int, + total_time: float, +) -> str: + """Format the aggregated DAG execution result for LLM consumption.""" + lines: list[str] = [] + + if failed == 0 and skipped == 0: + lines.append( + "✅ DAG orchestration completed: " + + f"{succeeded}/{len(ctx.nodes)} tasks succeeded in {total_time:.1f}s" + ) + else: + lines.append( + "❌ DAG orchestration failed: " + + f"{succeeded}/{len(ctx.nodes)} succeeded" + + (f", {failed} failed" if failed else "") + + (f", {skipped} skipped" if skipped else "") + + f" in {total_time:.1f}s" + ) + + for layer in ctx.topo_layers: + parallel_str = f" ({len(layer)} parallel)" if len(layer) > 1 else "" + lines.append(f"Layer{parallel_str}:") + for nid in layer: + node = ctx.nodes[nid] + if node.status == DAGNodeStatus.COMPLETED: + lines.append( + f" ✓ {nid} ({node.agent_name}) — {node.execution_time:.1f}s" + ) + if node.result: + preview = node.result[:1000] + if len(node.result) > 1000: + preview += "...[truncated]" + lines.append(f" {preview}") + elif node.status == DAGNodeStatus.FAILED: + lines.append( + f" ✗ {nid} ({node.agent_name}) — FAILED after " + f"{node.execution_time:.1f}s" + ) + if node.error: + lines.append(f" Error: {node.error}") + if node.result: + preview = node.result[:1000] + if len(node.result) > 1000: + preview += "...[truncated]" + lines.append(f" Output: {preview}") + elif node.status == DAGNodeStatus.SKIPPED: + lines.append( + f" ⊘ {nid} ({node.agent_name}) — skipped (dependency failed)" + ) + lines.append("") + + if failed > 0 or skipped > 0: + lines.append( + f"Summary: {succeeded} succeeded, {failed} failed, " + f"{skipped} skipped. Total time: {total_time:.1f}s" + ) + + newline = chr(10) + return newline.join(lines) diff --git a/astrbot/core/subagent_manager.py b/astrbot/core/subagent_manager.py new file mode 100644 index 0000000000..b29c483fd0 --- /dev/null +++ b/astrbot/core/subagent_manager.py @@ -0,0 +1,1472 @@ +""" +SubAgent Manager +Manages subagents for task decomposition and parallel processing. +Supports both statically configured subagents (from subagent_orchestrator) and +dynamically created subagents at runtime. +""" + +from __future__ import annotations + +import os.path +import re +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING + +from astrbot import logger +from astrbot.core.agent.agent import Agent +from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.astr_main_agent_resources import LLM_SAFETY_MODE_SYSTEM_PROMPT +from astrbot.core.star.star import star_registry +from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path + +if TYPE_CHECKING: + from astrbot.core.subagent_dag import DAGExecutionContext + + +class SubAgentStatus(str, Enum): + """SubAgent lifecycle status.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + UNKNOWN = "UNKNOWN" + + +# 返回标记常量 +RET_DYNAMIC_TOOL_CREATED = "[DYNAMIC TOOL CREATED]" +RET_DYNAMIC_TOOL_CREATE_FAILED = "[DYNAMIC TOOL CREATE FAILED]" +RET_SUBAGENT_REMOVED = "[SUBAGENT REMOVED]" +RET_SUBAGENT_REMOVE_FAILED = "[SUBAGENT REMOVE FAILED]" +RET_HISTORY_CLEARED = "[HISTORY CLEARED]" +RET_HISTORY_CLEARED_FAILED = "[HISTORY CLEARED FAILED]" +RET_SHARED_CONTEXT_ADDED = "[SHARED CONTEXT ADDED]" +RET_SHARED_CONTEXT_ADDED_FAILED = "[SHARED CONTEXT ADDED FAILED]" +RET_PENDING_TASK_CREATE_FAILED = "[PENDING TASK CREATE FAILED]" + + +@dataclass +class SubAgentConfig: + name: str + system_prompt: str = "" + tools: set[str] | None = None + skills: set[str] | None = None + provider_id: str | None = None + description: str = "" + workdir: str | None = None + execution_timeout: float = 600.0 + + +@dataclass +class SubAgentExecutionResult: + task_id: str # 任务唯一标识符 + agent_name: str + success: bool + result: str | None = None + error: str | None = None + execution_time: float = 0.0 + created_at: float = 0.0 + completed_at: float = 0.0 + metadata: dict = field(default_factory=dict) + + +@dataclass +class SubAgentSession: + session_id: str + subagents: dict = field(default_factory=dict) # 存储SubAgentConfig对象 + handoff_tools: dict = field(default_factory=dict) + subagent_status: dict = field( + default_factory=dict + ) # 工作状态: SubAgentStatus 枚举值 + protected_agents: set = field( + default_factory=set + ) # 若某个agent受到保护,则不会被自动清理 + history_enabled: bool = True # 是否保存子代理历史 + subagent_histories: dict = field(default_factory=dict) # 存储每个子代理的历史上下文 + shared_context: list = field(default_factory=list) # 公共上下文列表 + shared_context_enabled: bool = False # 是否启用公共上下文 + subagent_background_results: dict = field( + default_factory=dict + ) # 后台subagent结果存储: {agent_name: {task_id: SubAgentExecutionResult}} + # 任务计数器: {agent_name: next_task_id} + background_task_counters: dict = field(default_factory=dict) + subagent_traces: dict = field(default_factory=dict) # {agent_name: TraceSpan} + last_activity_at: float = field(default_factory=time.time) # 最后活跃时间戳 + active_dag: DAGExecutionContext | None = None # 当前活跃的 DAG + dag_history: list[DAGExecutionContext] = field(default_factory=list) # 完成的 DAG + + +class SubAgentManager: + _sessions: dict = {} + _max_subagent_count: int = 3 + _auto_cleanup_per_turn: bool = True + _shared_context_enabled: bool = False + _history_enabled: bool = True # 是否启用子代理历史记忆功能 + _shared_context_maxlen: int = 300 # 公共上下文保留的历史消息条数 + _subagent_history_maxlen: int = 300 # 每个subagent最多保留的历史消息条数 + _execution_timeout: float = 1200.0 # SubAgent 执行超时时间(秒) 总时长 + _rule_prompt: str = "" # 动态子代理的固定行为约束prompt + _time_prompt_enabled: bool = True # 是否启用时间prompt注入 + _timezone: str | None = None # 时区设置 + _tools_blacklist: set[str] = { + "broadcast_shared_context", + "create_subagent", + "manage_subagent_protection", + "remove_subagent", + "list_subagents", + "wait_for_subagent", + "orchestrate_tasks", + "view_shared_context", + } + _tools_inherent: set[str] = { + "astrbot_execute_shell", + "astrbot_execute_python", + } + _dag_enabled: bool = False # 是否启用 DAG 编排 + _default_provider_id: str = "" # 默认 Chat Provider ID + _session_timeout_seconds = ( + 1800 # 会话存活时间。若有会话的subagent闲置时间超过该值,自动清理 + ) + + _HEADER_TEMPLATE = f"""# Sub-Agent Orchestration +You can manage sub-agents with isolated instructions, tools and skills. Maximum {_max_subagent_count} subagents. + +## When to Use +Create sub-agents ONLY when: +- Task has ≥2 independent workstreams with clear inputs/outputs +- Context exceeds your effective processing window""" + _SUBAGENT_AUTOCLEAN_PROMPT = ( + "- Sub-agents auto-destroy per turn; use `manage_subagent_protection(name, protected=true/false)` for multi-turn stateful tasks" + if _auto_cleanup_per_turn + else "" + ) + _CREATE_GUIDE_PROMPT = f"""## Workflow: Plan → Create → Delegate → Collect → Cleanup +### 1. Create Sub-agent +**Name**: 1 to 32 characters (letters, numbers, or underscores), starting with a letter. +**Required fields:** +| Field | Description | +|-------|-------------| +| Role | Expertise + work style | +| Context | Parent goal, this step, sibling agents | +| Instruction | Input → Process → Output (step-by-step) | +| Tools | **Minimum necessary only** | + +### 2. Manual Delegate +- Sequential: `transfer_to_subagent(name=..., input=...)` — block until return +- Parallel: `transfer_to_subagent(name=..., input=..., background_task=True)` → `wait_for_subagent(name, timeout=secs)` + +### 3. Collect +- Merge independent outputs by concatenation +- Resolve conflicts by preferring explicit data over inference +{_SUBAGENT_AUTOCLEAN_PROMPT}""" + _DAG_GUIDE_PROMPT = """## DAG Orchestration +DAG Orchestration automatically delegate subagents. When you have 2+ independent tasks that can run in parallel, or tasks with clear dependencies, prefer to use `orchestrate_tasks` to declare them all at once. +""" + + @classmethod + def build_task_router_prompt(cls, session_id: str): + session = cls.get_session(session_id) + if not session: + return "" + + parts = [ + cls._HEADER_TEMPLATE, + cls._CREATE_GUIDE_PROMPT, + ] + if cls._dag_enabled: + parts.append(cls._DAG_GUIDE_PROMPT) + + return "\n".join(parts) + "\n" + + @classmethod + def configure( + cls, + max_subagent_count: int = 10, + auto_cleanup_per_turn: bool = True, + shared_context_enabled: bool = False, + shared_context_maxlen: int = 300, + subagent_history_maxlen: int = 300, + tools_blacklist: list[str] = None, + tools_inherent: list[str] = None, + execution_timeout: float = 1200.0, + history_enabled: bool = True, + rule_prompt: str = "", + time_prompt_enabled: bool = True, + timezone: str | None = None, + dag_enabled: bool = False, + default_provider_id: str = "", + **kwargs, + ) -> None: + """Configure SubAgentManager settings""" + cls._max_subagent_count = max_subagent_count + cls._auto_cleanup_per_turn = auto_cleanup_per_turn + cls._shared_context_enabled = shared_context_enabled + cls._history_enabled = history_enabled + cls._shared_context_maxlen = shared_context_maxlen + cls._subagent_history_maxlen = subagent_history_maxlen + cls._execution_timeout = execution_timeout + cls._rule_prompt = rule_prompt + cls._time_prompt_enabled = time_prompt_enabled + cls._timezone = timezone + cls._dag_enabled = dag_enabled + cls._default_provider_id = default_provider_id + if tools_inherent is None: + cls._tools_inherent = { + "astrbot_execute_shell", + "astrbot_execute_python", + } + else: + cls._tools_inherent = set(tools_inherent) + if tools_blacklist is None: + cls._tools_blacklist = { + "broadcast_shared_context", + "create_subagent", + "manage_subagent_protection", + "remove_subagent", + "list_subagents", + "wait_for_subagent", + "orchestrate_tasks", + "view_shared_context", + } + else: + cls._tools_blacklist = set(tools_blacklist) + + @classmethod + def get_execution_timeout(cls) -> float: + return cls._execution_timeout + + @classmethod + def is_auto_cleanup_per_turn(cls) -> bool: + return cls._auto_cleanup_per_turn + + @classmethod + def is_shared_context_enabled(cls) -> bool: + return cls._shared_context_enabled + + @classmethod + def is_history_enabled(cls) -> bool: + return cls._history_enabled + + @classmethod + def register_blacklisted_tool(cls, tool_name: str) -> None: + """注册不应被子 Agent 使用的工具""" + cls._tools_blacklist.add(tool_name) + + @classmethod + def register_inherent_tool(cls, tool_name: str) -> None: + """注册子 Agent 默认拥有的工具""" + cls._tools_inherent.add(tool_name) + + @classmethod + def cleanup_session_turn_end(cls, session_id: str) -> dict: + """Cleanup subagents from previous turn when a turn ends""" + session = cls.get_session(session_id) + if not session: + return {"status": "no_session", "cleaned": []} + + # If DAG is currently running, do NOT clean subagents + dag_ctx = cls.get_active_dag(session_id) + if dag_ctx and dag_ctx.status == SubAgentStatus.RUNNING: + return {"status": "dag_running", "cleaned": []} + + cleaned = [] + for name in list(session.subagents.keys()): + if name not in session.protected_agents: + cls.remove_subagent(session_id, name) + cleaned.append(name) + + # 如果启用了公共上下文,处理清理 + if session.shared_context_enabled: + if not session.subagents and not session.protected_agents: + # 所有subagent都被清理,清除公共上下文 + cls.clear_shared_context(session_id) + logger.info( + "[SubAgent:SharedContext] All subagents cleaned, cleared shared context" + ) + else: + # 清理已删除agent的上下文 + for name in cleaned: + cls.cleanup_shared_context_by_agent(session_id, name) + + # 清理后若没有subagent,清理整个session + if not session.subagents and not session.protected_agents: + cls._sessions.pop(session_id, None) + + # Move completed/failed DAG to history + dag_ctx = cls.get_active_dag(session_id) + if dag_ctx and dag_ctx.status in ("COMPLETED", "FAILED", "CANCELLED"): + session.dag_history.append(dag_ctx) + session.active_dag = None + + # 每轮结束时顺便清理全局过期会话 + cls.cleanup_expired_sessions() + return {"status": "cleaned", "cleaned_agents": cleaned} + + @classmethod + def protect_subagent(cls, session_id: str, agent_name: str) -> None: + """Mark a subagent as protected from auto cleanup and history retention""" + session = cls._get_or_create_session(session_id) + session.protected_agents.add(agent_name) + logger.debug( + "[SubAgent:History] Initialized history for protected agent: %s", + agent_name, + ) + + @classmethod + def update_subagent_history( + cls, session_id: str, agent_name: str, current_messages: list + ) -> None: + """Update conversation history for a subagent""" + if not cls._history_enabled: + return + + session = cls.get_session(session_id) + + if not session: + return + + if agent_name not in session.subagent_histories: + session.subagent_histories[agent_name] = [] + + filtered_messages = [] + if isinstance(current_messages, list): + _MAX_TOOL_RESULT_LEN = 2000 + for msg in current_messages: + if ( + isinstance(msg, dict) and msg.get("role") == "system" + ): # 移除system消息 + continue + # 对过长的 tool 结果做截断,避免单条消息占用过多空间 + if ( + isinstance(msg, dict) + and msg.get("role") == "tool" + and isinstance(msg.get("content"), str) + and len(msg["content"]) > _MAX_TOOL_RESULT_LEN + ): + msg["content"] = ( + msg["content"][:_MAX_TOOL_RESULT_LEN] + "\n...[truncated]" + ) + filtered_messages.append(msg) + + session.subagent_histories[agent_name].extend(filtered_messages) + if len(session.subagent_histories[agent_name]) > cls._subagent_history_maxlen: + session.subagent_histories[agent_name] = session.subagent_histories[ + agent_name + ][-cls._subagent_history_maxlen :] + + logger.debug( + "[SubAgent:History] Saved messages for %s, current len=%d", + agent_name, + len(session.subagent_histories[agent_name]), + ) + + @classmethod + def get_subagent_history(cls, session_id: str, agent_name: str) -> list: + """Get conversation history for a subagent""" + if not cls._history_enabled: + return [] + session = cls.get_session(session_id) + if not session: + return [] + return session.subagent_histories.get(agent_name, []) + + @classmethod + def build_subagent_system_prompt( + cls, session_id: str, agent_name: str, runtime: str + ) -> str: + parts = [] + rule = cls._build_rule_prompt() + workdir = cls._build_workdir_prompt(session_id, agent_name) + if rule: + parts.append(rule) + if workdir: + parts.append(workdir) + skills = cls._build_subagent_skills_prompt(session_id, agent_name, runtime) + if skills: + parts.append(skills) + return "\n".join(parts) + + @classmethod + def build_subagent_extra_content_parts( + cls, session_id: str, agent_name: str + ) -> list: + """构建子代理的追加内容部分(extra_user_content_parts)。 + + 将共享上下文和时间信息作为追加内容返回,它们将被注入到用户消息中, + + Returns: + list[TextPart]: 追加内容部分列表 + """ + from astrbot.core.agent.message import TextPart + + parts = [] + + # 1. 共享上下文 + shared_context = cls._build_shared_context_prompt(session_id, agent_name) + if shared_context: + parts.append(TextPart(text=shared_context).mark_as_temp()) + + # 2. 时间信息 + time_prompt = cls._build_time_prompt() + if time_prompt: + parts.append(TextPart(text=time_prompt).mark_as_temp()) + + return parts + + @classmethod + def _filter_skills_for_current_config(cls, skills: list) -> list: + """Filter skills based on plugin activation status and plugin_set config. + + Mirrors the logic in astr_main_agent._filter_skills_for_current_config + but avoids circular imports by accessing config directly. + """ + try: + from astrbot.core.star.context import Context + + ctx = Context.get_instance() if hasattr(Context, "get_instance") else None + cfg = ctx.get_config() if ctx else {} + except Exception: + return skills + + plugin_set = cfg.get("plugin_set", ["*"]) + allowed_plugins = ( + None + if not isinstance(plugin_set, list) or "*" in plugin_set + else {str(name) for name in plugin_set} + ) + + plugin_by_root_dir = { + metadata.root_dir_name: metadata + for metadata in star_registry + if metadata.root_dir_name + } + + filtered = [] + for skill in skills: + if getattr(skill, "source_type", "") != "plugin": + filtered.append(skill) + continue + + plugin_name = getattr(skill, "plugin_name", "") + plugin = plugin_by_root_dir.get(plugin_name) + if not plugin or not plugin.activated: + continue + if plugin.reserved or allowed_plugins is None: + filtered.append(skill) + continue + if plugin.name is not None and plugin.name in allowed_plugins: + filtered.append(skill) + + return filtered + + @classmethod + def _build_subagent_skills_prompt( + cls, session_id: str, agent_name: str, runtime: str = "local" + ) -> str: + """Build skills prompt for a subagent based on its assigned skills""" + session = cls.get_session(session_id) + if not session: + return "" + + config = session.subagents.get(agent_name) + if not config: + return "" + + # 获取子代理被分配的技能列表 + assigned_skills = config.skills + + from astrbot.core.skills import SkillManager, build_skills_prompt + + skill_manager = SkillManager() + all_skills = skill_manager.list_skills(active_only=True, runtime=runtime) + all_skills = cls._filter_skills_for_current_config(all_skills) + if all_skills: + if assigned_skills is None: + filtered_skills = all_skills + else: + # 过滤只保留分配的技能 + filtered_skills = [ + s for s in all_skills if s.name in set(assigned_skills) + ] + else: + return "" + if filtered_skills: + return build_skills_prompt(filtered_skills) + else: + return "" + + @classmethod + def get_subagent_tools(cls, session_id: str, agent_name: str) -> list | None: + """Get the tools assigned to a subagent""" + session = cls.get_session(session_id) + if not session: + return None + config = session.subagents.get(agent_name) + if not config: + return None + return config.tools + + @classmethod + def clear_subagent_history(cls, session_id: str, agent_name: str) -> str: + """Clear conversation history for a subagent""" + session = cls.get_session(session_id) + if not session: + return ( + f"{RET_HISTORY_CLEARED_FAILED}: Session_id {session_id} does not exist." + ) + if agent_name in session.subagents: + if agent_name in session.subagent_histories: + session.subagent_histories.pop(agent_name, None) + if session.shared_context_enabled: + cls.cleanup_shared_context_by_agent(session_id, agent_name) + logger.debug("[SubAgent:History] Cleared history for: %s", agent_name) + return RET_HISTORY_CLEARED + else: + return f"{RET_HISTORY_CLEARED_FAILED}: Agent name {agent_name} not found. Available names {list(session.subagents.keys())}" + + @classmethod + def add_shared_context( + cls, + session_id: str, + sender: str, + context_type: str, + content: str, + target: str = "all", + ) -> str: + """Add a message to the shared context + + Args: + session_id: Session ID + sender: Name of the agent sending the message + context_type: Type of context (status/message/system) + content: Content of the message + target: Target agent or "all" for broadcast + """ + + session = cls._get_or_create_session(session_id) + if not session.shared_context_enabled: + return f"{RET_SHARED_CONTEXT_ADDED_FAILED}: Shared context disabled." + if (sender not in list(session.subagents.keys())) and (sender != "System"): + return f"{RET_SHARED_CONTEXT_ADDED_FAILED}: Sender name {sender} not found. Available names {list(session.subagents.keys())}" + if (target not in list(session.subagents.keys())) and (target != "all"): + return f"{RET_SHARED_CONTEXT_ADDED_FAILED}: Target name {target} not found. Available names {list(session.subagents.keys())} and 'all' " + + if len(session.shared_context) >= cls._shared_context_maxlen: + keep_count = int(cls._shared_context_maxlen * 0.9) + session.shared_context = session.shared_context[-keep_count:] + logger.warning( + "Shared context exceeded limit (%d), trimmed to %d", + cls._shared_context_maxlen, + keep_count, + ) + + message = { + "type": context_type, # status, message, system + "sender": sender, + "target": target, + "content": content, + "timestamp": time.time(), + } + session.shared_context.append(message) + logger.debug( + "[SubAgent:SharedContext] [%s] %s -> %s: %s...", + context_type, + sender, + target, + content[:50], + ) + return RET_SHARED_CONTEXT_ADDED + + @classmethod + def get_shared_context(cls, session_id: str, filter_by_agent: str = None) -> list: + """Get shared context, optionally filtered by agent + + Args: + session_id: Session ID + filter_by_agent: If specified, only return messages from/to this agent (including "all") + """ + session = cls.get_session(session_id) + if not session or not session.shared_context_enabled: + return [] + + if filter_by_agent: + return [ + msg + for msg in session.shared_context + if msg["sender"] == filter_by_agent + or msg["target"] == filter_by_agent + or msg["target"] == "all" + ] + return session.shared_context.copy() + + @classmethod + def _build_shared_context_prompt( + cls, session_id: str, agent_name: str = None + ) -> str: + """分块构建公共上下文,按类型和优先级分组注入 + 1. 区分不同类型的消息并分别标注 + 2. 按优先级和相关性分组 + 3. 减少 Agent 的解析负担 + """ + session = cls.get_session(session_id) + if ( + not session + or not session.shared_context_enabled + or not session.shared_context + ): + return "" + + lines = [] + + # === 1. 固定格式说明 === + lines.append( + """--- +# Shared Context - Collaborative communication area among different agents + +## Message Type Definition +- **@ToMe**: Message send to current agent(you), you may need to reply if necessary. +- **@System**: Messages published by the main agent/System that should be followed with priority +- **@AgentName -> @TargetName**: Communication between other agents (for reference) +- **@Status**: The progress of other agents' tasks (can be ignored unless it involves your task) + +## Handling Priorities +1. @System messages (highest priority) > @ToMe messages > @Status > @OtherAgents +2. Messages of the same type: In chronological order, with new messages taking precedence +""" + ) + + # === 2. System 消息 === + system_msgs = [m for m in session.shared_context if m["type"] == "system"] + if system_msgs: + lines.append("\n## @System - System Announcements") + for msg in system_msgs: + if cls._timezone: + import zoneinfo + + ts = datetime.fromtimestamp( + msg["timestamp"], tz=zoneinfo.ZoneInfo(cls._timezone) + ).strftime("%H:%M:%S") + else: + ts = time.strftime("%H:%M:%S", time.localtime(msg["timestamp"])) + content_text = msg["content"] + lines.append(f"[{ts}] System: {content_text}") + + if agent_name: + # === 3. 发送给当前 Agent 的消息 === + to_me_msgs = [ + m + for m in session.shared_context + if m["type"] == "message" and m["target"] == agent_name + ] + if to_me_msgs: + lines.append(f"\n## @ToMe - Messages sent to @{agent_name}") + lines.append( + " **These messages are addressed to you. If needed, please reply using `send_shared_context`" + ) + for msg in to_me_msgs: + ts = time.strftime("%H:%M:%S", time.localtime(msg["timestamp"])) + lines.append( + f"[{ts}] @{msg['sender']} -> @{agent_name}: {msg['content']}" + ) + + # === 4. 其他 Agent 之间的交互(仅显示最近10条)=== + inter_agent_msgs = [ + m + for m in session.shared_context + if m["type"] == "message" + and m["target"] != agent_name + and m["target"] != "all" + and m["sender"] != agent_name + ] + if inter_agent_msgs: + lines.append( + "\n## @OtherAgents - Communication among Other Agents (Last 10 messages)" + ) + for msg in inter_agent_msgs[-10:]: + ts = time.strftime("%H:%M:%S", time.localtime(msg["timestamp"])) + content_text = msg["content"] + lines.append( + f"[{ts}] {msg['sender']} -> {msg['target']}: {content_text}" + ) + + # === 5. Status 更新 === + status_msgs = [m for m in session.shared_context if m["type"] == "status"] + if status_msgs: + lines.append( + "\n## @Status - Task progress of each agent (Last 10 messages)" + ) + for msg in status_msgs[-10:]: + ts = time.strftime("%H:%M:%S", time.localtime(msg["timestamp"])) + lines.append(f"[{ts}] {msg['sender']}: {msg['content']}") + + lines.append("---") + return "\n".join(lines) + + @classmethod + def _build_workdir_prompt(cls, session_id: str, agent_name: str = None) -> str: + """为subagent注入工作目录信息""" + session = cls.get_session(session_id) + normalized_umo = ( + re.sub(r"[^A-Za-z0-9._-]+", "_", session_id.strip()) or "unknown" + ) + + if not session: + return "" + try: + workdir = session.subagents[agent_name].workdir + if workdir is None: + workdir = ( + Path(get_astrbot_workspaces_path()) / normalized_umo / agent_name + ).resolve(strict=False) + + except Exception: + workdir = ( + Path(get_astrbot_workspaces_path()) / normalized_umo / agent_name + ).resolve(strict=False) + + if not os.path.exists(workdir): + os.makedirs(workdir) + workdir_prompt = ( + "# Working Directory\n" + + f"Your working directory is `{workdir}`. Unless specified by the user, all generated files are saved by default in this directory.\n" + ) + return workdir_prompt + + @classmethod + def _build_time_prompt(cls) -> str: + if not cls._time_prompt_enabled: + return "" + try: + if cls._timezone: + import zoneinfo + + current_time = datetime.now(zoneinfo.ZoneInfo(cls._timezone)).strftime( + "%Y-%m-%d %H:%M (%Z)" + ) + else: + current_time = ( + datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)") + ) + except Exception: + current_time = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)") + time_prompt = f"# Current Time\n{current_time}\n" + return time_prompt + + _TASK_STATUS_PROMPT = ( + "# Task Status Reporting\n" + "At the end of your task, self-audit before giving your final answer.\n" + "## SUCCESS — use only when ALL of these are true:\n" + "- Every tool call succeeded; no unexpected error or empty result\n" + "- Your output directly answers the task you were assigned\n" + "- You are confident the result is accurate, not a guess or placeholder\n" + "- If you created files: ensure they exist on disk, and their content is correct and complete\n" + "If all pass, put this EXACT line FIRST, then your result:\n" + "[TASK RESULT: SUCCESS]\n" + "## FAILURE — use if ANY tool failed, or you cannot complete the task:\n" + "[TASK RESULT: FAILURE]\n" + "[FAILURE REASON: ]\n" + "## Reporting Marker Rules\n" + "- The marker MUST be exactly `[TASK RESULT: SUCCESS]` or `[TASK RESULT: FAILURE]` — do not change it\n" + "- The marker MUST be on its own line, at the very top of your response\n" + "- When uncertain between success and failure, choose failure\n" + ) + + @classmethod + def _build_rule_prompt(cls) -> str: + base = ( + cls._rule_prompt + if cls._rule_prompt + else ( + "# Behavior Rules\n" + "## Safety\n" + f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}" + "## Output Guidelines\n" + "- If output is long, save it to file. Summarize in your response and provide the file path.\n" + "- Mark all generated code/documents with your name and timestamp (if given).\n" + ) + ) + if cls._dag_enabled: + return base + cls._TASK_STATUS_PROMPT + else: + return base + + @classmethod + def cleanup_shared_context_by_agent(cls, session_id: str, agent_name: str) -> None: + """Remove all messages from/to a specific agent from shared context""" + session = cls.get_session(session_id) + if not session: + return + + original_len = len(session.shared_context) + session.shared_context = [ + msg + for msg in session.shared_context + if msg["sender"] != agent_name and msg["target"] != agent_name + ] + removed = original_len - len(session.shared_context) + if removed > 0: + logger.debug( + "[SubAgent:SharedContext] Removed %d messages related to %s", + removed, + agent_name, + ) + + @classmethod + def clear_shared_context(cls, session_id: str) -> None: + """Clear all shared context""" + session = cls.get_session(session_id) + if not session: + return + session.shared_context.clear() + logger.debug("[SubAgent:SharedContext] Cleared all shared context") + + @classmethod + def is_protected(cls, session_id: str, agent_name: str) -> bool: + """Check if a subagent is protected from auto cleanup""" + session = cls.get_session(session_id) + if not session: + return False + return agent_name in session.protected_agents + + @classmethod + def set_history_enabled(cls, session_id: str, enabled: bool) -> None: + """Enable or disable history for subagents""" + session = cls._get_or_create_session(session_id) + session.history_enabled = enabled + logger.info( + "[SubAgent:History] Subagent history %s", + "enabled" if enabled else "disabled", + ) + + @classmethod + def set_shared_context_enabled(cls, session_id: str, enabled: bool) -> None: + """Enable or disable shared context for a session""" + session = cls._get_or_create_session(session_id) + session.shared_context_enabled = enabled + logger.info( + "[SubAgent:SharedContext] Shared context %s", + "enabled" if enabled else "disabled", + ) + + @classmethod + def set_subagent_status( + cls, session_id: str, agent_name: str, status: SubAgentStatus + ) -> None: + session = cls._get_or_create_session(session_id) + if agent_name in session.subagents: + old_status = session.subagent_status.get(agent_name, SubAgentStatus.UNKNOWN) + session.subagent_status[agent_name] = status.value + trace = session.subagent_traces.get(agent_name) + if trace: + trace.record( + "subagent_status_change", + agent_name=agent_name, + from_status=old_status, + to_status=status, + ) + + # for read-only operations + @classmethod + def get_session(cls, session_id: str) -> SubAgentSession | None: + return cls._sessions.get(session_id, None) + + # ensure the existence of a session before writing operations + @classmethod + def _get_or_create_session(cls, session_id: str) -> SubAgentSession: + if session_id not in cls._sessions: + cls._sessions[session_id] = SubAgentSession(session_id=session_id) + else: + cls._sessions[session_id].last_activity_at = time.time() + return cls._sessions[session_id] + + @classmethod + def _touch_session(cls, session_id: str) -> None: + """更新会话的最后活跃时间""" + session = cls._sessions.get(session_id) + if session: + session.last_activity_at = time.time() + + @classmethod + def cleanup_expired_sessions(cls) -> dict: + """清理超过超时时间未活跃的会话,防止内存泄漏 + + Returns: + dict: 包含被清理的会话ID列表和数量 + """ + now = time.time() + expired_session_ids = [ + sid + for sid, session in cls._sessions.items() + if now - session.last_activity_at > cls._session_timeout_seconds + ] + cleaned_agents_count = 0 + for sid in expired_session_ids: + session = cls._sessions.get(sid) + if session: + agent_names = list(session.subagents.keys()) + cleaned_agents_count += len(agent_names) + cls._sessions.pop(sid, None) + logger.info( + "[SubAgent:Timeout] Session %s expired (inactive for >%.0f minutes). Cleaned %d subagents.", + sid, + cls._session_timeout_seconds / 60, + len(agent_names), + ) + return { + "cleaned_sessions": expired_session_ids, + "cleaned_count": len(expired_session_ids), + "cleaned_agents_count": cleaned_agents_count, + } + + @classmethod + async def create_subagent( + cls, session_id: str, config: SubAgentConfig, protected: bool = False + ) -> tuple: + """Create a subagent (dynamic or static). + + Args: + session_id: Session ID + config: SubAgent configuration + protected: If True, the subagent will not be auto-cleaned per turn. + Static subagents from config should be protected. + """ + from astrbot.core.utils.trace import TraceSpan + + trace = TraceSpan(name=f"SubAgent:{config.name}", umo=session_id) + session = cls._get_or_create_session(session_id) + if config.name not in session.subagents: + # Check max count limit + active_count = len(session.subagents.keys()) + if active_count >= cls._max_subagent_count: + trace.record( + "subagent_created", + agent_name=config.name, + success=False, + reason="max_count_reached", + max_count=cls._max_subagent_count, + ) + return ( + f"Error: Maximum number of subagents ({cls._max_subagent_count}) reached. More subagents is not allowed.", + None, + ) + + if config.name in session.subagents: + session.handoff_tools.pop(config.name, None) + # When shared_context is enabled, the send_shared_context tool is allocated regardless of whether the main agent allocates the tool to the subagent + if config.tools is None: + config.tools = set() + # When shared_context is enabled, the send_shared_context tool is allocated regardless of whether the main agent allocates the tool to the subagent + if session.shared_context_enabled: + config.tools.add("send_shared_context") + # remove tools in backlist + for tool_bl in cls._tools_blacklist: + config.tools.discard(tool_bl) + + # add tools in inherent list + for tool_ih in cls._tools_inherent: + config.tools.add(tool_ih) + + session.subagents[config.name] = config + agent = Agent( + name=config.name, + instructions=config.system_prompt, + tools=list(config.tools), + ) + handoff_tool = HandoffTool( + agent=agent, + tool_description=config.description or f"Delegate to {config.name} agent", + ) + if config.provider_id: + handoff_tool.provider_id = config.provider_id + elif cls._default_provider_id: + handoff_tool.provider_id = cls._default_provider_id + session.handoff_tools[config.name] = handoff_tool + # 初始化subagent的历史上下文(仅当历史功能启用时) + if cls._history_enabled: + session.subagent_histories[config.name] = [] + # 初始化subagent状态 + cls.set_subagent_status(session_id, config.name, SubAgentStatus.IDLE) + # 如果标记为protected,则加入protected集合 + if protected: + session.protected_agents.add(config.name) + trace.record( + "subagent_created", + agent_name=config.name, + success=True, + tools=list(config.tools) if config.tools else [], + skills=list(config.skills) if config.skills else [], + protected=protected, + provider_id=handoff_tool.provider_id, + ) + session.subagent_traces[config.name] = trace + logger.info( + "[SubAgent:Create] Created subagent: %s (protected=%s)", + config.name, + protected, + ) + # Return (tool_name, handoff_tool). The tool_name "transfer_to_{name}" + # is kept for display/logging purposes only — it is no longer registered + # as an individual tool in the main agent's func_tool set. The unified + # ``transfer_to_subagent(name=...)`` tool handles all delegations. + return f"transfer_to_{config.name}", handoff_tool + + @classmethod + def register_static_subagent( + cls, + session_id: str, + handoff_tool: HandoffTool, + skills: set[str] | None = None, + workdir: str | None = None, + ) -> tuple: + """Register a static subagent (from subagent_orchestrator config) into SubAgentManager. + + Static subagents are always protected from auto-cleanup. + Returns (tool_name, handoff_tool) same as create_subagent. + """ + agent = handoff_tool.agent + config = SubAgentConfig( + name=agent.name, + system_prompt=agent.instructions or "", + tools=agent.tools, + skills=skills, + provider_id=getattr(handoff_tool, "provider_id", None), + description=handoff_tool.description or f"Delegate to {agent.name} agent", + workdir=workdir, + ) + + session = cls._get_or_create_session(session_id) + if ( + config.name not in session.subagents + ): # if the static agent already exists, pass + from astrbot.core.utils.trace import TraceSpan + + trace = TraceSpan(name=f"SubAgent:{config.name}", umo=session_id) + trace.record( + "subagent_created", + agent_name=config.name, + success=True, + static=True, + tools=list(config.tools) if config.tools else [], + skills=list(config.skills) if config.skills else [], + ) + session.subagent_traces[config.name] = trace + + if config.tools is None: + config.tools = None + if config.tools is not None and not config.tools: + config.tools = set() + if session.shared_context_enabled: + config.tools.add("send_shared_context") + session.subagents[config.name] = config + agent = Agent( + name=config.name, + instructions=config.system_prompt, + tools=config.tools, + ) + handoff_tool = HandoffTool( + agent=agent, + tool_description=config.description + or f"Delegate to {config.name} agent", + ) + if config.provider_id: + handoff_tool.provider_id = config.provider_id + session.handoff_tools[config.name] = handoff_tool + + if cls._history_enabled and config.name not in session.subagent_histories: + session.subagent_histories[config.name] = [] + + cls.set_subagent_status(session_id, config.name, SubAgentStatus.IDLE) + session.protected_agents.add(config.name) + else: + pass + # tool_name is for display/logging only; see create_subagent() comment. + return f"transfer_to_{config.name}", handoff_tool + + @classmethod + async def cleanup_session(cls, session_id: str) -> dict: + # Cancel active DAG first + cls.cancel_dag(session_id) + + session = cls._sessions.pop(session_id, None) + if not session: + return {"status": "not_found", "cleaned_agents": []} + else: + cleaned = list(session.subagents.keys()) + for name in cleaned: + logger.info("[SubAgent:Cleanup] Cleaned: %s", name) + return {"status": "cleaned", "cleaned_agents": cleaned} + + @classmethod + def register_dag(cls, session_id: str, dag_ctx: DAGExecutionContext) -> None: + """Register a DAG execution context with the session.""" + session = cls._get_or_create_session(session_id) + session.active_dag = dag_ctx + logger.info( + "[SubAgent:DAG] Registered DAG %s for session %s with %d nodes", + dag_ctx.dag_id, + session_id, + len(dag_ctx.nodes), + ) + + @classmethod + def get_active_dag(cls, session_id: str) -> DAGExecutionContext | None: + """Get the active DAG for a session, or None.""" + session = cls.get_session(session_id) + if not session: + return None + return session.active_dag + + @classmethod + def cancel_dag(cls, session_id: str) -> dict: + """Cancel the active DAG, aborting all non-terminal nodes.""" + session = cls.get_session(session_id) + if not session or not session.active_dag: + return {"status": "no_active_dag"} + + dag_ctx = session.active_dag + cancelled = 0 + for node in dag_ctx.nodes.values(): + if node.status.value in ("RUNNING", "READY", "PENDING"): + node.status = type(node.status).SKIPPED + cancelled += 1 + + dag_ctx.status = "CANCELLED" + dag_ctx.completed_at = time.time() + session.dag_history.append(dag_ctx) + session.active_dag = None + + logger.info( + "[SubAgent:DAG] Cancelled DAG %s: %d nodes aborted", + dag_ctx.dag_id, + cancelled, + ) + return { + "status": "cancelled", + "dag_id": dag_ctx.dag_id, + "cancelled_nodes": cancelled, + } + + @classmethod + def remove_subagent(cls, session_id: str, agent_name: str) -> str: + cls._touch_session(session_id) + session = cls.get_session(session_id) + if not session: + return f"{RET_SUBAGENT_REMOVE_FAILED}: Session {session_id} does not exist." + if session.subagent_status.get(agent_name) == SubAgentStatus.RUNNING: + return f"{RET_SUBAGENT_REMOVE_FAILED}: {agent_name} is still RUNNING. Waiting for finish first." + + def _remove_by_name(name): + session.subagents.pop(name, None) + session.protected_agents.discard(name) + session.handoff_tools.pop(name, None) + session.subagent_histories.pop(name, None) + session.subagent_background_results.pop(name, None) + session.background_task_counters.pop(name, None) + session.subagent_traces.pop(name, None) + # 清理公共上下文中包含该Agent的内容 + cls.cleanup_shared_context_by_agent(session_id, name) + + if agent_name == "all": + if SubAgentStatus.RUNNING in session.subagent_status.values(): + removed = 0 + for subagent_name in list(session.subagents.keys()): + if ( + session.subagent_status.get(subagent_name) + == SubAgentStatus.RUNNING + ): + continue + _remove_by_name(subagent_name) + removed += 1 + return f"{RET_SUBAGENT_REMOVED}: Removed {removed} subagents. {len(session.subagents.keys())} subagents are reserved because they are still running." + else: + session.subagents.clear() + session.handoff_tools.clear() + session.protected_agents.clear() + session.subagent_histories.clear() + session.shared_context.clear() + session.subagent_background_results.clear() + session.background_task_counters.clear() + session.subagent_traces.clear() + logger.info("[SubAgent:Cleanup] All subagents cleaned.") + return f"{RET_SUBAGENT_REMOVED}: All subagents have been removed." + else: + if agent_name not in session.subagents: + return f"{RET_SUBAGENT_REMOVE_FAILED}: {agent_name} not found. Available subagent names {list(session.subagents.keys())}" + else: + _remove_by_name(agent_name) + logger.info("[SubAgent:Cleanup] Cleaned: %s", agent_name) + return ( + f"{RET_SUBAGENT_REMOVED}: Subagent {agent_name} has been removed." + ) + + @classmethod + def get_handoff_tools_for_session(cls, session_id: str) -> list: + session = cls.get_session(session_id) + if not session: + return [] + return list(session.handoff_tools.values()) + + @classmethod + def create_pending_subagent_task(cls, session_id: str, agent_name: str) -> str: + """为 SubAgent 创建一个 pending 任务,返回 task_id + + Args: + session_id: Session ID + agent_name: SubAgent 名称 + + Returns: + task_id: 任务ID,格式为简单的递增数字字符串 + """ + session = cls._get_or_create_session(session_id) + + # 初始化 + if agent_name not in session.subagent_background_results: + session.subagent_background_results[agent_name] = {} + if agent_name not in session.background_task_counters: + session.background_task_counters[agent_name] = 0 + + if ( + session.subagent_status[agent_name] == SubAgentStatus.RUNNING + ): # 若当前有任务在运行,不允许创建 + return f"{RET_PENDING_TASK_CREATE_FAILED}: Subagent {agent_name} already running" + + # 生成递增的任务ID + session.background_task_counters[agent_name] += 1 + task_id = str(session.background_task_counters[agent_name]) + + # 创建 pending 占位 + session.subagent_background_results[agent_name][task_id] = ( + SubAgentExecutionResult( + task_id=task_id, + agent_name=agent_name, + success=False, + result=None, + created_at=time.time(), + metadata={}, + ) + ) + + return task_id + + @classmethod + def _ensure_task_store( + cls, session: SubAgentSession, agent_name: str + ) -> dict[str, SubAgentExecutionResult]: + if agent_name not in session.subagent_background_results: + session.subagent_background_results[agent_name] = {} + return session.subagent_background_results[agent_name] + + @staticmethod + def _is_task_completed(result: SubAgentExecutionResult) -> bool: + return result.completed_at > 0 or result.error is not None + + @classmethod + def get_pending_subagent_tasks(cls, session_id: str, agent_name: str) -> list[str]: + """获取 SubAgent 的所有 pending 任务 ID 列表(按创建时间排序)""" + session = cls.get_session(session_id) + if not session: + return [] + + store = session.subagent_background_results.get(agent_name) + if not store: + return [] + + pending = [tid for tid, res in store.items() if not cls._is_task_completed(res)] + return sorted(pending, key=lambda tid: store[tid].created_at) + + @classmethod + def get_latest_task_id(cls, session_id: str, agent_name: str) -> str | None: + """获取 SubAgent 的最新任务 ID""" + session = cls.get_session(session_id) + if not session or agent_name not in session.subagent_background_results: + return None + + # 按 created_at 排序取最新的 + sorted_tasks = sorted( + session.subagent_background_results[agent_name].items(), + key=lambda x: x[1].created_at, + reverse=True, + ) + return sorted_tasks[0][0] if sorted_tasks else None + + @classmethod + def store_subagent_result( + cls, + session_id: str, + agent_name: str, + success: bool, + result: str, + task_id: str | None = None, + error: str | None = None, + execution_time: float = 0.0, + metadata: dict | None = None, + ) -> None: + """存储 SubAgent 的执行结果 + + Args: + session_id: Session ID + agent_name: SubAgent 名称 + success: 是否成功 + result: 执行结果 + task_id: 任务ID,如果为None则存储到最新的pending任务 + error: 错误信息 + execution_time: 执行耗时 + metadata: 额外元数据 + """ + session = cls._get_or_create_session(session_id) + + task_store = cls._ensure_task_store(session, agent_name) + + if task_id is None: + # 如果没有指定task_id,尝试找最新的pending任务 + pending = cls.get_pending_subagent_tasks(session_id, agent_name) + if pending: + task_id = pending[-1] + else: + logger.warning( + f"[SubAgentResult] No task_id and no pending tasks for {agent_name}" + ) + return + + if task_id not in task_store: + # 如果任务不存在,先创建一个占位 + task_store[task_id] = SubAgentExecutionResult( + task_id=task_id, + agent_name=agent_name, + success=False, + result="", + created_at=time.time(), + metadata=metadata or {}, + ) + + # 更新结果 + task_store[task_id].success = success + task_store[task_id].result = result + task_store[task_id].error = error + task_store[task_id].execution_time = execution_time + task_store[task_id].completed_at = time.time() + if metadata: + task_store[task_id].metadata.update(metadata) + + trace = session.subagent_traces.get(agent_name) + if trace: + trace.record( + "subagent_result_stored", + agent_name=agent_name, + task_id=task_id, + success=success, + execution_time=execution_time, + has_error=error is not None, + ) + + @classmethod + def get_subagent_result( + cls, session_id: str, agent_name: str, task_id: str | None = None + ) -> SubAgentExecutionResult | None: + """获取 SubAgent 的执行结果 + + Args: + session_id: Session ID + agent_name: SubAgent 名称 + task_id: 任务ID,如果为None则获取最新完成的任务结果 + + Returns: + SubAgentExecutionResult 或 None + """ + session = cls.get_session(session_id) + if not session or agent_name not in session.subagent_background_results: + return None + + if task_id is None: + # 获取最新的已完成任务 + completed = [ + (tid, r) + for tid, r in session.subagent_background_results[agent_name].items() + if r.result != "" or r.completed_at > 0 + ] + if not completed: + return None + # 按创建时间排序,取最新的 + completed.sort(key=lambda x: x[1].created_at, reverse=True) + return completed[0][1] + + return session.subagent_background_results[agent_name].get(task_id, None) + + @classmethod + def has_subagent_result( + cls, session_id: str, agent_name: str, task_id: str | None = None + ) -> bool: + """检查 SubAgent 是否有结果 + + Args: + session_id: Session ID + agent_name: SubAgent 名称 + task_id: 任务ID,如果为None则检查是否有任何已完成的任务 + """ + session = cls.get_session(session_id) + task_store = cls._ensure_task_store(session, agent_name) + if not session or not task_store: + return False + + if task_id is None: + # 检查是否有任何已完成的任务 + return any( + r.result != "" or r.completed_at > 0 for r in task_store.values() + ) + + if task_id not in task_store: + return False + result = task_store[task_id] + return result.result != "" or result.completed_at > 0 + + @classmethod + def clear_subagent_result( + cls, session_id: str, agent_name: str, task_id: str | None = None + ) -> None: + """清除 SubAgent 的执行结果 + + Args: + session_id: Session ID + agent_name: SubAgent 名称 + task_id: 任务ID,如果为None则清除该Agent所有任务 + """ + session = cls.get_session(session_id) + task_store = cls._ensure_task_store(session, agent_name) + if not session or not task_store: + return + + if task_id is None: + # 清除所有任务 + session.subagent_background_results.pop(agent_name, None) + session.background_task_counters.pop(agent_name, None) + else: + # 清除特定任务 + task_store.pop(task_id, None) + + @classmethod + def get_subagent_status(cls, session_id: str, agent_name: str) -> str: + """获取 SubAgent 的状态: IDLE, RUNNING, COMPLETED, FAILED + + Args: + session_id: Session ID + agent_name: SubAgent 名称 + """ + session = cls.get_session(session_id) + if not session: + return SubAgentStatus.UNKNOWN + return session.subagent_status.get(agent_name, SubAgentStatus.UNKNOWN) + + @classmethod + def get_all_subagent_status(cls, session_id: str) -> dict: + """获取所有 SubAgent 的状态""" + session = cls.get_session(session_id) + if not session: + return {} + return { + name: cls.get_subagent_status(session_id, name) + for name in session.subagents + } diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index c6c595dfc9..18da2f4b6f 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -15,8 +15,9 @@ class SubAgentOrchestrator: """Loads subagent definitions from config and registers handoff tools. - This is intentionally lightweight: it does not execute agents itself. - Execution happens via HandoffTool in FunctionToolExecutor. + Static subagents from config are registered into SubAgentManager so they + can enjoy unified lifecycle management, shared context, history retention, + and other advanced features alongside dynamically created subagents. """ def __init__( @@ -25,6 +26,7 @@ def __init__( self._tool_mgr = tool_mgr self._persona_mgr = persona_mgr self.handoffs: list[HandoffTool] = [] + self.handoff_skills: list[Any] = [] async def reload_from_config(self, cfg: dict[str, Any]) -> None: from astrbot.core.astr_agent_context import AstrAgentContext @@ -35,6 +37,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: return handoffs: list[HandoffTool] = [] + handoff_skills: list[Any] = [] for item in agents: if not isinstance(item, dict): continue @@ -61,6 +64,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: if provider_id is not None: provider_id = str(provider_id).strip() or None tools = item.get("tools", []) + skills = item.get("skills", []) begin_dialogs = None if persona_data: @@ -71,6 +75,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: persona_data.get("_begin_dialogs_processed") ) tools = persona_data.get("tools") + skills = persona_data.get("skills") if public_description == "" and prompt: public_description = prompt[:120] if tools is None: @@ -80,6 +85,12 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: else: tools = [str(t).strip() for t in tools if str(t).strip()] + if skills is None: + skills = None + elif not isinstance(skills, list): + skills = [] + else: + skills = [str(s).strip() for s in skills if str(s).strip()] agent = Agent[AstrAgentContext]( name=name, instructions=instructions, @@ -97,8 +108,50 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: handoff.provider_id = provider_id handoffs.append(handoff) + handoff_skills.append(skills) for handoff in handoffs: logger.info(f"Registered subagent handoff tool: {handoff.name}") self.handoffs = handoffs + self.handoff_skills = handoff_skills + + def register_static_subagents_to_manager(self, session_id: str) -> None: + """Register all static subagents (from config) into SubAgentManager. + + This makes static subagents enjoy the same unified management as + dynamically created subagents: shared context, history retention, + lifecycle management, etc. + + Static subagents are always protected from auto-cleanup. + """ + + try: + from astrbot.core.subagent_manager import SubAgentManager + except ImportError: + return + + for handoff, skills in zip(self.handoffs, self.handoff_skills): + try: + workdir = None + # Try to get skills from the handoff tool or agent + agent = handoff.agent + # The agent.tools may contain skill names; we pass them along + # SubAgentManager will filter and build skills prompt as needed + SubAgentManager.register_static_subagent( + session_id=session_id, + handoff_tool=handoff, + skills=skills, + workdir=workdir, + ) + logger.debug( + "[SubAgentOrchestrator] Registered static subagent '%s' to SubAgentManager for session %s", + agent.name, + session_id, + ) + except Exception as e: + logger.warning( + "[SubAgentOrchestrator] Failed to register static subagent '%s' to manager: %s", + getattr(handoff.agent, "name", "unknown"), + e, + ) diff --git a/astrbot/core/subagent_tools.py b/astrbot/core/subagent_tools.py new file mode 100644 index 0000000000..1b26338a82 --- /dev/null +++ b/astrbot/core/subagent_tools.py @@ -0,0 +1,936 @@ +""" +SubAgent Tools +Tool definitions for SubAgent management. +These tools are used by the main agent to create, manage, and interact with subagents. +""" + +from __future__ import annotations + +import asyncio +import os +import platform +import re +import time +import uuid +from dataclasses import dataclass, field + +from astrbot import logger +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.subagent_dag import ( + DAGExecutionContext, + DAGNodeStatus, + DAGTaskNode, + SubAgentDAGEngine, +) +from astrbot.core.subagent_manager import ( + RET_DYNAMIC_TOOL_CREATE_FAILED, + RET_DYNAMIC_TOOL_CREATED, + RET_HISTORY_CLEARED, + RET_SHARED_CONTEXT_ADDED, + RET_SUBAGENT_REMOVED, + SubAgentConfig, + SubAgentManager, +) + + +@dataclass +class CreateSubAgentTool(FunctionTool): + name: str = "create_subagent" + description: str = "Create a subagent. After creation, use transfer_to_subagent(name=...) to delegate." + + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Subagent name"}, + "system_prompt": { + "type": "string", + "description": "Subagent system_prompt", + }, + "subagent_description": { + "type": "string", + "description": "Brief description of what this subagent does and when to use it (e.g., 'Analyzes Python code for bugs and performance issues'). Shown in list_subagents output.", + }, + "tools": { + "type": "array", + "items": {"type": "string"}, + "description": "Tools available to subagent, can be empty.", + }, + "skills": { + "type": "array", + "items": {"type": "string"}, + "description": "Skills available to subagent, can be empty", + }, + "workdir": { + "type": "string", + "description": "Subagent working directory(absolute path), can be empty(same to main agent). Fill only when the user has clearly specified the path.", + }, + "provider_id": { + "type": "string", + "description": "LLM provider ID for this subagent. If not provided, uses the same provider as the main agent.", + }, + }, + "required": ["name", "system_prompt"], + } + ) + + def _check_path_safety(self, path_str: str) -> bool: + """ + 检查路径是否合法、安全 + """ + if not path_str or not isinstance(path_str, str): + return False + + if not os.path.isabs(path_str): + return False + + try: + resolved = os.path.realpath(path_str) + except (OSError, ValueError): + return False + + # 使用路径组件匹配而非子字符串匹配 + path_parts = {part.lower() for part in os.path.normpath(resolved).split(os.sep)} + + # Windows 特殊目录检查(作为独立的路径组件) + windows_dangerous_components = { + "windows", + "system32", + "syswow64", + "boot", + "recovery", + "programdata", + "$recycle.bin", + "system volume information", + } + + system = platform.system().lower() + if system == "windows": + if path_parts & windows_dangerous_components: + return False + elif system == "linux": + # 检查是否在危险目录下(前缀匹配) + linux_dangerous_prefixes = [ + "/etc", + "/bin", + "/sbin", + "/lib", + "/lib64", + "/boot", + "/dev", + "/proc", + "/sys", + "/root", + ] + resolved_norm = os.path.normpath(resolved) + for prefix in linux_dangerous_prefixes: + if resolved_norm.startswith(prefix + "/") or resolved_norm == prefix: + return False + elif system == "darwin": + darwin_dangerous_prefixes = [ + "/System", + "/Library", + "/private/var", + "/usr", + ] + resolved_norm = os.path.normpath(resolved) + for prefix in darwin_dangerous_prefixes: + if resolved_norm.startswith(prefix + "/") or resolved_norm == prefix: + return False + + # 通用检查:父目录跳转 + if ".." in path_str: + return False + + if not os.path.exists(resolved): + return False + + return True + + async def call(self, context, **kwargs) -> str: + name = kwargs.get("name", "") + + if not name: + return "Error: subagent name required" + if name == "subagent": + return "Error: 'subagent' cannot be a name" + # 验证名称格式:只允许英文字母、数字和下划线,长度限制;避免Windows保留名 + SAFE_IDENTIFIER = re.compile( + r"^(?!^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])$)[a-zA-Z][a-zA-Z0-9_]{0,32}$", + re.IGNORECASE, + ) + if not bool(SAFE_IDENTIFIER.match(name)): + return "Error: SubAgent name must start with letter, contain only letters/numbers/underscores, max 32 characters" + + system_prompt = kwargs.get("system_prompt", "") + subagent_description = kwargs.get("subagent_description", "") + tools = kwargs.get("tools", {}) + skills = kwargs.get("skills", {}) + workdir = kwargs.get("workdir") + provider_id = kwargs.get("provider_id") + + session_id = context.context.event.unified_msg_origin + if not self._check_path_safety(workdir): + workdir = None + config = SubAgentConfig( + name=name, + system_prompt=system_prompt, + tools=set(tools), + skills=set(skills), + workdir=workdir, + provider_id=provider_id, + description=subagent_description, + ) + + tool_name, handoff_tool = await SubAgentManager.create_subagent( + session_id=session_id, config=config + ) + if handoff_tool: + return f"{RET_DYNAMIC_TOOL_CREATED}:{tool_name}:{handoff_tool.name}:Created. Use transfer_to_subagent(name='{name}', ...) to delegate." + else: + return f"{RET_DYNAMIC_TOOL_CREATE_FAILED}:{tool_name}" + + +@dataclass +class RemoveSubagentTool(FunctionTool): + name: str = "remove_subagent" + description: str = "Remove subagent by name. Use 'all' to remove all subagents." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Subagent name to remove. Use 'all' to remove all subagents.", + } + }, + "required": ["name"], + } + ) + + async def call(self, context, **kwargs) -> str: + name = kwargs.get("name", "") + if not name: + return "Error: name required" + session_id = context.context.event.unified_msg_origin + remove_status = SubAgentManager.remove_subagent(session_id, name) + if remove_status.startswith(RET_SUBAGENT_REMOVED): + return f"Cleaned {name} Subagent" + else: + return remove_status + + +@dataclass +class ListSubagentsTool(FunctionTool): + name: str = "list_subagents" + description: str = "List subagents with their status." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "include_status": { + "type": "boolean", + "description": "Include status", + "default": True, + } + }, + } + ) + + async def call(self, context, **kwargs) -> str: + include_status = kwargs.get("include_status", True) + session_id = context.context.event.unified_msg_origin + session = SubAgentManager.get_session(session_id) + if not session or not session.subagents: + return "No subagents available." + + count = len(session.subagents) + lines = [f"Available Subagents ({count}):"] + lines.append("---") + + for name, config in session.subagents.items(): + lines.append(f"Name: {name}") + + if include_status: + status = SubAgentManager.get_subagent_status(session_id, name) + lines.append(f"Status: {status}") + + protected = name in session.protected_agents + lines.append(f"Protected: {'Yes' if protected else 'No'}") + + if config.description: + lines.append(f"Description: {config.description}") + + tools_list = ", ".join(sorted(config.tools)) if config.tools else "none" + lines.append(f"Tools: {tools_list}") + + lines.append("---") + + return "\n".join(lines) + + +@dataclass +class ManageSubagentProtectionTool(FunctionTool): + """Tool to protect or unprotect a subagent from auto cleanup""" + + name: str = "manage_subagent_protection" + description: str = "Protect or unprotect a subagent from automatic cleanup. Use this to prevent important subagents from being removed, or to allow them to be auto cleaned." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Subagent name to manage"}, + "protected": { + "type": "boolean", + "description": "Whether to protect (true) or unprotect (false) the subagent", + }, + }, + "required": ["name", "protected"], + } + ) + + async def call(self, context, **kwargs) -> str: + name = kwargs.get("name", "") + protected = kwargs.get("protected", True) + if not name: + return "Error: name required" + session_id = context.context.event.unified_msg_origin + session = SubAgentManager._get_or_create_session(session_id) + if name not in session.subagents: + return f"Error: Subagent {name} not found. Available subagents: {session.subagents.keys()}" + if protected: + SubAgentManager.protect_subagent(session_id, name) + return f"Subagent {name} is now protected from auto cleanup" + else: + if name in session.protected_agents: + session.protected_agents.discard(name) + return f"Subagent {name} is no longer protected" + return f"Subagent {name} was not protected" + + +@dataclass +class ResetSubAgentTool(FunctionTool): + """Tool to reset a subagent""" + + name: str = "reset_subagent" + description: str = "Reset an existing subagent. This will clean the dialog history of the subagent. Used before assigning a new task to an existing subagent." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Subagent name to reset"}, + }, + "required": ["name"], + } + ) + + async def call(self, context, **kwargs) -> str: + name = kwargs.get("name", "") + if not name: + return "Error: name required" + session_id = context.context.event.unified_msg_origin + reset_status = SubAgentManager.clear_subagent_history(session_id, name) + if reset_status == RET_HISTORY_CLEARED: + return f"Subagent {name} was reset" + else: + return reset_status + + +# Shared Context Tools +@dataclass +class BroadCastSharedContextTool(FunctionTool): + """Tool to send a message to the shared context (visible to all agents)""" + + name: str = "broadcast_shared_context" + description: str = ( + """Send a message to one or all subagents when they are running.""" + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "context_type": { + "type": "string", + "description": "Type of context: message (to other agents), system (global announcement)", + "enum": ["message", "system"], + }, + "content": {"type": "string", "description": "Content to share"}, + "target": { + "type": "string", + "description": "Target agent name or 'all' for broadcast", + "default": "all", + }, + }, + "required": ["context_type", "content", "target"], + } + ) + + async def call(self, context, **kwargs) -> str: + context_type = kwargs.get("context_type", "message") + content = kwargs.get("content", "") + target = kwargs.get("target", "all") + if not content: + return "Error: content is required" + session_id = context.context.event.unified_msg_origin + add_status = SubAgentManager.add_shared_context( + session_id, "System", context_type, content, target + ) + if add_status == RET_SHARED_CONTEXT_ADDED: + return f"Shared context updated: [{context_type}] System -> {target}: {content[:100]}{'...' if len(content) > 100 else ''}" + else: + return add_status + + +@dataclass +class SendSharedContextTool(FunctionTool): + """Tool to send a message to the shared context (visible to all agents)""" + + name: str = "send_shared_context" + description: str = """Send a message to the shared context that will be visible to other subagents. +Use this to share information, status updates, or coordinate with other subagents. +Not used for informing the main agent, return the results directly instead. +""" + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "context_type": { + "type": "string", + "description": "Type of context: `status` (your current task progress), `message` (to other agents)", + "enum": ["status", "message"], + }, + "content": {"type": "string", "description": "Content to share"}, + "sender": { + "type": "string", + "description": "Sender agent name", + "default": "YourName", + }, + "target": { + "type": "string", + "description": "Target agent name or 'all' for broadcast.", + "default": "all", + }, + }, + "required": ["context_type", "content", "sender", "target"], + } + ) + + async def call(self, context, **kwargs) -> str: + context_type = kwargs.get("context_type", "message") + content = kwargs.get("content", "") + target = kwargs.get("target", "all") + sender = kwargs.get("sender", "YourName") + if not content: + return "Error: content is required" + session_id = context.context.event.unified_msg_origin + add_status = SubAgentManager.add_shared_context( + session_id, sender, context_type, content, target + ) + if add_status == RET_SHARED_CONTEXT_ADDED: + return f"Shared context updated: [{context_type}] {sender} -> {target}: {content[:100]}{'...' if len(content) > 100 else ''}" + else: + return add_status + + +@dataclass +class ViewSharedContextTool(FunctionTool): + """Tool to view the shared context (mainly for main agent)""" + + name: str = "view_shared_context" + description: str = """View the shared context between all agents. This shows all messages including status updates, +inter-agent messages, and system announcements.""" + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": {}, + } + ) + + async def call(self, context, **kwargs) -> str: + session_id = context.context.event.unified_msg_origin + shared_context = SubAgentManager.get_shared_context(session_id) + + if not shared_context: + return "Shared context is empty." + + lines = ["=== Shared Context ===\n"] + for msg in shared_context: + ts = time.strftime("%H:%M:%S", time.localtime(msg["timestamp"])) + msg_type = msg["type"] + sender = msg["sender"] + target = msg["target"] + content = msg["content"] + lines.append(f"[{ts}] [{msg_type}] {sender} -> {target}:") + lines.append(f" {content}") + lines.append("") + + return "\n".join(lines) + + +@dataclass +class WaitForSubagentTool(FunctionTool): + """等待 SubAgent 结果的工具""" + + name: str = "wait_for_subagent" + description: str = """Waiting for the execution result of the specified SubAgent. +Usage scenario: +- After assigning a background task to SubAgent, you need to wait for its result before proceeding to the next step. + CAUTION: Whenever you have a task that does not depend on the output of a subagent, please execute THAT TASK FIRST instead of waiting. +- Avoids repeatedly executing tasks that have already been completed by SubAgent +parameter +- subagent_name: The name of the SubAgent to wait for +- task_id: Task ID (optional). If not filled in, the latest task result of the Agent will be obtained. +- timeout: Maximum waiting time (in seconds), default 60 +- poll_interval: polling interval (in seconds), default 5 +""" + + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "subagent_name": { + "type": "string", + "description": "The name of the SubAgent to wait for", + }, + "timeout": { + "type": "number", + "description": "Maximum waiting time (seconds)", + "default": 60, + }, + "poll_interval": { + "type": "number", + "description": "Poll interval (seconds)", + "default": 5, + }, + "task_id": { + "type": "string", + "description": "Task ID (optional; if not filled in, the latest task result will be obtained)", + }, + }, + "required": ["subagent_name"], + } + ) + + async def call(self, context, **kwargs) -> str: + subagent_name = kwargs.get("subagent_name") + if not subagent_name: + return "Error: subagent_name is required" + + task_id = kwargs.get("task_id") # 可选,不填则获取最新的 + timeout = kwargs.get("timeout", 60) + if timeout > 3600 or timeout <= 0: + return "Error: timeout is invalid. Must be between 1 and 3600" + poll_interval = kwargs.get("poll_interval", 5) + if poll_interval > 60 or poll_interval <= 0: + return "Error: poll_interval is invalid. Must be between 1 and 60" + session_id = context.context.event.unified_msg_origin + session = SubAgentManager.get_session(session_id) + + if not session: + return "Error: No session found" + if subagent_name not in session.subagents: + return f"Error: SubAgent '{subagent_name}' not found. Available: {list(session.subagents.keys())}" + + # 如果没有指定 task_id,尝试获取最新创建的 pending 任务 + if not task_id: + pending_tasks = SubAgentManager.get_pending_subagent_tasks( + session_id, subagent_name + ) + if pending_tasks: + # 使用最新的 pending 任务 + task_id = pending_tasks[-1] + else: + # 没有 pending 任务,检查是否有已完成的最新任务 + latest = SubAgentManager.get_subagent_result(session_id, subagent_name) + if latest: + return f"SubAgent '{subagent_name}' has no pending tasks. Latest completed task id: {latest.task_id}. Task id {latest.task_id} Results:\n{latest.result}" + return f"Error: SubAgent '{subagent_name}' has no tasks." + start_time = time.time() + + while time.time() - start_time < timeout: + session = SubAgentManager.get_session(session_id) + if not session: + return "Error: Session Not Found" + if subagent_name not in session.subagents: + return ( + f"Error: SubAgent '{subagent_name}' not found. It may be removed." + ) + + status = SubAgentManager.get_subagent_status(session_id, subagent_name) + + if status == "IDLE": + return f"Error: SubAgent '{subagent_name}' is running no tasks." + elif status == "COMPLETED": + result = SubAgentManager.get_subagent_result( + session_id, subagent_name, task_id + ) + if result and (result.result != "" or result.completed_at > 0): + return f"SubAgent '{result.agent_name}' execution completed\nTask id: {result.task_id}\nExecution time: {result.execution_time:.1f}s\n--- Result ---\n{result.result}\n" + else: + return f"SubAgent '{subagent_name}' task {task_id} execution completed with empty results." + elif status == "FAILED": + result = SubAgentManager.get_subagent_result( + session_id, subagent_name, task_id + ) + if result and (result.result != "" or result.completed_at > 0): + return ( + f"SubAgent '{result.agent_name}' execution failed\n" + f"Task id: {result.task_id}\n" + f"Execution time: {result.execution_time:.1f}s\n" + f"Error: {result.error or 'Unknown error'}\n" + ) + else: + return f"SubAgent '{subagent_name}' failed task {task_id} with empty results. Error: {result.error or 'Unknown error'}" + else: + pass + + await asyncio.sleep(poll_interval) + + target = f"Task {task_id}" + return f"Timeout! SubAgent '{subagent_name}' has not finished '{target}' in {timeout}s. The task may be still running. You can continue waiting by `wait_for_subagent` again." + + +@dataclass +class TransferToSubagentTool(FunctionTool): + """Unified handoff tool that delegates to a specific subagent by name. + + This replaces the dynamic transfer_to_{name} tools with a single fixed tool, + preserving LLM prefix cache across subagent creation. + """ + + name: str = "transfer_to_subagent" + description: str = ( + "Delegate a task to a specific subagent by name. " + "The subagent must have been created previously using create_subagent, " + "or be a statically configured subagent. " + "Use list_subagents to see available subagent names." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the subagent to delegate to. Use list_subagents to see available names.", + }, + "input": { + "type": "string", + "description": "The input/task to be handed off to the subagent. This should be a clear and concise request or task.", + }, + "image_urls": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional: An array of image sources (public HTTP URLs or local file paths) used as references in multimodal tasks.", + }, + "background_task": { + "type": "boolean", + "description": ( + "Defaults to false. " + "Set to true if the task may take noticeable time, involves external tools, or the user does not need to wait. " + "Use false only for quick, immediate tasks." + ), + }, + }, + "required": ["name", "input"], + } + ) + + async def call(self, context, **kwargs) -> str: + # This tool is a "virtual" tool — its actual execution is intercepted by + # FunctionToolExecutor.execute() which resolves the target HandoffTool + # from SubAgentManager and delegates. This call() is a fallback. + name = kwargs.get("name", "") + if not name: + return "Error: subagent name is required for transfer_to_subagent" + return f"Error: transfer_to_subagent execution should be handled by FunctionToolExecutor. name={name}" + + +@dataclass +class OrchestrateTasksTool(FunctionTool): + """Orchestrate multiple subagent tasks with DAG dependency management.""" + + name: str = "orchestrate_tasks" + description: str = ( + "Orchestrate multiple subagent tasks with automatic dependency management." + " Define tasks with their dependencies and the orchestrator will:" + " (1) Automatically determine which tasks can run in parallel," + " (2) Execute dependent tasks sequentially in waves," + " (3) Auto-inject predecessor results as context for successor tasks," + " (4) Aggregate all results into a single summary." + " Use this when you have 2+ subtasks where some produce output that others" + " consume. For simple single-agent delegation, use transfer_to_subagent directly." + ) + + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "tasks": { + "type": "array", + "description": "List of tasks to orchestrate with dependencies", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique task ID, e.g. 'step1'", + }, + "agent": { + "type": "string", + "description": "Target subagent name", + }, + "prompt": { + "type": "string", + "description": "Task description for the subagent", + }, + "depends_on": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "IDs of tasks that must complete first. " + "Results auto-injected as context." + ), + }, + }, + "required": ["id", "agent", "prompt"], + }, + }, + "max_parallel": { + "type": "integer", + "default": 5, + "minimum": 1, + "maximum": 10, + "description": "Maximum concurrent subagents", + }, + }, + "required": ["tasks"], + } + ) + + async def call(self, context, **kwargs) -> str: + tasks_data = kwargs.get("tasks", []) + max_parallel = kwargs.get("max_parallel", 5) + session_id = context.context.event.unified_msg_origin + + if not tasks_data: + return "Error: At least one task is required." + + cfg = self._get_dag_config(context) + max_nodes = cfg.get("dag_max_nodes", 10) + cfg_max_parallel = cfg.get("dag_max_parallel", 5) + + if len(tasks_data) > max_nodes: + return f"Error: Maximum {max_nodes} tasks per DAG. Got {len(tasks_data)}." + + max_parallel = min(max_parallel, cfg_max_parallel) + + active_dag = SubAgentManager.get_active_dag(session_id) + if active_dag and active_dag.status == "RUNNING": + completed = sum( + 1 + for n in active_dag.nodes.values() + if n.status == DAGNodeStatus.COMPLETED + ) + return ( + f"Error: A DAG is already running for this session " + f"(dag_id={active_dag.dag_id[:8]}..., " + f"{completed}/{len(active_dag.nodes)} completed)." + ) + + session = SubAgentManager.get_session(session_id) + if not session: + return "Error: No session found. Create subagents first." + + nodes: list[DAGTaskNode] = [] + for t in tasks_data: + agent_name = t.get("agent", "") + if agent_name not in session.subagents: + available = list(session.subagents.keys()) + return ( + f"Error: SubAgent '{agent_name}' not found. Available: {available}" + ) + node = DAGTaskNode( + id=t["id"], + agent_name=agent_name, + prompt=t["prompt"], + depends_on=t.get("depends_on", []), + ) + nodes.append(node) + + valid, error = SubAgentDAGEngine.validate_dag(nodes) + if not valid: + return f"Error: Invalid DAG — {error}" + + topo_layers = SubAgentDAGEngine._kahn_sort(nodes) + + node_map = {n.id: n for n in nodes} + adj: dict[str, set[str]] = {n.id: set() for n in nodes} + rev_adj: dict[str, set[str]] = {n.id: set() for n in nodes} + for n in nodes: + for dep in n.depends_on: + adj[dep].add(n.id) + rev_adj[n.id].add(dep) + + dag_ctx = DAGExecutionContext( + dag_id=uuid.uuid4().hex[:12], + session_id=session_id, + nodes=node_map, + adjacency=adj, + reverse_adjacency=rev_adj, + topo_layers=topo_layers, + fail_fast=True, + max_parallel=max_parallel, + created_at=time.time(), + ) + + SubAgentManager.register_dag(session_id, dag_ctx) + + # Build launch callback that actually triggers subagent execution + tool_context = context # ContextWrapper[AstrAgentContext] + + async def _launch_dag_node(node, _sid, injected_context, task_id): + import mcp.types as _mcp_types + + session = SubAgentManager.get_session(_sid) + if not session or node.agent_name not in session.handoff_tools: + logger.error(f"[SubAgent:DAG] No handoff tool for {node.agent_name}") + SubAgentManager.store_subagent_result( + _sid, + node.agent_name, + False, + "", + task_id=task_id, + error=f"No handoff for {node.agent_name}", + execution_time=0.0, + ) + return + + handoff = session.handoff_tools[node.agent_name] + prompt = node.prompt + if injected_context: + ctx_text = injected_context[0]["content"] + prompt = ctx_text + "Your task:" + chr(10) + prompt + + try: + from astrbot.core.astr_agent_tool_exec import ( + FunctionToolExecutor, + ) + + result_text = "" + async for r in FunctionToolExecutor._execute_handoff( + tool=handoff, + run_context=tool_context, + input=prompt, + ): + if isinstance(r, _mcp_types.CallToolResult): + for c in r.content: + if isinstance(c, _mcp_types.TextContent): + result_text += c.text + chr(10) + + # Detect task status from subagent output. + # Priority: 1) [TASK RESULT: ...] marker 2) error: prefix 3) empty + success = True + error_reason = None + stripped = result_text.strip() + status_match = re.search( + r"\[TASK\s*RESULT\s*:\s*(SUCCESS|FAILURE)\]", + stripped, + re.IGNORECASE, + ) + if status_match: + success = status_match.group(1).upper() == "SUCCESS" + if not success: + # Extract failure reason for concise error reporting + reason_match = re.search( + r"\[FAILURE\s*REASON\s*:\s*(.+?)\]", + stripped, + re.IGNORECASE, + ) + if reason_match: + error_reason = reason_match.group(1).strip() + else: + error_reason = "No reason provided" + elif not stripped or stripped.lower().startswith("error:"): + success = False + + SubAgentManager.store_subagent_result( + _sid, + node.agent_name, + success, + result_text, + task_id=task_id, + execution_time=0.0, + error=error_reason, + ) + except Exception as e: + logger.error(f"[SubAgent:DAG] Launch error for {node.agent_name}: {e}") + SubAgentManager.store_subagent_result( + _sid, + node.agent_name, + False, + "", + task_id=task_id, + error=str(e), + execution_time=0.0, + ) + + try: + result = await SubAgentDAGEngine.execute_dag( + ctx=dag_ctx, + session_id=session_id, + max_inject_length=cfg.get("dag_max_inject_length", 4000), + launch_fn=_launch_dag_node, + ) + dag_ctx.status = "COMPLETED" if result["failed"] == 0 else "FAILED" + dag_ctx.completed_at = time.time() + + session = SubAgentManager.get_session(session_id) + if session: + session.dag_history.append(dag_ctx) + session.active_dag = None + + return result["formatted"] + + except Exception as e: + logger.error(f"[SubAgent:DAG] Execution error: {e}", exc_info=True) + dag_ctx.status = "FAILED" + dag_ctx.completed_at = time.time() + session = SubAgentManager.get_session(session_id) + if session: + session.dag_history.append(dag_ctx) + session.active_dag = None + return f"Error: DAG execution failed — {e}" + + @staticmethod + def _get_dag_config(context) -> dict: + try: + ctx = context.context.context + cfg = ctx.get_config(umo=context.context.event.unified_msg_origin) + orch_cfg = cfg.get("subagent_orchestrator", {}) + return { + "dag_max_nodes": orch_cfg.get("dag_max_nodes", 10), + "dag_max_parallel": orch_cfg.get("dag_max_parallel", 5), + "dag_max_inject_length": orch_cfg.get("dag_max_inject_length", 4000), + } + except Exception: + return { + "dag_max_nodes": 10, + "dag_max_parallel": 5, + "dag_max_inject_length": 4000, + } + + +ORCHESTRATE_TASKS_TOOL = OrchestrateTasksTool() + + +# Tool instances +CREATE_SUBAGENT_TOOL = CreateSubAgentTool() +REMOVE_SUBAGENT_TOOL = RemoveSubagentTool() +LIST_SUBAGENTS_TOOL = ListSubagentsTool() +RESET_SUBAGENT_TOOL = ResetSubAgentTool() +MANAGE_SUBAGENT_PROTECTION_TOOL = ManageSubagentProtectionTool() +SEND_SHARED_CONTEXT_TOOL = SendSharedContextTool() +BROADCAST_SHARED_CONTEXT_TOOL = BroadCastSharedContextTool() +VIEW_SHARED_CONTEXT_TOOL = ViewSharedContextTool() +WAIT_FOR_SUBAGENT_TOOL = WaitForSubagentTool() +TRANSFER_TO_SUBAGENT_TOOL = TransferToSubagentTool() diff --git a/astrbot/core/utils/trace.py b/astrbot/core/utils/trace.py index 7b095dbc01..e65eb5d208 100644 --- a/astrbot/core/utils/trace.py +++ b/astrbot/core/utils/trace.py @@ -41,12 +41,14 @@ def __init__( umo: str | None = None, sender_name: str | None = None, message_outline: str | None = None, + parent_span_id: str | None = None, ) -> None: self.span_id = str(uuid.uuid4()) self.name = name self.umo = umo self.sender_name = sender_name self.message_outline = message_outline + self.parent_span_id = parent_span_id self.started_at = time.time() def record(self, action: str, **fields: Any) -> None: @@ -59,6 +61,7 @@ def record(self, action: str, **fields: Any) -> None: "level": "TRACE", "time": time.time(), "span_id": self.span_id, + "parent_span_id": self.parent_span_id, "name": self.name, "umo": self.umo, "sender_name": self.sender_name, diff --git a/astrbot/dashboard/routes/subagent.py b/astrbot/dashboard/routes/subagent.py index e3d77f73ad..06c2d24041 100644 --- a/astrbot/dashboard/routes/subagent.py +++ b/astrbot/dashboard/routes/subagent.py @@ -36,21 +36,44 @@ async def get_config(self): data = { "main_enable": False, "remove_main_duplicate_tools": False, + "router_system_prompt": "", "agents": [], + "dynamic_agents": { + "enabled": False, + "max_subagent_count": 3, + "auto_cleanup_per_turn": True, + "default_provider_id": "", + "tools_blacklist": [], + "tools_inherent": [], + }, + "history_enabled": True, + "shared_context_enabled": False, + "shared_context_maxlen": 200, + "subagent_history_maxlen": 500, + "execution_timeout": 600, } - # Backward compatibility: older config used `enable`. - if ( - isinstance(data, dict) - and "main_enable" not in data - and "enable" in data - ): - data["main_enable"] = bool(data.get("enable", False)) - # Ensure required keys exist. data.setdefault("main_enable", False) data.setdefault("remove_main_duplicate_tools", False) + data.setdefault("router_system_prompt", "") data.setdefault("agents", []) + data.setdefault("dynamic_agents", {}) + data.setdefault("history_enabled", True) + data.setdefault("shared_context_enabled", False) + data.setdefault("shared_context_maxlen", 200) + data.setdefault("subagent_history_maxlen", 500) + data.setdefault("execution_timeout", 600) + + # Ensure dynamic_agents sub-keys exist. + dyn = data["dynamic_agents"] + if isinstance(dyn, dict): + dyn.setdefault("enabled", False) + dyn.setdefault("max_subagent_count", 3) + dyn.setdefault("auto_cleanup_per_turn", True) + dyn.setdefault("default_provider_id", "") + dyn.setdefault("tools_blacklist", []) + dyn.setdefault("tools_inherent", []) # Backward/forward compatibility: ensure each agent contains provider_id. # None means follow global/default provider settings. @@ -97,7 +120,7 @@ async def get_available_tools(self): tools_dict = [] for tool in tool_mgr.func_list: # Prevent recursive routing: subagents should not be able to select - # the handoff (transfer_to_*) tools as their own mounted tools. + # the handoff (transfer_to_subagent) tools as their own mounted tools. if isinstance(tool, HandoffTool): continue if tool.handler_module_path == "core.subagent_orchestrator": diff --git a/dashboard/src/components/shared/TraceDisplayer.vue b/dashboard/src/components/shared/TraceDisplayer.vue index 62c57ef479..d356015fc2 100644 --- a/dashboard/src/components/shared/TraceDisplayer.vue +++ b/dashboard/src/components/shared/TraceDisplayer.vue @@ -6,17 +6,21 @@ import { EventSourcePolyfill } from 'event-source-polyfill';