Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file added .vs/AstrBot.slnx/v18/.wsuo
Binary file not shown.
23 changes: 23 additions & 0 deletions .vs/AstrBot.slnx/v18/DocumentLayout.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"Version": 1,
"WorkspaceRootPath": "E:\\AstrBot\\",
"Documents": [],
"DocumentGroupContainers": [
{
"Orientation": 0,
"VerticalTabListWidth": 256,
"DocumentGroups": [
{
"DockedWidth": 200,
"SelectedChildIndex": -1,
"Children": [
{
"$type": "Bookmark",
"Name": "ST:0:0:{3ae79031-e1bc-11d0-8f78-00a0c9110057}"
}
]
}
]
}
]
}
Binary file added .vs/AstrBot/v18/workspaceFileList.bin
Binary file not shown.
3 changes: 3 additions & 0 deletions .vs/ProjectSettings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"CurrentProjectSetting": null
}
6 changes: 6 additions & 0 deletions .vs/VSWorkspaceState.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ExpandedNodes": [
""
],
"PreviewInSolutionExplorer": false
}
Binary file added .vs/slnx.sqlite
Binary file not shown.
203 changes: 180 additions & 23 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
retrieve_knowledge_base,
)
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Record, Reply
from astrbot.core.exceptions import UnsupportedToolCapabilityError
from astrbot.core.message.components import File, Image, Reply
from astrbot.core.persona_error_reply import (
extract_persona_custom_error_message_from_persona,
set_persona_custom_error_message_on_event,
Expand Down Expand Up @@ -99,6 +100,10 @@ class MainAgentBuildConfig:
"""
tool_schema_mode: str = "full"
"""The tool schema mode, can be 'full' or 'skills-like'."""
tool_capability_strategy: str = "fallback_provider"
"""How to handle tool requests when the selected model does not support tool calls.
Supported values: fallback_provider, chat_only, hard_fail.
"""
provider_wake_prefix: str = ""
"""The wake prefix for the provider. If the user message does not start with this prefix,
the main agent will not be triggered."""
Expand Down Expand Up @@ -151,6 +156,7 @@ class MainAgentBuildResult:
provider_request: ProviderRequest
provider: Provider
reset_coro: Coroutine | None = None
allow_follow_up: bool = True


def _select_provider(
Expand Down Expand Up @@ -771,26 +777,81 @@ def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
else:
req.prompt = placeholder
req.image_urls = []
if req.audio_urls:
provider_cfg = provider.provider_config.get("modalities", ["audio"])
if "audio" not in provider_cfg:
logger.debug(
"Provider %s does not support audio, using placeholder.", provider
)
audio_count = len(req.audio_urls)
placeholder = " ".join(["[Audio]"] * audio_count)
if req.prompt:
req.prompt = f"{placeholder} {req.prompt}"
else:
req.prompt = placeholder
req.audio_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
if "tool_use" not in provider_cfg:
logger.debug(
"Provider %s does not support tool_use, clearing tools.", provider
)
req.func_tool = None


def _request_has_tools(req: ProviderRequest) -> bool:
return req.func_tool is not None and not req.func_tool.empty()


def _get_effective_model_name(
provider: Provider,
requested_model: str | None = None,
) -> str:
model_name = (
requested_model
or provider.get_model()
or provider.provider_config.get("model", "")
)
return str(model_name or "")


def _get_tool_call_support(
provider: Provider,
*,
requested_model: str | None = None,
) -> bool | None:
modalities = provider.provider_config.get("modalities", None)
if isinstance(modalities, list) and "tool_use" not in modalities:
return False

model_name = _get_effective_model_name(provider, requested_model)
if not model_name:
return None

metadata = LLM_METADATAS.get(model_name)
if metadata is None:
return None
return bool(metadata["tool_call"])
Comment on lines +811 to +814
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Guard against missing tool_call key in LLM_METADATAS entries to avoid KeyError.

Using metadata["tool_call"] will raise a KeyError if that key is missing for any model. Prefer metadata.get("tool_call") (with a default like None or False) before casting to bool, so the behavior matches the metadata is None case instead of breaking the flow.



def _strip_tool_state_from_request(req: ProviderRequest) -> None:
req.func_tool = None
req.tool_calls_result = None

if not isinstance(req.contexts, list) or not req.contexts:
return

sanitized_contexts: list[dict] = []
for message in req.contexts:
if not isinstance(message, dict):
continue
role = message.get("role")
if role == "tool":
continue

new_message = copy.deepcopy(message)
if role == "assistant":
new_message.pop("tool_calls", None)
new_message.pop("tool_call_id", None)
content = new_message.get("content")
if not content:
continue
if isinstance(content, str) and not content.strip():
continue

sanitized_contexts.append(new_message)

req.contexts = sanitized_contexts


def _normalize_tool_capability_strategy(strategy: str) -> str:
if strategy in {"fallback_provider", "chat_only", "hard_fail"}:
return strategy
Comment on lines +848 to +849
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability, consider defining the valid tool capability strategy names as constants or an Enum in a shared location, rather than hardcoding them as a set here. This would prevent inconsistencies, as these names are also used in astrbot/core/config/default.py for the configuration schema.

For example, you could define an Enum:

from enum import Enum

class ToolCapabilityStrategy(str, Enum):
    FALLBACK_PROVIDER = "fallback_provider"
    CHAT_ONLY = "chat_only"
    HARD_FAIL = "hard_fail"

VALID_TOOL_CAPABILITY_STRATEGIES = {s.value for s in ToolCapabilityStrategy}

And then use VALID_TOOL_CAPABILITY_STRATEGIES for validation. This would make the code more robust and easier to update.

logger.warning(
"Unsupported tool_capability_strategy `%s`, fallback to `fallback_provider`.",
strategy,
)
return "fallback_provider"


def _sanitize_context_by_modalities(
Expand Down Expand Up @@ -1112,6 +1173,93 @@ def _get_fallback_chat_providers(
return fallbacks


def _get_tool_capable_fallback_provider(
provider: Provider,
plugin_context: Context,
provider_settings: dict,
) -> tuple[Provider | None, bool]:
unknown_capability_provider: Provider | None = None

for fallback_provider in _get_fallback_chat_providers(
provider, plugin_context, provider_settings
):
support = _get_tool_call_support(fallback_provider)
if support is True:
return fallback_provider, True
if support is None and unknown_capability_provider is None:
unknown_capability_provider = fallback_provider

return unknown_capability_provider, False


def _resolve_tool_capability_strategy(
*,
provider: Provider,
req: ProviderRequest,
plugin_context: Context,
config: MainAgentBuildConfig,
) -> tuple[Provider, bool]:
if not _request_has_tools(req):
return provider, True

support = _get_tool_call_support(provider, requested_model=req.model)
if support is not False:
return provider, True

strategy = _normalize_tool_capability_strategy(config.tool_capability_strategy)
model_name = _get_effective_model_name(provider, req.model)
provider_id = str(provider.provider_config.get("id", ""))

if strategy == "fallback_provider":
fallback_provider, confirmed = _get_tool_capable_fallback_provider(
provider,
plugin_context,
config.provider_settings,
)
if fallback_provider is not None:
fallback_model = _get_effective_model_name(fallback_provider)
if confirmed:
logger.info(
"Model `%s` on provider `%s` does not support tool calls, "
"switching to fallback provider `%s` (%s).",
model_name,
provider_id,
fallback_provider.provider_config.get("id", ""),
fallback_model,
)
else:
logger.info(
"Model `%s` on provider `%s` does not support tool calls, "
"trying fallback provider `%s` without explicit tool metadata.",
model_name,
provider_id,
fallback_provider.provider_config.get("id", ""),
)
req.model = fallback_model or None
return fallback_provider, True

logger.warning(
"Model `%s` on provider `%s` does not support tool calls and no suitable "
"fallback provider was found; degrading to chat_only.",
model_name,
provider_id,
)
strategy = "chat_only"

if strategy == "hard_fail":
raise UnsupportedToolCapabilityError(
f"Model `{model_name}` on provider `{provider_id}` does not support tool calls."
)

logger.info(
"Model `%s` on provider `%s` does not support tool calls, degrading to chat_only.",
model_name,
provider_id,
)
_strip_tool_state_from_request(req)
return provider, False


async def build_main_agent(
*,
event: AstrMessageEvent,
Expand Down Expand Up @@ -1293,9 +1441,7 @@ async def build_main_agent(
if not req.session_id:
req.session_id = event.unified_msg_origin

_modalities_fix(provider, req)
_plugin_tool_fix(event, req)
_sanitize_context_by_modalities(config, provider, req)

if config.llm_safety_mode:
_apply_llm_safety_mode(config, req)
Expand All @@ -1319,6 +1465,16 @@ async def build_main_agent(
req.func_tool = ToolSet()
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)

provider, allow_follow_up = _resolve_tool_capability_strategy(
provider=provider,
req=req,
plugin_context=plugin_context,
config=config,
)

_modalities_fix(provider, req)
_sanitize_context_by_modalities(config, provider, req)

if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
if model_info := LLM_METADATAS.get(model):
Expand Down Expand Up @@ -1370,4 +1526,5 @@ async def build_main_agent(
provider_request=req,
provider=provider,
reset_coro=reset_coro if not apply_reset else None,
allow_follow_up=allow_follow_up,
)
22 changes: 22 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
"max_agent_step": 30,
"tool_call_timeout": 120,
"tool_schema_mode": "full",
"tool_capability_strategy": "fallback_provider",
"llm_safety_mode": True,
"safety_mode_strategy": "system_prompt", # TODO: llm judge
"file_extract": {
Expand Down Expand Up @@ -2777,6 +2778,9 @@ class ChatProviderTemplate(TypedDict):
"tool_schema_mode": {
"type": "string",
},
"tool_capability_strategy": {
"type": "string",
},
"file_extract": {
"type": "object",
"items": {
Expand Down Expand Up @@ -3533,6 +3537,24 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_capability_strategy": {
"description": "模型不支持工具调用时的处理策略",
"type": "string",
"options": [
"fallback_provider",
"chat_only",
"hard_fail",
],
"labels": [
"自动切换回退模型",
"降级为纯文本聊天",
"直接报错",
],
"hint": "当当前模型不支持函数工具调用时,优先切换到回退模型;也可以直接降级为纯文本聊天或中止请求。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class ProviderNotFoundError(AstrBotError):

class EmptyModelOutputError(AstrBotError):
"""Raised when the model response contains no usable assistant output."""


class UnsupportedToolCapabilityError(AstrBotError):
"""Raised when the selected model cannot satisfy the requested tool mode."""
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
self.tool_schema_mode: str = settings.get("tool_schema_mode", "full")
self.tool_capability_strategy: str = settings.get(
"tool_capability_strategy",
"fallback_provider",
)
if self.tool_schema_mode not in ("skills_like", "full"):
logger.warning(
"Unsupported tool_schema_mode: %s, fallback to skills_like",
Expand Down Expand Up @@ -116,6 +120,7 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.main_agent_cfg = MainAgentBuildConfig(
tool_call_timeout=self.tool_call_timeout,
tool_schema_mode=self.tool_schema_mode,
tool_capability_strategy=self.tool_capability_strategy,
sanitize_context_by_modalities=self.sanitize_context_by_modalities,
kb_agentic_mode=self.kb_agentic_mode,
file_extract_enabled=self.file_extract_enabled,
Expand Down Expand Up @@ -236,8 +241,9 @@ async def process(
if reset_coro:
await reset_coro

register_active_runner(event.unified_msg_origin, agent_runner)
runner_registered = True
if build_result.allow_follow_up:
register_active_runner(event.unified_msg_origin, agent_runner)
runner_registered = True
action_type = event.get_extra("action_type")

event.trace.record(
Expand Down
Loading