-
-
Notifications
You must be signed in to change notification settings - Fork 2k
fix: handle tool-incompatible models before agent execution #7164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
930c420
e721106
316f384
0eb0fe1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| { | ||
| "Version": 1, | ||
| "WorkspaceRootPath": "E:\\AstrBot\\", | ||
| "Documents": [], | ||
| "DocumentGroupContainers": [ | ||
| { | ||
| "Orientation": 0, | ||
| "VerticalTabListWidth": 256, | ||
| "DocumentGroups": [ | ||
| { | ||
| "DockedWidth": 200, | ||
| "SelectedChildIndex": -1, | ||
| "Children": [ | ||
| { | ||
| "$type": "Bookmark", | ||
| "Name": "ST:0:0:{3ae79031-e1bc-11d0-8f78-00a0c9110057}" | ||
| } | ||
| ] | ||
| } | ||
| ] | ||
| } | ||
| ] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| { | ||
| "CurrentProjectSetting": null | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| { | ||
| "ExpandedNodes": [ | ||
| "" | ||
| ], | ||
| "PreviewInSolutionExplorer": false | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,8 @@ | |
| retrieve_knowledge_base, | ||
| ) | ||
| from astrbot.core.conversation_mgr import Conversation | ||
| from astrbot.core.message.components import File, Image, Record, Reply | ||
| from astrbot.core.exceptions import UnsupportedToolCapabilityError | ||
| from astrbot.core.message.components import File, Image, Reply | ||
| from astrbot.core.persona_error_reply import ( | ||
| extract_persona_custom_error_message_from_persona, | ||
| set_persona_custom_error_message_on_event, | ||
|
|
@@ -99,6 +100,10 @@ class MainAgentBuildConfig: | |
| """ | ||
| tool_schema_mode: str = "full" | ||
| """The tool schema mode, can be 'full' or 'skills-like'.""" | ||
| tool_capability_strategy: str = "fallback_provider" | ||
| """How to handle tool requests when the selected model does not support tool calls. | ||
| Supported values: fallback_provider, chat_only, hard_fail. | ||
| """ | ||
| provider_wake_prefix: str = "" | ||
| """The wake prefix for the provider. If the user message does not start with this prefix, | ||
| the main agent will not be triggered.""" | ||
|
|
@@ -151,6 +156,7 @@ class MainAgentBuildResult: | |
| provider_request: ProviderRequest | ||
| provider: Provider | ||
| reset_coro: Coroutine | None = None | ||
| allow_follow_up: bool = True | ||
|
|
||
|
|
||
| def _select_provider( | ||
|
|
@@ -771,26 +777,81 @@ def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: | |
| else: | ||
| req.prompt = placeholder | ||
| req.image_urls = [] | ||
| if req.audio_urls: | ||
| provider_cfg = provider.provider_config.get("modalities", ["audio"]) | ||
| if "audio" not in provider_cfg: | ||
| logger.debug( | ||
| "Provider %s does not support audio, using placeholder.", provider | ||
| ) | ||
| audio_count = len(req.audio_urls) | ||
| placeholder = " ".join(["[Audio]"] * audio_count) | ||
| if req.prompt: | ||
| req.prompt = f"{placeholder} {req.prompt}" | ||
| else: | ||
| req.prompt = placeholder | ||
| req.audio_urls = [] | ||
| if req.func_tool: | ||
| provider_cfg = provider.provider_config.get("modalities", ["tool_use"]) | ||
| if "tool_use" not in provider_cfg: | ||
| logger.debug( | ||
| "Provider %s does not support tool_use, clearing tools.", provider | ||
| ) | ||
| req.func_tool = None | ||
|
|
||
|
|
||
| def _request_has_tools(req: ProviderRequest) -> bool: | ||
| return req.func_tool is not None and not req.func_tool.empty() | ||
|
|
||
|
|
||
| def _get_effective_model_name( | ||
| provider: Provider, | ||
| requested_model: str | None = None, | ||
| ) -> str: | ||
| model_name = ( | ||
| requested_model | ||
| or provider.get_model() | ||
| or provider.provider_config.get("model", "") | ||
| ) | ||
| return str(model_name or "") | ||
|
|
||
|
|
||
| def _get_tool_call_support( | ||
| provider: Provider, | ||
| *, | ||
| requested_model: str | None = None, | ||
| ) -> bool | None: | ||
| modalities = provider.provider_config.get("modalities", None) | ||
| if isinstance(modalities, list) and "tool_use" not in modalities: | ||
| return False | ||
|
|
||
| model_name = _get_effective_model_name(provider, requested_model) | ||
| if not model_name: | ||
| return None | ||
|
|
||
| metadata = LLM_METADATAS.get(model_name) | ||
| if metadata is None: | ||
| return None | ||
| return bool(metadata["tool_call"]) | ||
|
|
||
|
|
||
| def _strip_tool_state_from_request(req: ProviderRequest) -> None: | ||
| req.func_tool = None | ||
| req.tool_calls_result = None | ||
|
|
||
| if not isinstance(req.contexts, list) or not req.contexts: | ||
| return | ||
|
|
||
| sanitized_contexts: list[dict] = [] | ||
| for message in req.contexts: | ||
| if not isinstance(message, dict): | ||
| continue | ||
| role = message.get("role") | ||
| if role == "tool": | ||
| continue | ||
|
|
||
| new_message = copy.deepcopy(message) | ||
| if role == "assistant": | ||
| new_message.pop("tool_calls", None) | ||
| new_message.pop("tool_call_id", None) | ||
| content = new_message.get("content") | ||
| if not content: | ||
| continue | ||
| if isinstance(content, str) and not content.strip(): | ||
| continue | ||
|
|
||
| sanitized_contexts.append(new_message) | ||
|
|
||
| req.contexts = sanitized_contexts | ||
|
|
||
|
|
||
| def _normalize_tool_capability_strategy(strategy: str) -> str: | ||
| if strategy in {"fallback_provider", "chat_only", "hard_fail"}: | ||
| return strategy | ||
|
Comment on lines
+848
to
+849
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve maintainability, consider defining the valid tool capability strategy names as constants or an Enum in a shared location, rather than hardcoding them as a set here. This would prevent inconsistencies, as these names are also used in For example, you could define an Enum: from enum import Enum
class ToolCapabilityStrategy(str, Enum):
FALLBACK_PROVIDER = "fallback_provider"
CHAT_ONLY = "chat_only"
HARD_FAIL = "hard_fail"
VALID_TOOL_CAPABILITY_STRATEGIES = {s.value for s in ToolCapabilityStrategy}And then use |
||
| logger.warning( | ||
| "Unsupported tool_capability_strategy `%s`, fallback to `fallback_provider`.", | ||
| strategy, | ||
| ) | ||
| return "fallback_provider" | ||
|
|
||
|
|
||
| def _sanitize_context_by_modalities( | ||
|
|
@@ -1112,6 +1173,93 @@ def _get_fallback_chat_providers( | |
| return fallbacks | ||
|
|
||
|
|
||
| def _get_tool_capable_fallback_provider( | ||
| provider: Provider, | ||
| plugin_context: Context, | ||
| provider_settings: dict, | ||
| ) -> tuple[Provider | None, bool]: | ||
| unknown_capability_provider: Provider | None = None | ||
|
|
||
| for fallback_provider in _get_fallback_chat_providers( | ||
| provider, plugin_context, provider_settings | ||
| ): | ||
| support = _get_tool_call_support(fallback_provider) | ||
| if support is True: | ||
| return fallback_provider, True | ||
| if support is None and unknown_capability_provider is None: | ||
| unknown_capability_provider = fallback_provider | ||
|
|
||
| return unknown_capability_provider, False | ||
|
|
||
|
|
||
| def _resolve_tool_capability_strategy( | ||
| *, | ||
| provider: Provider, | ||
| req: ProviderRequest, | ||
| plugin_context: Context, | ||
| config: MainAgentBuildConfig, | ||
| ) -> tuple[Provider, bool]: | ||
| if not _request_has_tools(req): | ||
| return provider, True | ||
|
|
||
| support = _get_tool_call_support(provider, requested_model=req.model) | ||
| if support is not False: | ||
| return provider, True | ||
|
|
||
| strategy = _normalize_tool_capability_strategy(config.tool_capability_strategy) | ||
| model_name = _get_effective_model_name(provider, req.model) | ||
| provider_id = str(provider.provider_config.get("id", "")) | ||
|
|
||
| if strategy == "fallback_provider": | ||
| fallback_provider, confirmed = _get_tool_capable_fallback_provider( | ||
| provider, | ||
| plugin_context, | ||
| config.provider_settings, | ||
| ) | ||
| if fallback_provider is not None: | ||
| fallback_model = _get_effective_model_name(fallback_provider) | ||
| if confirmed: | ||
| logger.info( | ||
| "Model `%s` on provider `%s` does not support tool calls, " | ||
| "switching to fallback provider `%s` (%s).", | ||
| model_name, | ||
| provider_id, | ||
| fallback_provider.provider_config.get("id", ""), | ||
| fallback_model, | ||
| ) | ||
| else: | ||
| logger.info( | ||
| "Model `%s` on provider `%s` does not support tool calls, " | ||
| "trying fallback provider `%s` without explicit tool metadata.", | ||
| model_name, | ||
| provider_id, | ||
| fallback_provider.provider_config.get("id", ""), | ||
| ) | ||
| req.model = fallback_model or None | ||
| return fallback_provider, True | ||
|
|
||
| logger.warning( | ||
| "Model `%s` on provider `%s` does not support tool calls and no suitable " | ||
| "fallback provider was found; degrading to chat_only.", | ||
| model_name, | ||
| provider_id, | ||
| ) | ||
| strategy = "chat_only" | ||
|
|
||
| if strategy == "hard_fail": | ||
| raise UnsupportedToolCapabilityError( | ||
| f"Model `{model_name}` on provider `{provider_id}` does not support tool calls." | ||
| ) | ||
|
|
||
| logger.info( | ||
| "Model `%s` on provider `%s` does not support tool calls, degrading to chat_only.", | ||
| model_name, | ||
| provider_id, | ||
| ) | ||
| _strip_tool_state_from_request(req) | ||
| return provider, False | ||
|
|
||
|
|
||
| async def build_main_agent( | ||
| *, | ||
| event: AstrMessageEvent, | ||
|
|
@@ -1293,9 +1441,7 @@ async def build_main_agent( | |
| if not req.session_id: | ||
| req.session_id = event.unified_msg_origin | ||
|
|
||
| _modalities_fix(provider, req) | ||
| _plugin_tool_fix(event, req) | ||
| _sanitize_context_by_modalities(config, provider, req) | ||
|
|
||
| if config.llm_safety_mode: | ||
| _apply_llm_safety_mode(config, req) | ||
|
|
@@ -1319,6 +1465,16 @@ async def build_main_agent( | |
| req.func_tool = ToolSet() | ||
| req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) | ||
|
|
||
| provider, allow_follow_up = _resolve_tool_capability_strategy( | ||
| provider=provider, | ||
| req=req, | ||
| plugin_context=plugin_context, | ||
| config=config, | ||
| ) | ||
|
|
||
| _modalities_fix(provider, req) | ||
| _sanitize_context_by_modalities(config, provider, req) | ||
|
|
||
| if provider.provider_config.get("max_context_tokens", 0) <= 0: | ||
| model = provider.get_model() | ||
| if model_info := LLM_METADATAS.get(model): | ||
|
|
@@ -1370,4 +1526,5 @@ async def build_main_agent( | |
| provider_request=req, | ||
| provider=provider, | ||
| reset_coro=reset_coro if not apply_reset else None, | ||
| allow_follow_up=allow_follow_up, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Guard against missing
tool_callkey inLLM_METADATASentries to avoidKeyError.Using
metadata["tool_call"]will raise aKeyErrorif that key is missing for any model. Prefermetadata.get("tool_call")(with a default likeNoneorFalse) before casting tobool, so the behavior matches themetadata is Nonecase instead of breaking the flow.