diff --git a/.vs/AstrBot.slnx/FileContentIndex/729256a7-0a3e-40db-915e-2ca6b1a0943f.vsidx b/.vs/AstrBot.slnx/FileContentIndex/729256a7-0a3e-40db-915e-2ca6b1a0943f.vsidx new file mode 100644 index 0000000000..9a52fdba8f Binary files /dev/null and b/.vs/AstrBot.slnx/FileContentIndex/729256a7-0a3e-40db-915e-2ca6b1a0943f.vsidx differ diff --git a/.vs/AstrBot.slnx/v18/.wsuo b/.vs/AstrBot.slnx/v18/.wsuo new file mode 100644 index 0000000000..f51a8fe959 Binary files /dev/null and b/.vs/AstrBot.slnx/v18/.wsuo differ diff --git a/.vs/AstrBot.slnx/v18/DocumentLayout.json b/.vs/AstrBot.slnx/v18/DocumentLayout.json new file mode 100644 index 0000000000..18e482b44e --- /dev/null +++ b/.vs/AstrBot.slnx/v18/DocumentLayout.json @@ -0,0 +1,23 @@ +{ + "Version": 1, + "WorkspaceRootPath": "E:\\AstrBot\\", + "Documents": [], + "DocumentGroupContainers": [ + { + "Orientation": 0, + "VerticalTabListWidth": 256, + "DocumentGroups": [ + { + "DockedWidth": 200, + "SelectedChildIndex": -1, + "Children": [ + { + "$type": "Bookmark", + "Name": "ST:0:0:{3ae79031-e1bc-11d0-8f78-00a0c9110057}" + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/.vs/AstrBot/v18/workspaceFileList.bin b/.vs/AstrBot/v18/workspaceFileList.bin new file mode 100644 index 0000000000..353dbffa3c Binary files /dev/null and b/.vs/AstrBot/v18/workspaceFileList.bin differ diff --git a/.vs/ProjectSettings.json b/.vs/ProjectSettings.json new file mode 100644 index 0000000000..f8b4888565 --- /dev/null +++ b/.vs/ProjectSettings.json @@ -0,0 +1,3 @@ +{ + "CurrentProjectSetting": null +} \ No newline at end of file diff --git a/.vs/VSWorkspaceState.json b/.vs/VSWorkspaceState.json new file mode 100644 index 0000000000..6b6114114f --- /dev/null +++ b/.vs/VSWorkspaceState.json @@ -0,0 +1,6 @@ +{ + "ExpandedNodes": [ + "" + ], + "PreviewInSolutionExplorer": false +} \ No newline at end of file diff --git a/.vs/slnx.sqlite b/.vs/slnx.sqlite new file mode 100644 index 0000000000..6b932cc86b Binary files /dev/null and b/.vs/slnx.sqlite differ diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 9861e669c4..f1ac270d13 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -51,7 +51,8 @@ retrieve_knowledge_base, ) from astrbot.core.conversation_mgr import Conversation -from astrbot.core.message.components import File, Image, Record, Reply +from astrbot.core.exceptions import UnsupportedToolCapabilityError +from astrbot.core.message.components import File, Image, Reply from astrbot.core.persona_error_reply import ( extract_persona_custom_error_message_from_persona, set_persona_custom_error_message_on_event, @@ -99,6 +100,10 @@ class MainAgentBuildConfig: """ tool_schema_mode: str = "full" """The tool schema mode, can be 'full' or 'skills-like'.""" + tool_capability_strategy: str = "fallback_provider" + """How to handle tool requests when the selected model does not support tool calls. + Supported values: fallback_provider, chat_only, hard_fail. + """ provider_wake_prefix: str = "" """The wake prefix for the provider. If the user message does not start with this prefix, the main agent will not be triggered.""" @@ -151,6 +156,7 @@ class MainAgentBuildResult: provider_request: ProviderRequest provider: Provider reset_coro: Coroutine | None = None + allow_follow_up: bool = True def _select_provider( @@ -771,26 +777,81 @@ def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: else: req.prompt = placeholder req.image_urls = [] - if req.audio_urls: - provider_cfg = provider.provider_config.get("modalities", ["audio"]) - if "audio" not in provider_cfg: - logger.debug( - "Provider %s does not support audio, using placeholder.", provider - ) - audio_count = len(req.audio_urls) - placeholder = " ".join(["[Audio]"] * audio_count) - if req.prompt: - req.prompt = f"{placeholder} {req.prompt}" - else: - req.prompt = placeholder - req.audio_urls = [] - if req.func_tool: - provider_cfg = provider.provider_config.get("modalities", ["tool_use"]) - if "tool_use" not in provider_cfg: - logger.debug( - "Provider %s does not support tool_use, clearing tools.", provider - ) - req.func_tool = None + + +def _request_has_tools(req: ProviderRequest) -> bool: + return req.func_tool is not None and not req.func_tool.empty() + + +def _get_effective_model_name( + provider: Provider, + requested_model: str | None = None, +) -> str: + model_name = ( + requested_model + or provider.get_model() + or provider.provider_config.get("model", "") + ) + return str(model_name or "") + + +def _get_tool_call_support( + provider: Provider, + *, + requested_model: str | None = None, +) -> bool | None: + modalities = provider.provider_config.get("modalities", None) + if isinstance(modalities, list) and "tool_use" not in modalities: + return False + + model_name = _get_effective_model_name(provider, requested_model) + if not model_name: + return None + + metadata = LLM_METADATAS.get(model_name) + if metadata is None: + return None + return bool(metadata["tool_call"]) + + +def _strip_tool_state_from_request(req: ProviderRequest) -> None: + req.func_tool = None + req.tool_calls_result = None + + if not isinstance(req.contexts, list) or not req.contexts: + return + + sanitized_contexts: list[dict] = [] + for message in req.contexts: + if not isinstance(message, dict): + continue + role = message.get("role") + if role == "tool": + continue + + new_message = copy.deepcopy(message) + if role == "assistant": + new_message.pop("tool_calls", None) + new_message.pop("tool_call_id", None) + content = new_message.get("content") + if not content: + continue + if isinstance(content, str) and not content.strip(): + continue + + sanitized_contexts.append(new_message) + + req.contexts = sanitized_contexts + + +def _normalize_tool_capability_strategy(strategy: str) -> str: + if strategy in {"fallback_provider", "chat_only", "hard_fail"}: + return strategy + logger.warning( + "Unsupported tool_capability_strategy `%s`, fallback to `fallback_provider`.", + strategy, + ) + return "fallback_provider" def _sanitize_context_by_modalities( @@ -1112,6 +1173,93 @@ def _get_fallback_chat_providers( return fallbacks +def _get_tool_capable_fallback_provider( + provider: Provider, + plugin_context: Context, + provider_settings: dict, +) -> tuple[Provider | None, bool]: + unknown_capability_provider: Provider | None = None + + for fallback_provider in _get_fallback_chat_providers( + provider, plugin_context, provider_settings + ): + support = _get_tool_call_support(fallback_provider) + if support is True: + return fallback_provider, True + if support is None and unknown_capability_provider is None: + unknown_capability_provider = fallback_provider + + return unknown_capability_provider, False + + +def _resolve_tool_capability_strategy( + *, + provider: Provider, + req: ProviderRequest, + plugin_context: Context, + config: MainAgentBuildConfig, +) -> tuple[Provider, bool]: + if not _request_has_tools(req): + return provider, True + + support = _get_tool_call_support(provider, requested_model=req.model) + if support is not False: + return provider, True + + strategy = _normalize_tool_capability_strategy(config.tool_capability_strategy) + model_name = _get_effective_model_name(provider, req.model) + provider_id = str(provider.provider_config.get("id", "")) + + if strategy == "fallback_provider": + fallback_provider, confirmed = _get_tool_capable_fallback_provider( + provider, + plugin_context, + config.provider_settings, + ) + if fallback_provider is not None: + fallback_model = _get_effective_model_name(fallback_provider) + if confirmed: + logger.info( + "Model `%s` on provider `%s` does not support tool calls, " + "switching to fallback provider `%s` (%s).", + model_name, + provider_id, + fallback_provider.provider_config.get("id", ""), + fallback_model, + ) + else: + logger.info( + "Model `%s` on provider `%s` does not support tool calls, " + "trying fallback provider `%s` without explicit tool metadata.", + model_name, + provider_id, + fallback_provider.provider_config.get("id", ""), + ) + req.model = fallback_model or None + return fallback_provider, True + + logger.warning( + "Model `%s` on provider `%s` does not support tool calls and no suitable " + "fallback provider was found; degrading to chat_only.", + model_name, + provider_id, + ) + strategy = "chat_only" + + if strategy == "hard_fail": + raise UnsupportedToolCapabilityError( + f"Model `{model_name}` on provider `{provider_id}` does not support tool calls." + ) + + logger.info( + "Model `%s` on provider `%s` does not support tool calls, degrading to chat_only.", + model_name, + provider_id, + ) + _strip_tool_state_from_request(req) + return provider, False + + async def build_main_agent( *, event: AstrMessageEvent, @@ -1293,9 +1441,7 @@ async def build_main_agent( if not req.session_id: req.session_id = event.unified_msg_origin - _modalities_fix(provider, req) _plugin_tool_fix(event, req) - _sanitize_context_by_modalities(config, provider, req) if config.llm_safety_mode: _apply_llm_safety_mode(config, req) @@ -1319,6 +1465,16 @@ async def build_main_agent( req.func_tool = ToolSet() req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) + provider, allow_follow_up = _resolve_tool_capability_strategy( + provider=provider, + req=req, + plugin_context=plugin_context, + config=config, + ) + + _modalities_fix(provider, req) + _sanitize_context_by_modalities(config, provider, req) + if provider.provider_config.get("max_context_tokens", 0) <= 0: model = provider.get_model() if model_info := LLM_METADATAS.get(model): @@ -1370,4 +1526,5 @@ async def build_main_agent( provider_request=req, provider=provider, reset_coro=reset_coro if not apply_reset else None, + allow_follow_up=allow_follow_up, ) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ee42ca99d0..37f402cd65 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -151,6 +151,7 @@ "max_agent_step": 30, "tool_call_timeout": 120, "tool_schema_mode": "full", + "tool_capability_strategy": "fallback_provider", "llm_safety_mode": True, "safety_mode_strategy": "system_prompt", # TODO: llm judge "file_extract": { @@ -2777,6 +2778,9 @@ class ChatProviderTemplate(TypedDict): "tool_schema_mode": { "type": "string", }, + "tool_capability_strategy": { + "type": "string", + }, "file_extract": { "type": "object", "items": { @@ -3533,6 +3537,24 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.tool_capability_strategy": { + "description": "模型不支持工具调用时的处理策略", + "type": "string", + "options": [ + "fallback_provider", + "chat_only", + "hard_fail", + ], + "labels": [ + "自动切换回退模型", + "降级为纯文本聊天", + "直接报错", + ], + "hint": "当当前模型不支持函数工具调用时,优先切换到回退模型;也可以直接降级为纯文本聊天或中止请求。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.wake_prefix": { "description": "LLM 聊天额外唤醒前缀 ", "type": "string", diff --git a/astrbot/core/exceptions.py b/astrbot/core/exceptions.py index f10af57ea8..f85170a2eb 100644 --- a/astrbot/core/exceptions.py +++ b/astrbot/core/exceptions.py @@ -11,3 +11,7 @@ class ProviderNotFoundError(AstrBotError): class EmptyModelOutputError(AstrBotError): """Raised when the model response contains no usable assistant output.""" + + +class UnsupportedToolCapabilityError(AstrBotError): + """Raised when the selected model cannot satisfy the requested tool mode.""" diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index e0ba2463ca..1d10609b21 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -56,6 +56,10 @@ async def initialize(self, ctx: PipelineContext) -> None: self.max_step: int = settings.get("max_agent_step", 30) self.tool_call_timeout: int = settings.get("tool_call_timeout", 60) self.tool_schema_mode: str = settings.get("tool_schema_mode", "full") + self.tool_capability_strategy: str = settings.get( + "tool_capability_strategy", + "fallback_provider", + ) if self.tool_schema_mode not in ("skills_like", "full"): logger.warning( "Unsupported tool_schema_mode: %s, fallback to skills_like", @@ -116,6 +120,7 @@ async def initialize(self, ctx: PipelineContext) -> None: self.main_agent_cfg = MainAgentBuildConfig( tool_call_timeout=self.tool_call_timeout, tool_schema_mode=self.tool_schema_mode, + tool_capability_strategy=self.tool_capability_strategy, sanitize_context_by_modalities=self.sanitize_context_by_modalities, kb_agentic_mode=self.kb_agentic_mode, file_extract_enabled=self.file_extract_enabled, @@ -236,8 +241,9 @@ async def process( if reset_coro: await reset_coro - register_active_runner(event.unified_msg_origin, agent_runner) - runner_registered = True + if build_result.allow_follow_up: + register_active_runner(event.unified_msg_origin, agent_runner) + runner_registered = True action_type = event.get_extra("action_type") event.trace.record( diff --git a/tests/test_astr_main_agent_tool_capability.py b/tests/test_astr_main_agent_tool_capability.py new file mode 100644 index 0000000000..a4ff91abf5 --- /dev/null +++ b/tests/test_astr_main_agent_tool_capability.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.astr_main_agent import ( + MainAgentBuildConfig, + _resolve_tool_capability_strategy, +) +from astrbot.core.exceptions import UnsupportedToolCapabilityError +from astrbot.core.provider.entities import ProviderRequest +from astrbot.core.provider.provider import Provider +from astrbot.core.utils.llm_metadata import LLM_METADATAS + + +class DummyProvider(Provider): + def __init__( + self, + provider_id: str, + model: str, + modalities: list[str] | None = None, + ) -> None: + provider_config: dict[str, Any] = { + "id": provider_id, + "type": "openai_chat_completion", + "model": model, + } + if modalities is not None: + provider_config["modalities"] = modalities + super().__init__(provider_config, {}) + self.set_model(model) + + def get_current_key(self) -> str: + return "test-key" + + def set_key(self, key: str) -> None: + return None + + async def get_models(self) -> list[str]: + return [self.get_model()] + + async def text_chat(self, **kwargs): + raise NotImplementedError + + +class DummyPluginContext: + def __init__(self, providers: dict[str, Provider]) -> None: + self.providers = providers + + def get_provider_by_id(self, provider_id: str) -> Provider | None: + return self.providers.get(provider_id) + + +def _make_metadata(tool_call: bool) -> dict[str, Any]: + return { + "id": "test-model", + "reasoning": False, + "tool_call": tool_call, + "knowledge": "none", + "release_date": "", + "modalities": {"input": ["text"], "output": ["text"]}, + "open_weights": False, + "limit": {"context": 0, "output": 0}, + } + + +def _make_tool_set() -> ToolSet: + return ToolSet( + [ + FunctionTool( + name="test_tool", + description="Test tool", + parameters={"type": "object", "properties": {}}, + handler=None, + ) + ] + ) + + +def test_tool_capability_strategy_switches_to_fallback_provider( + monkeypatch: pytest.MonkeyPatch, +) -> None: + primary = DummyProvider("primary", "deepseek-r1:7b", ["text", "tool_use"]) + fallback = DummyProvider("fallback", "gpt-4.1-mini", ["text", "tool_use"]) + plugin_context = DummyPluginContext( + { + "primary": primary, + "fallback": fallback, + } + ) + req = ProviderRequest( + func_tool=_make_tool_set(), + model="deepseek-r1:7b", + ) + config = MainAgentBuildConfig( + tool_call_timeout=30, + tool_capability_strategy="fallback_provider", + provider_settings={"fallback_chat_models": ["fallback"]}, + ) + + monkeypatch.setitem(LLM_METADATAS, "deepseek-r1:7b", _make_metadata(False)) + monkeypatch.setitem(LLM_METADATAS, "gpt-4.1-mini", _make_metadata(True)) + + selected_provider, allow_follow_up = _resolve_tool_capability_strategy( + provider=primary, + req=req, + plugin_context=plugin_context, + config=config, + ) + + assert selected_provider is fallback + assert allow_follow_up is True + assert req.func_tool is not None + assert req.model == "gpt-4.1-mini" + + +def test_tool_capability_strategy_chat_only_clears_tool_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + primary = DummyProvider("primary", "deepseek-r1:7b", ["text", "tool_use"]) + plugin_context = DummyPluginContext({"primary": primary}) + req = ProviderRequest( + func_tool=_make_tool_set(), + contexts=[ + {"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": ""}, + {"role": "tool", "content": "tool output"}, + {"role": "user", "content": "Why?"}, + {"role": "assistant", "content": "Plain answer"}, + ], + ) + config = MainAgentBuildConfig( + tool_call_timeout=30, + tool_capability_strategy="chat_only", + provider_settings={}, + ) + + monkeypatch.setitem(LLM_METADATAS, "deepseek-r1:7b", _make_metadata(False)) + + selected_provider, allow_follow_up = _resolve_tool_capability_strategy( + provider=primary, + req=req, + plugin_context=plugin_context, + config=config, + ) + + assert selected_provider is primary + assert allow_follow_up is False + assert req.func_tool is None + assert req.tool_calls_result is None + assert req.contexts == [ + {"role": "user", "content": "Why?"}, + {"role": "assistant", "content": "Plain answer"}, + ] + + +def test_tool_capability_strategy_hard_fail_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + primary = DummyProvider("primary", "deepseek-r1:7b", ["text", "tool_use"]) + plugin_context = DummyPluginContext({"primary": primary}) + req = ProviderRequest(func_tool=_make_tool_set()) + config = MainAgentBuildConfig( + tool_call_timeout=30, + tool_capability_strategy="hard_fail", + provider_settings={}, + ) + + # Capture initial tool and context state to ensure they are not mutated + original_func_tool = req.func_tool + original_contexts = list(req.contexts) + + monkeypatch.setitem(LLM_METADATAS, "deepseek-r1:7b", _make_metadata(False)) + + with pytest.raises(UnsupportedToolCapabilityError): + _resolve_tool_capability_strategy( + provider=primary, + req=req, + plugin_context=plugin_context, + config=config, + ) + + # Hard-fail strategy should not mutate tool or context state + assert req.func_tool is original_func_tool + assert req.contexts == original_contexts