diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 1fb4b03368..18ac1a446a 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -25,7 +25,6 @@ LOCAL_EXECUTE_SHELL_TOOL, LOCAL_PYTHON_TOOL, PYTHON_TOOL, - SEND_MESSAGE_TO_USER_TOOL, ) from astrbot.core.cron.events import CronMessageEvent from astrbot.core.message.components import Image @@ -37,6 +36,7 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.tools.message_tools import SendMessageToUserTool from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history from astrbot.core.utils.image_ref_utils import is_supported_image_ref @@ -515,7 +515,9 @@ async def _wake_main_agent_for_background_result( ) if not req.func_tool: req.func_tool = ToolSet() - req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) + req.func_tool.add_tool( + ctx.get_llm_tool_manager().get_builtin_tool(SendMessageToUserTool) + ) result = await build_main_agent( event=cron_event, plugin_context=ctx, config=config, req=req diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index bd0c780ecc..75f5d30e2a 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -32,7 +32,6 @@ FILE_UPLOAD_TOOL, GET_EXECUTION_HISTORY_TOOL, GET_SKILL_PAYLOAD_TOOL, - KNOWLEDGE_BASE_QUERY_TOOL, LIST_SKILL_CANDIDATES_TOOL, LIST_SKILL_RELEASES_TOOL, LIVE_MODE_SYSTEM_PROMPT, @@ -44,11 +43,9 @@ ROLLBACK_SKILL_RELEASE_TOOL, RUN_BROWSER_SKILL_TOOL, SANDBOX_MODE_PROMPT, - SEND_MESSAGE_TO_USER_TOOL, SYNC_SKILL_RELEASE_TOOL, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, - retrieve_knowledge_base, ) from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Record, Reply @@ -63,16 +60,21 @@ from astrbot.core.star.context import Context from astrbot.core.star.star_handler import star_map from astrbot.core.tools.cron_tools import ( - CREATE_CRON_JOB_TOOL, - DELETE_CRON_JOB_TOOL, - LIST_CRON_JOBS_TOOL, + CreateActiveCronTool, + DeleteCronJobTool, + ListCronJobsTool, +) +from astrbot.core.tools.knowledge_base_tools import ( + KnowledgeBaseQueryTool, + retrieve_knowledge_base, ) +from astrbot.core.tools.message_tools import SendMessageToUserTool from astrbot.core.tools.web_search_tools import ( - TAVILY_EXTRACT_WEB_PAGE_TOOL, - WEB_SEARCH_BAIDU_TOOL, - WEB_SEARCH_BOCHA_TOOL, - WEB_SEARCH_BRAVE_TOOL, - WEB_SEARCH_TAVILY_TOOL, + BaiduWebSearchTool, + BochaWebSearchTool, + BraveWebSearchTool, + TavilyExtractWebPageTool, + TavilyWebSearchTool, normalize_legacy_web_search_config, ) from astrbot.core.utils.file_extract import extract_file_moonshotai @@ -226,7 +228,11 @@ async def _apply_kb( else: if req.func_tool is None: req.func_tool = ToolSet() - req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) + req.func_tool.add_tool( + plugin_context.get_llm_tool_manager().get_builtin_tool( + KnowledgeBaseQueryTool + ) + ) async def _apply_file_extract( @@ -1054,12 +1060,13 @@ def _apply_sandbox_tools( req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n" -def _proactive_cron_job_tools(req: ProviderRequest) -> None: +def _proactive_cron_job_tools(req: ProviderRequest, plugin_context: Context) -> None: if req.func_tool is None: req.func_tool = ToolSet() - req.func_tool.add_tool(CREATE_CRON_JOB_TOOL) - req.func_tool.add_tool(DELETE_CRON_JOB_TOOL) - req.func_tool.add_tool(LIST_CRON_JOBS_TOOL) + tool_mgr = plugin_context.get_llm_tool_manager() + req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateActiveCronTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(DeleteCronJobTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListCronJobsTool)) async def _apply_web_search_tools( @@ -1077,16 +1084,17 @@ async def _apply_web_search_tools( if req.func_tool is None: req.func_tool = ToolSet() + tool_mgr = plugin_context.get_llm_tool_manager() provider = prov_settings.get("websearch_provider", "tavily") if provider == "tavily": - req.func_tool.add_tool(WEB_SEARCH_TAVILY_TOOL) - req.func_tool.add_tool(TAVILY_EXTRACT_WEB_PAGE_TOOL) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(TavilyWebSearchTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(TavilyExtractWebPageTool)) elif provider == "bocha": - req.func_tool.add_tool(WEB_SEARCH_BOCHA_TOOL) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(BochaWebSearchTool)) elif provider == "brave": - req.func_tool.add_tool(WEB_SEARCH_BRAVE_TOOL) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(BraveWebSearchTool)) elif provider == "baidu_ai_search": - req.func_tool.add_tool(WEB_SEARCH_BAIDU_TOOL) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool)) def _get_compress_provider( @@ -1348,12 +1356,16 @@ async def build_main_agent( ) if config.add_cron_tools: - _proactive_cron_job_tools(req) + _proactive_cron_job_tools(req, plugin_context) if event.platform_meta.support_proactive_message: if req.func_tool is None: req.func_tool = ToolSet() - req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) + req.func_tool.add_tool( + plugin_context.get_llm_tool_manager().get_builtin_tool( + SendMessageToUserTool + ) + ) if provider.provider_config.get("max_context_tokens", 0) <= 0: model = provider.get_model() diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 09e77b4cbe..4d1e59c291 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -1,17 +1,5 @@ import base64 -import json -import os -import uuid -from pydantic import Field -from pydantic.dataclasses import dataclass - -import astrbot.core.message.components as Comp -from astrbot.api import logger, sp -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import FunctionTool, ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter from astrbot.core.computer.tools import ( AnnotateExecutionTool, BrowserBatchExecTool, @@ -33,11 +21,6 @@ RunBrowserSkillTool, SyncSkillReleaseTool, ) -from astrbot.core.knowledge_base.kb_helper import KBHelper -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.platform.message_session import MessageSession -from astrbot.core.star.context import Context -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path LLM_SAFETY_MODE_SYSTEM_PROMPT = """You are running in Safe Mode. @@ -148,352 +131,6 @@ ) -@dataclass -class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): - name: str = "astr_kb_search" - description: str = ( - "Query the knowledge base for facts or relevant context. " - "Use this tool when the user's question requires factual information, " - "definitions, background knowledge, or previously indexed content. " - "Only send short keywords or a concise question as the query." - ) - parameters: dict = Field( - default_factory=lambda: { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "A concise keyword query for the knowledge base.", - }, - }, - "required": ["query"], - } - ) - - async def call( - self, context: ContextWrapper[AstrAgentContext], **kwargs - ) -> ToolExecResult: - query = kwargs.get("query", "") - if not query: - return "error: Query parameter is empty." - result = await retrieve_knowledge_base( - query=kwargs.get("query", ""), - umo=context.context.event.unified_msg_origin, - context=context.context.context, - ) - if not result: - return "No relevant knowledge found." - return result - - -@dataclass -class SendMessageToUserTool(FunctionTool[AstrAgentContext]): - name: str = "send_message_to_user" - description: str = ( - "Send message to the user. " - "Supports various message types including `plain`, `image`, `record`, `video`, `file`, and `mention_user`. " - "Use this tool to send media files (`image`, `record`, `video`, `file`), " - "or when you need to proactively message the user(such as cron job). For normal text replies, you can output directly." - ) - - parameters: dict = Field( - default_factory=lambda: { - "type": "object", - "properties": { - "messages": { - "type": "array", - "description": "An ordered list of message components to send. `mention_user` type can be used to mention the user.", - "items": { - "type": "object", - "properties": { - "type": { - "type": "string", - "description": ( - "Component type. One of: " - "plain, image, record, video, file, mention_user. Record is voice message." - ), - }, - "text": { - "type": "string", - "description": "Text content for `plain` type.", - }, - "path": { - "type": "string", - "description": "File path for `image`, `record`, or `file` types. Both local path and sandbox path are supported.", - }, - "url": { - "type": "string", - "description": "URL for `image`, `record`, or `file` types.", - }, - "mention_user_id": { - "type": "string", - "description": "User ID to mention for `mention_user` type.", - }, - }, - "required": ["type"], - }, - }, - }, - "required": ["messages"], - } - ) - - async def _resolve_path_from_sandbox( - self, context: ContextWrapper[AstrAgentContext], path: str - ) -> tuple[str, bool]: - """ - If the path exists locally, return it directly. - Otherwise, check if it exists in the sandbox and download it. - - bool: indicates whether the file was downloaded from sandbox. - """ - if os.path.exists(path): - return path, False - - # Try to check if the file exists in the sandbox - try: - sb = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - # Use shell to check if the file exists in sandbox - result = await sb.shell.exec(f"test -f {path} && echo '_&exists_'") - if "_&exists_" in json.dumps(result): - # Download the file from sandbox - name = os.path.basename(path) - local_path = os.path.join( - get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}" - ) - await sb.download_file(path, local_path) - logger.info(f"Downloaded file from sandbox: {path} -> {local_path}") - return local_path, True - except Exception as e: - logger.warning(f"Failed to check/download file from sandbox: {e}") - - # Return the original path (will likely fail later, but that's expected) - return path, False - - async def call( - self, context: ContextWrapper[AstrAgentContext], **kwargs - ) -> ToolExecResult: - session = kwargs.get("session") or context.context.event.unified_msg_origin - messages = kwargs.get("messages") - - if not isinstance(messages, list) or not messages: - return "error: messages parameter is empty or invalid." - - components: list[Comp.BaseMessageComponent] = [] - - for idx, msg in enumerate(messages): - if not isinstance(msg, dict): - return f"error: messages[{idx}] should be an object." - - msg_type = str(msg.get("type", "")).lower() - if not msg_type: - return f"error: messages[{idx}].type is required." - - file_from_sandbox = False - - try: - if msg_type == "plain": - text = str(msg.get("text", "")).strip() - if not text: - return f"error: messages[{idx}].text is required for plain component." - components.append(Comp.Plain(text=text)) - elif msg_type == "image": - path = msg.get("path") - url = msg.get("url") - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.Image.fromFileSystem(path=local_path)) - elif url: - components.append(Comp.Image.fromURL(url=url)) - else: - return f"error: messages[{idx}] must include path or url for image component." - elif msg_type == "record": - path = msg.get("path") - url = msg.get("url") - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.Record.fromFileSystem(path=local_path)) - elif url: - components.append(Comp.Record.fromURL(url=url)) - else: - return f"error: messages[{idx}] must include path or url for record component." - elif msg_type == "video": - path = msg.get("path") - url = msg.get("url") - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.Video.fromFileSystem(path=local_path)) - elif url: - components.append(Comp.Video.fromURL(url=url)) - else: - return f"error: messages[{idx}] must include path or url for video component." - elif msg_type == "file": - path = msg.get("path") - url = msg.get("url") - name = ( - msg.get("text") - or (os.path.basename(path) if path else "") - or (os.path.basename(url) if url else "") - or "file" - ) - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.File(name=name, file=local_path)) - elif url: - components.append(Comp.File(name=name, url=url)) - else: - return f"error: messages[{idx}] must include path or url for file component." - elif msg_type == "mention_user": - mention_user_id = msg.get("mention_user_id") - if not mention_user_id: - return f"error: messages[{idx}].mention_user_id is required for mention_user component." - components.append( - Comp.At( - qq=mention_user_id, - ), - ) - else: - return ( - f"error: unsupported message type '{msg_type}' at index {idx}." - ) - except Exception as exc: # 捕获组件构造异常,避免直接抛出 - return f"error: failed to build messages[{idx}] component: {exc}" - - try: - target_session = ( - MessageSession.from_str(session) - if isinstance(session, str) - else session - ) - except Exception as e: - return f"error: invalid session: {e}" - - await context.context.context.send_message( - target_session, - MessageChain(chain=components), - ) - - # if file_from_sandbox: - # try: - # os.remove(local_path) - # except Exception as e: - # logger.error(f"Error removing temp file {local_path}: {e}") - - return f"Message sent to session {target_session}" - - -def check_all_kb(kb_list: list[KBHelper | None]) -> bool: - """检查是否所有的知识库都为空 - Args: - kb_list: 所选的知识库 - Returns: - bool: 是否全为空 - """ - return not any( - kb and (kb.kb.doc_count != 0 or kb.kb.chunk_count != 0) for kb in kb_list - ) - - -async def retrieve_knowledge_base( - query: str, - umo: str, - context: Context, -) -> str | None: - """Inject knowledge base context into the provider request - - Args: - umo: Unique message object (session ID) - p_ctx: Pipeline context - """ - kb_mgr = context.kb_manager - config = context.get_config(umo=umo) - - # 1. 优先读取会话级配置 - session_config = await sp.session_get(umo, "kb_config", default={}) - - if session_config and "kb_ids" in session_config: - # 会话级配置 - kb_ids = session_config.get("kb_ids", []) - - # 如果配置为空列表,明确表示不使用知识库 - if not kb_ids: - logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库") - return - - top_k = session_config.get("top_k", 5) - - # 将 kb_ids 转换为 kb_names - kb_names = [] - invalid_kb_ids = [] - for kb_id in kb_ids: - kb_helper = await kb_mgr.get_kb(kb_id) - if kb_helper: - kb_names.append(kb_helper.kb.kb_name) - else: - logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}") - invalid_kb_ids.append(kb_id) - - if invalid_kb_ids: - logger.warning( - f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}", - ) - - if not kb_names: - return - - logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}") - else: - kb_names = config.get("kb_names", []) - top_k = config.get("kb_final_top_k", 5) - logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}") - - top_k_fusion = config.get("kb_fusion_top_k", 20) - - if not kb_names: - return - - all_kbs = [await kb_mgr.get_kb_by_name(kb) for kb in kb_names] - - if check_all_kb(all_kbs): - logger.debug("所配置的所有知识库全为空,跳过检索过程") - return - - logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}") - kb_context = await kb_mgr.retrieve( - query=query, - kb_names=kb_names, - top_k_fusion=top_k_fusion, - top_m_final=top_k, - ) - - if not kb_context: - return - - formatted = kb_context.get("context_text", "") - if formatted: - results = kb_context.get("results", []) - logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块") - return formatted - - -KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool() -SEND_MESSAGE_TO_USER_TOOL = SendMessageToUserTool() - EXECUTE_SHELL_TOOL = ExecuteShellTool() LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True) PYTHON_TOOL = PythonTool() diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index ff7facd247..c86fc160fa 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -275,8 +275,8 @@ async def _woke_main_agent( ) from astrbot.core.astr_main_agent_resources import ( PROACTIVE_AGENT_CRON_WOKE_SYSTEM_PROMPT, - SEND_MESSAGE_TO_USER_TOOL, ) + from astrbot.core.tools.message_tools import SendMessageToUserTool try: session = ( @@ -342,7 +342,9 @@ async def _woke_main_agent( ) if not req.func_tool: req.func_tool = ToolSet() - req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) + req.func_tool.add_tool( + self.ctx.get_llm_tool_manager().get_builtin_tool(SendMessageToUserTool) + ) result = await build_main_agent( event=cron_event, plugin_context=self.ctx, config=config, req=req diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index b93d6ca2e1..bf16a3ec96 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -17,6 +17,12 @@ from astrbot.core import sp from astrbot.core.agent.mcp_client import MCPClient, MCPTool from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.tools.registry import ( + ensure_builtin_tools_loaded, + get_builtin_tool_class, + get_builtin_tool_name, + iter_builtin_tool_classes, +) from astrbot.core.utils.astrbot_path import get_astrbot_data_path DEFAULT_MCP_CONFIG = {"mcpServers": {}} @@ -207,8 +213,12 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class FunctionToolManager: def __init__(self) -> None: self.func_list: list[FuncTool] = [] + """All tools include mcp tools and plugin tools, except astrbot builtin tools.""" + self.builtin_func_list: dict[type[FuncTool], FuncTool] = {} + """All astrbot builtin tools, keyed by their class. Values are instantiated tool objects, created on demand.""" + self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {} - """MCP 服务运行时状态(唯一事实来源)""" + """MCP runtime metadata, keyed by server name. Updated atomically on MCP lifecycle changes.""" self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime) self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime) self._timeout_mismatch_warned = False @@ -320,8 +330,50 @@ def get_func(self, name) -> FuncTool | None: for f in reversed(self.func_list): if f.name == name: return f + if isinstance(name, str): + try: + builtin_tool = self.get_builtin_tool(name) + except KeyError: + return None + if getattr(builtin_tool, "active", True): + return builtin_tool + return builtin_tool return None + def get_builtin_tool(self, tool: str | type[FuncTool]) -> FuncTool: + ensure_builtin_tools_loaded() + + if isinstance(tool, str): + tool_cls = get_builtin_tool_class(tool) + if tool_cls is None: + raise KeyError(f"Builtin tool {tool} is not registered.") + elif isinstance(tool, type) and issubclass(tool, FunctionTool): + tool_cls = tool + if get_builtin_tool_name(tool_cls) is None: + raise KeyError( + f"Builtin tool class {tool_cls.__module__}.{tool_cls.__name__} is not registered.", + ) + else: + raise TypeError("tool must be a builtin tool name or FunctionTool class.") + + cached_tool = self.builtin_func_list.get(tool_cls) + if cached_tool is not None: + return cached_tool + + builtin_tool = tool_cls() # type: ignore + self.builtin_func_list[tool_cls] = builtin_tool + return builtin_tool + + def iter_builtin_tools(self) -> list[FuncTool]: + ensure_builtin_tools_loaded() + return [ + self.get_builtin_tool(tool_cls) for tool_cls in iter_builtin_tool_classes() + ] + + def is_builtin_tool(self, name: str) -> bool: + ensure_builtin_tools_loaded() + return get_builtin_tool_class(name) is not None + def get_full_tool_set(self) -> ToolSet: """获取完整工具集 diff --git a/astrbot/core/tools/cron_tools.py b/astrbot/core/tools/cron_tools.py index b939b53fa8..599957e0ab 100644 --- a/astrbot/core/tools/cron_tools.py +++ b/astrbot/core/tools/cron_tools.py @@ -7,6 +7,7 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.tools.registry import builtin_tool def _extract_job_session(job: Any) -> str | None: @@ -17,6 +18,7 @@ def _extract_job_session(job: Any) -> str | None: return str(session) if session is not None else None +@builtin_tool @dataclass class CreateActiveCronTool(FunctionTool[AstrAgentContext]): name: str = "create_future_task" @@ -105,6 +107,7 @@ async def call( return f"Scheduled future task {job.job_id} ({job.name}) {suffix}." +@builtin_tool @dataclass class DeleteCronJobTool(FunctionTool[AstrAgentContext]): name: str = "delete_future_task" @@ -141,6 +144,7 @@ async def call( return f"Deleted cron job {job_id}." +@builtin_tool @dataclass class ListCronJobsTool(FunctionTool[AstrAgentContext]): name: str = "list_future_tasks" @@ -180,14 +184,7 @@ async def call( return "\n".join(lines) -CREATE_CRON_JOB_TOOL = CreateActiveCronTool() -DELETE_CRON_JOB_TOOL = DeleteCronJobTool() -LIST_CRON_JOBS_TOOL = ListCronJobsTool() - __all__ = [ - "CREATE_CRON_JOB_TOOL", - "DELETE_CRON_JOB_TOOL", - "LIST_CRON_JOBS_TOOL", "CreateActiveCronTool", "DeleteCronJobTool", "ListCronJobsTool", diff --git a/astrbot/core/tools/knowledge_base_tools.py b/astrbot/core/tools/knowledge_base_tools.py new file mode 100644 index 0000000000..e27a883d4a --- /dev/null +++ b/astrbot/core/tools/knowledge_base_tools.py @@ -0,0 +1,129 @@ +from pydantic import Field +from pydantic.dataclasses import dataclass + +from astrbot.api import logger, sp +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool, ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.knowledge_base.kb_helper import KBHelper +from astrbot.core.star.context import Context +from astrbot.core.tools.registry import builtin_tool + + +def check_all_kb(kb_list: list[KBHelper | None]) -> bool: + """检查是否所有的知识库都为空""" + return not any( + kb and (kb.kb.doc_count != 0 or kb.kb.chunk_count != 0) for kb in kb_list + ) + + +async def retrieve_knowledge_base( + query: str, + umo: str, + context: Context, +) -> str | None: + """Retrieve knowledge base context for the given query.""" + kb_mgr = context.kb_manager + config = context.get_config(umo=umo) + + session_config = await sp.session_get(umo, "kb_config", default={}) + if session_config and "kb_ids" in session_config: + kb_ids = session_config.get("kb_ids", []) + if not kb_ids: + logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库") + return None + + top_k = session_config.get("top_k", 5) + kb_names = [] + invalid_kb_ids = [] + for kb_id in kb_ids: + kb_helper = await kb_mgr.get_kb(kb_id) + if kb_helper: + kb_names.append(kb_helper.kb.kb_name) + else: + logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}") + invalid_kb_ids.append(kb_id) + + if invalid_kb_ids: + logger.warning( + f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}", + ) + if not kb_names: + return None + logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}") + else: + kb_names = config.get("kb_names", []) + top_k = config.get("kb_final_top_k", 5) + logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}") + + top_k_fusion = config.get("kb_fusion_top_k", 20) + if not kb_names: + return None + + all_kbs = [await kb_mgr.get_kb_by_name(kb) for kb in kb_names] + if check_all_kb(all_kbs): + logger.debug("所配置的所有知识库全为空,跳过检索过程") + return None + + logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}") + kb_context = await kb_mgr.retrieve( + query=query, + kb_names=kb_names, + top_k_fusion=top_k_fusion, + top_m_final=top_k, + ) + if not kb_context: + return None + + formatted = kb_context.get("context_text", "") + if formatted: + results = kb_context.get("results", []) + logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块") + return formatted + return None + + +@builtin_tool +@dataclass +class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): + name: str = "astr_kb_search" + description: str = ( + "Query the knowledge base for facts or relevant context. " + "Use this tool when the user's question requires factual information, " + "definitions, background knowledge, or previously indexed content. " + "Only send short keywords or a concise question as the query." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A concise keyword query for the knowledge base.", + }, + }, + "required": ["query"], + } + ) + + async def call( + self, context: ContextWrapper[AstrAgentContext], **kwargs + ) -> ToolExecResult: + query = kwargs.get("query", "") + if not query: + return "error: Query parameter is empty." + result = await retrieve_knowledge_base( + query=query, + umo=context.context.event.unified_msg_origin, + context=context.context.context, + ) + if not result: + return "No relevant knowledge found." + return result + + +__all__ = [ + "KnowledgeBaseQueryTool", + "check_all_kb", + "retrieve_knowledge_base", +] diff --git a/astrbot/core/tools/message_tools.py b/astrbot/core/tools/message_tools.py new file mode 100644 index 0000000000..020c1ad5a0 --- /dev/null +++ b/astrbot/core/tools/message_tools.py @@ -0,0 +1,210 @@ +import json +import os +import shlex +import uuid + +from pydantic import Field +from pydantic.dataclasses import dataclass + +import astrbot.core.message.components as Comp +from astrbot.api import logger +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool, ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer.computer_client import get_booter +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.tools.registry import builtin_tool +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + + +@builtin_tool +@dataclass +class SendMessageToUserTool(FunctionTool[AstrAgentContext]): + name: str = "send_message_to_user" + description: str = ( + "Send message to the user. " + "Supports various message types including `plain`, `image`, `record`, `video`, `file`, and `mention_user`. " + "Use this tool to send media files (`image`, `record`, `video`, `file`), " + "or when you need to proactively message the user(such as cron job). For normal text replies, you can output directly." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "messages": { + "type": "array", + "description": "An ordered list of message components to send. `mention_user` type can be used to mention the user.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": ( + "Component type. One of: " + "plain, image, record, video, file, mention_user. Record is voice message." + ), + }, + "text": { + "type": "string", + "description": "Text content for `plain` type.", + }, + "path": { + "type": "string", + "description": "File path for `image`, `record`, `video`, or `file` types. Both local path and sandbox path are supported.", + }, + "url": { + "type": "string", + "description": "URL for `image`, `record`, `video`, or `file` types.", + }, + "mention_user_id": { + "type": "string", + "description": "User ID to mention for `mention_user` type.", + }, + }, + "required": ["type"], + }, + }, + "session": { + "type": "string", + "description": "Optional. Target session string. Defaults to current session.", + }, + }, + "required": ["messages"], + } + ) + + async def _resolve_path_from_sandbox( + self, context: ContextWrapper[AstrAgentContext], path: str + ) -> tuple[str, bool]: + if os.path.exists(path): + return path, False + + try: + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + quoted_path = shlex.quote(path) + result = await sb.shell.exec(f"test -f {quoted_path} && echo '_&exists_'") + if "_&exists_" in json.dumps(result): + name = os.path.basename(path) + local_path = os.path.join( + get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}" + ) + await sb.download_file(path, local_path) + logger.info(f"Downloaded file from sandbox: {path} -> {local_path}") + return local_path, True + except Exception as exc: + logger.warning(f"Failed to check/download file from sandbox: {exc}") + + return path, False + + async def call( + self, context: ContextWrapper[AstrAgentContext], **kwargs + ) -> ToolExecResult: + session = kwargs.get("session") or context.context.event.unified_msg_origin + messages = kwargs.get("messages") + if not isinstance(messages, list) or not messages: + return "error: messages parameter is empty or invalid." + + components: list[Comp.BaseMessageComponent] = [] + for idx, msg in enumerate(messages): + if not isinstance(msg, dict): + return f"error: messages[{idx}] should be an object." + + msg_type = str(msg.get("type", "")).lower() + if not msg_type: + return f"error: messages[{idx}].type is required." + + try: + if msg_type == "plain": + text = str(msg.get("text", "")).strip() + if not text: + return f"error: messages[{idx}].text is required for plain component." + components.append(Comp.Plain(text=text)) + elif msg_type == "image": + path = msg.get("path") + url = msg.get("url") + if path: + local_path, _ = await self._resolve_path_from_sandbox( + context, path + ) + components.append(Comp.Image.fromFileSystem(path=local_path)) + elif url: + components.append(Comp.Image.fromURL(url=url)) + else: + return f"error: messages[{idx}] must include path or url for image component." + elif msg_type == "record": + path = msg.get("path") + url = msg.get("url") + if path: + local_path, _ = await self._resolve_path_from_sandbox( + context, path + ) + components.append(Comp.Record.fromFileSystem(path=local_path)) + elif url: + components.append(Comp.Record.fromURL(url=url)) + else: + return f"error: messages[{idx}] must include path or url for record component." + elif msg_type == "video": + path = msg.get("path") + url = msg.get("url") + if path: + local_path, _ = await self._resolve_path_from_sandbox( + context, path + ) + components.append(Comp.Video.fromFileSystem(path=local_path)) + elif url: + components.append(Comp.Video.fromURL(url=url)) + else: + return f"error: messages[{idx}] must include path or url for video component." + elif msg_type == "file": + path = msg.get("path") + url = msg.get("url") + name = ( + msg.get("text") + or (os.path.basename(path) if path else "") + or (os.path.basename(url) if url else "") + or "file" + ) + if path: + local_path, _ = await self._resolve_path_from_sandbox( + context, path + ) + components.append(Comp.File(name=name, file=local_path)) + elif url: + components.append(Comp.File(name=name, url=url)) + else: + return f"error: messages[{idx}] must include path or url for file component." + elif msg_type == "mention_user": + mention_user_id = msg.get("mention_user_id") + if not mention_user_id: + return f"error: messages[{idx}].mention_user_id is required for mention_user component." + components.append(Comp.At(qq=mention_user_id)) + else: + return ( + f"error: unsupported message type '{msg_type}' at index {idx}." + ) + except Exception as exc: + return f"error: failed to build messages[{idx}] component: {exc}" + + try: + target_session = ( + MessageSession.from_str(session) + if isinstance(session, str) + else session + ) + except Exception as exc: + return f"error: invalid session: {exc}" + + await context.context.context.send_message( + target_session, + MessageChain(chain=components), + ) + return f"Message sent to session {target_session}" + + +__all__ = [ + "SendMessageToUserTool", +] diff --git a/astrbot/core/tools/registry.py b/astrbot/core/tools/registry.py new file mode 100644 index 0000000000..eaca4af144 --- /dev/null +++ b/astrbot/core/tools/registry.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from importlib import import_module +from typing import TypeVar + +from astrbot.core.agent.tool import FunctionTool + +TFunctionTool = TypeVar("TFunctionTool", bound=type[FunctionTool]) + +_BUILTIN_TOOL_MODULES = ( + "astrbot.core.tools.cron_tools", + "astrbot.core.tools.knowledge_base_tools", + "astrbot.core.tools.message_tools", + "astrbot.core.tools.web_search_tools", +) + +_builtin_tool_classes_by_name: dict[str, type[FunctionTool]] = {} +_builtin_tool_names_by_class: dict[type[FunctionTool], str] = {} +_builtin_tools_loaded = False + + +def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str: + tool_name = getattr(tool_cls, "name", None) + if isinstance(tool_name, str) and tool_name: + return tool_name + + dataclass_fields = getattr(tool_cls, "__dataclass_fields__", {}) + name_field = dataclass_fields.get("name") + if name_field is not None and isinstance(name_field.default, str): + return name_field.default + + raise ValueError( + f"Builtin tool class {tool_cls.__module__}.{tool_cls.__name__} does not define a valid name.", + ) + + +def builtin_tool(tool_cls: TFunctionTool) -> TFunctionTool: + tool_name = _resolve_builtin_tool_name(tool_cls) + existing = _builtin_tool_classes_by_name.get(tool_name) + if existing is not None and existing is not tool_cls: + raise ValueError( + f"Builtin tool name conflict detected: {tool_name} is already registered by " + f"{existing.__module__}.{existing.__name__}.", + ) + + _builtin_tool_classes_by_name[tool_name] = tool_cls + _builtin_tool_names_by_class[tool_cls] = tool_name + return tool_cls + + +def ensure_builtin_tools_loaded() -> None: + global _builtin_tools_loaded + if _builtin_tools_loaded: + return + + for module_name in _BUILTIN_TOOL_MODULES: + import_module(module_name) + + _builtin_tools_loaded = True + + +def get_builtin_tool_class(name: str) -> type[FunctionTool] | None: + ensure_builtin_tools_loaded() + return _builtin_tool_classes_by_name.get(name) + + +def get_builtin_tool_name(tool_cls: type[FunctionTool]) -> str | None: + ensure_builtin_tools_loaded() + return _builtin_tool_names_by_class.get(tool_cls) + + +def iter_builtin_tool_classes() -> tuple[type[FunctionTool], ...]: + ensure_builtin_tools_loaded() + return tuple(_builtin_tool_classes_by_name.values()) + + +__all__ = [ + "builtin_tool", + "ensure_builtin_tools_loaded", + "get_builtin_tool_class", + "get_builtin_tool_name", + "iter_builtin_tool_classes", +] diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 1aa7f9bc70..5ca8c3e08e 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -11,6 +11,7 @@ from astrbot.core import logger, sp from astrbot.core.agent.tool import FunctionTool, ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.tools.registry import builtin_tool WEB_SEARCH_TOOL_NAMES = [ "web_search_baidu", @@ -275,6 +276,7 @@ async def _baidu_search( ] +@builtin_tool @pydantic_dataclass class TavilyWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_tavily" @@ -357,6 +359,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) +@builtin_tool @pydantic_dataclass class TavilyExtractWebPageTool(FunctionTool[AstrAgentContext]): name: str = "tavily_extract_web_page" @@ -403,6 +406,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return ret or "Error: Tavily web searcher does not return any results." +@builtin_tool @pydantic_dataclass class BochaWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_bocha" @@ -466,6 +470,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) +@builtin_tool @pydantic_dataclass class BraveWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_brave" @@ -523,6 +528,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) +@builtin_tool @pydantic_dataclass class BaiduWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_baidu" @@ -585,18 +591,12 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) -WEB_SEARCH_TAVILY_TOOL = TavilyWebSearchTool() -TAVILY_EXTRACT_WEB_PAGE_TOOL = TavilyExtractWebPageTool() -WEB_SEARCH_BOCHA_TOOL = BochaWebSearchTool() -WEB_SEARCH_BRAVE_TOOL = BraveWebSearchTool() -WEB_SEARCH_BAIDU_TOOL = BaiduWebSearchTool() - __all__ = [ - "WEB_SEARCH_BAIDU_TOOL", - "WEB_SEARCH_BOCHA_TOOL", - "WEB_SEARCH_BRAVE_TOOL", - "WEB_SEARCH_TAVILY_TOOL", - "TAVILY_EXTRACT_WEB_PAGE_TOOL", + "BaiduWebSearchTool", + "BochaWebSearchTool", + "BraveWebSearchTool", + "TavilyExtractWebPageTool", + "TavilyWebSearchTool", "WEB_SEARCH_TOOL_NAMES", "normalize_legacy_web_search_config", ] diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 84f8dcc6d7..33b74deffc 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -428,10 +428,20 @@ async def test_mcp_connection(self): async def get_tool_list(self): """Get all registered tools.""" try: - tools = self.tool_mgr.func_list + tools = list(self.tool_mgr.func_list) + existing_names = {tool.name for tool in tools} + for tool in self.tool_mgr.iter_builtin_tools(): + if tool.name not in existing_names: + tools.append(tool) + tools_dict = [] for tool in tools: - if isinstance(tool, MCPTool): + readonly = False + if self.tool_mgr.is_builtin_tool(tool.name): + origin = "builtin" + origin_name = "AstrBot Core" + readonly = True + elif isinstance(tool, MCPTool): origin = "mcp" origin_name = tool.mcp_server_name elif tool.handler_module_path and star_map.get( @@ -451,6 +461,7 @@ async def get_tool_list(self): "active": tool.active, "origin": origin, "origin_name": origin_name, + "readonly": readonly, } tools_dict.append(tool_info) return Response().ok(data=tools_dict).__dict__ @@ -472,6 +483,13 @@ async def toggle_tool(self): .__dict__ ) + if self.tool_mgr.is_builtin_tool(tool_name): + return ( + Response() + .error("Builtin tools are read-only and cannot be toggled.") + .__dict__ + ) + if action: try: ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) diff --git a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue index 7fa4ef1679..f0be7bbc83 100644 --- a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue +++ b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue @@ -4,7 +4,6 @@ import { useModuleI18n } from '@/i18n/composables'; import type { ToolItem } from '../types'; const { tm: tmTool } = useModuleI18n('features/tooluse'); -const { tm: tmCommand } = useModuleI18n('features/command'); const props = defineProps<{ items: ToolItem[]; @@ -16,11 +15,10 @@ const emit = defineEmits<{ }>(); const toolHeaders = computed(() => [ - { title: tmTool('functionTools.title'), key: 'name', minWidth: '160px' }, + { title: tmTool('functionTools.title'), key: 'name', minWidth: '240px' }, { title: tmTool('functionTools.description'), key: 'description' }, { title: tmTool('functionTools.table.origin'), key: 'origin', sortable: false, width: '120px' }, { title: tmTool('functionTools.table.originName'), key: 'origin_name', sortable: false, width: '160px' }, - { title: tmCommand('status.enabled'), key: 'active', sortable: false, width: '120px' }, { title: tmTool('functionTools.table.actions'), key: 'actions', sortable: false, width: '120px' } ]); @@ -39,13 +37,8 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro :loading="props.loading" > @@ -56,7 +49,7 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro @@ -67,14 +60,10 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro - -