Skip to content

Commit 7db7f4a

Browse files
authored
feat(agent-runner): add tool_choice parameter to fix empty tool calls response in "skills-like" tool call mode (#7101)
fixes: #7049
1 parent 77419e0 commit 7db7f4a

File tree

5 files changed

+142
-31
lines changed

5 files changed

+142
-31
lines changed

astrbot/core/agent/runners/tool_loop_agent_runner.py

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,32 @@ def _get_persona_custom_error_message(self) -> str | None:
100100
event = getattr(self.run_context.context, "event", None)
101101
return extract_persona_custom_error_message_from_event(event)
102102

103+
async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None:
104+
"""Finalize the current step as a plain assistant response with no tool calls."""
105+
self.final_llm_resp = llm_resp
106+
self._transition_state(AgentState.DONE)
107+
self.stats.end_time = time.time()
108+
109+
parts = []
110+
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
111+
parts.append(
112+
ThinkPart(
113+
think=llm_resp.reasoning_content,
114+
encrypted=llm_resp.reasoning_signature,
115+
)
116+
)
117+
if llm_resp.completion_text:
118+
parts.append(TextPart(text=llm_resp.completion_text))
119+
if len(parts) == 0:
120+
logger.warning("LLM returned empty assistant message with no tool calls.")
121+
self.run_context.messages.append(Message(role="assistant", content=parts))
122+
123+
try:
124+
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
125+
except Exception as e:
126+
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
127+
self._resolve_unconsumed_follow_ups()
128+
103129
@override
104130
async def reset(
105131
self,
@@ -463,34 +489,7 @@ async def step(self):
463489
return
464490

465491
if not llm_resp.tools_call_name:
466-
# 如果没有工具调用,转换到完成状态
467-
self.final_llm_resp = llm_resp
468-
self._transition_state(AgentState.DONE)
469-
self.stats.end_time = time.time()
470-
471-
# record the final assistant message
472-
parts = []
473-
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
474-
parts.append(
475-
ThinkPart(
476-
think=llm_resp.reasoning_content,
477-
encrypted=llm_resp.reasoning_signature,
478-
)
479-
)
480-
if llm_resp.completion_text:
481-
parts.append(TextPart(text=llm_resp.completion_text))
482-
if len(parts) == 0:
483-
logger.warning(
484-
"LLM returned empty assistant message with no tool calls."
485-
)
486-
self.run_context.messages.append(Message(role="assistant", content=parts))
487-
488-
# call the on_agent_done hook
489-
try:
490-
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
491-
except Exception as e:
492-
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
493-
self._resolve_unconsumed_follow_ups()
492+
await self._complete_with_assistant_response(llm_resp)
494493

495494
# 返回 LLM 结果
496495
if llm_resp.result_chain:
@@ -510,6 +509,24 @@ async def step(self):
510509
if llm_resp.tools_call_name:
511510
if self.tool_schema_mode == "skills_like":
512511
llm_resp, _ = await self._resolve_tool_exec(llm_resp)
512+
if not llm_resp.tools_call_name:
513+
logger.warning(
514+
"skills_like tool re-query returned no tool calls; fallback to assistant response."
515+
)
516+
if llm_resp.result_chain:
517+
yield AgentResponse(
518+
type="llm_result",
519+
data=AgentResponseData(chain=llm_resp.result_chain),
520+
)
521+
elif llm_resp.completion_text:
522+
yield AgentResponse(
523+
type="llm_result",
524+
data=AgentResponseData(
525+
chain=MessageChain().message(llm_resp.completion_text),
526+
),
527+
)
528+
await self._complete_with_assistant_response(llm_resp)
529+
return
513530

514531
tool_call_result_blocks = []
515532
cached_images = [] # Collect cached images for LLM visibility
@@ -873,7 +890,9 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
873890
)
874891

875892
def _build_tool_requery_context(
876-
self, tool_names: list[str]
893+
self,
894+
tool_names: list[str],
895+
extra_instruction: str | None = None,
877896
) -> list[dict[str, T.Any]]:
878897
"""Build contexts for re-querying LLM with param-only tool schemas."""
879898
contexts: list[dict[str, T.Any]] = []
@@ -888,13 +907,20 @@ def _build_tool_requery_context(
888907
+ ". Now call the tool(s) with required arguments using the tool schema, "
889908
"and follow the existing tool-use rules."
890909
)
910+
if extra_instruction:
911+
instruction = f"{instruction}\n{extra_instruction}"
891912
if contexts and contexts[0].get("role") == "system":
892913
content = contexts[0].get("content") or ""
893914
contexts[0]["content"] = f"{content}\n{instruction}"
894915
else:
895916
contexts.insert(0, {"role": "system", "content": instruction})
896917
return contexts
897918

919+
@staticmethod
920+
def _has_meaningful_assistant_reply(llm_resp: LLMResponse) -> bool:
921+
text = (llm_resp.completion_text or "").strip()
922+
return bool(text)
923+
898924
def _build_tool_subset(self, tool_set: ToolSet, tool_names: list[str]) -> ToolSet:
899925
"""Build a subset of tools from the given tool set based on tool names."""
900926
subset = ToolSet()
@@ -932,11 +958,45 @@ async def _resolve_tool_exec(
932958
model=self.req.model,
933959
session_id=self.req.session_id,
934960
extra_user_content_parts=self.req.extra_user_content_parts,
961+
tool_choice="required",
935962
abort_signal=self._abort_signal,
936963
)
937964
if requery_resp:
938965
llm_resp = requery_resp
939966

967+
# If the re-query still returns no tool calls, and also does not have a meaningful assistant reply,
968+
# we consider it as a failure of the LLM to follow the tool-use instruction,
969+
# and we will retry once with a stronger instruction that explicitly requires the LLM to either call the tool or give an explanation.
970+
if (
971+
not llm_resp.tools_call_name
972+
and not self._has_meaningful_assistant_reply(llm_resp)
973+
):
974+
logger.warning(
975+
"skills_like tool re-query returned no tool calls and no explanation; retrying with stronger instruction."
976+
)
977+
repair_contexts = self._build_tool_requery_context(
978+
tool_names,
979+
extra_instruction=(
980+
"This is the second-stage tool execution step. "
981+
"You must do exactly one of the following: "
982+
"1. Call one of the selected tools using the provided tool schema. "
983+
"2. If calling a tool is no longer possible or appropriate, reply to the user with a brief explanation of why. "
984+
"Do not return an empty response. "
985+
"Do not ignore the selected tools without explanation."
986+
),
987+
)
988+
repair_resp = await self.provider.text_chat(
989+
contexts=repair_contexts,
990+
func_tool=param_subset,
991+
model=self.req.model,
992+
session_id=self.req.session_id,
993+
extra_user_content_parts=self.req.extra_user_content_parts,
994+
tool_choice="required",
995+
abort_signal=self._abort_signal,
996+
)
997+
if repair_resp:
998+
llm_resp = repair_resp
999+
9401000
return llm_resp, subset
9411001

9421002
def done(self) -> bool:

astrbot/core/provider/provider.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
import os
44
from collections.abc import AsyncGenerator
5-
from typing import TypeAlias, Union
5+
from typing import Literal, TypeAlias, Union
66

77
from astrbot.core.agent.message import ContentPart, Message
88
from astrbot.core.agent.tool import ToolSet
@@ -104,6 +104,7 @@ async def text_chat(
104104
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
105105
model: str | None = None,
106106
extra_user_content_parts: list[ContentPart] | None = None,
107+
tool_choice: Literal["auto", "required"] = "auto",
107108
**kwargs,
108109
) -> LLMResponse:
109110
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -113,6 +114,7 @@ async def text_chat(
113114
session_id: 会话 ID(此属性已经被废弃)
114115
image_urls: 图片 URL 列表
115116
tools: tool set
117+
tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具
116118
contexts: 上下文,和 prompt 二选一使用
117119
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
118120
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
@@ -135,6 +137,7 @@ async def text_chat_stream(
135137
system_prompt: str | None = None,
136138
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
137139
model: str | None = None,
140+
tool_choice: Literal["auto", "required"] = "auto",
138141
**kwargs,
139142
) -> AsyncGenerator[LLMResponse, None]:
140143
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -144,6 +147,7 @@ async def text_chat_stream(
144147
session_id: 会话 ID(此属性已经被废弃)
145148
image_urls: 图片 URL 列表
146149
tools: tool set
150+
tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具
147151
contexts: 上下文,和 prompt 二选一使用
148152
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
149153
kwargs: 其他参数

astrbot/core/provider/sources/anthropic_source.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import json
33
from collections.abc import AsyncGenerator
4+
from typing import Literal
45

56
import anthropic
67
import httpx
@@ -258,6 +259,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
258259
if tools:
259260
if tool_list := tools.get_func_desc_anthropic_style():
260261
payloads["tools"] = tool_list
262+
payloads["tool_choice"] = {
263+
"type": "any"
264+
if payloads.get("tool_choice") == "required"
265+
else "auto"
266+
}
261267

262268
extra_body = self.provider_config.get("custom_extra_body", {})
263269

@@ -334,6 +340,11 @@ async def _query_stream(
334340
if tools:
335341
if tool_list := tools.get_func_desc_anthropic_style():
336342
payloads["tools"] = tool_list
343+
payloads["tool_choice"] = {
344+
"type": "any"
345+
if payloads.get("tool_choice") == "required"
346+
else "auto"
347+
}
337348

338349
# 用于累积工具调用信息
339350
tool_use_buffer = {}
@@ -483,6 +494,7 @@ async def text_chat(
483494
tool_calls_result=None,
484495
model=None,
485496
extra_user_content_parts=None,
497+
tool_choice: Literal["auto", "required"] = "auto",
486498
**kwargs,
487499
) -> LLMResponse:
488500
if contexts is None:
@@ -516,6 +528,8 @@ async def text_chat(
516528
model = model or self.get_model()
517529

518530
payloads = {"messages": new_messages, "model": model}
531+
if func_tool and not func_tool.empty():
532+
payloads["tool_choice"] = tool_choice
519533

520534
# Anthropic has a different way of handling system prompts
521535
if system_prompt:
@@ -540,6 +554,7 @@ async def text_chat_stream(
540554
tool_calls_result=None,
541555
model=None,
542556
extra_user_content_parts=None,
557+
tool_choice: Literal["auto", "required"] = "auto",
543558
**kwargs,
544559
):
545560
if contexts is None:
@@ -572,6 +587,8 @@ async def text_chat_stream(
572587
model = model or self.get_model()
573588

574589
payloads = {"messages": new_messages, "model": model}
590+
if func_tool and not func_tool.empty():
591+
payloads["tool_choice"] = tool_choice
575592

576593
# Anthropic has a different way of handling system prompts
577594
if system_prompt:

astrbot/core/provider/sources/gemini_source.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import random
66
from collections.abc import AsyncGenerator
7-
from typing import cast
7+
from typing import Literal, cast
88

99
from google import genai
1010
from google.genai import types
@@ -131,6 +131,7 @@ async def _prepare_query_config(
131131
self,
132132
payloads: dict,
133133
tools: ToolSet | None = None,
134+
tool_choice: Literal["auto", "required"] = "auto",
134135
system_instruction: str | None = None,
135136
modalities: list[str] | None = None,
136137
temperature: float = 0.7,
@@ -207,6 +208,18 @@ async def _prepare_query_config(
207208
types.Tool(function_declarations=func_desc["function_declarations"]),
208209
]
209210

211+
tool_config = None
212+
if tools and tool_list:
213+
tool_config = types.ToolConfig(
214+
function_calling_config=types.FunctionCallingConfig(
215+
mode=(
216+
types.FunctionCallingConfigMode.ANY
217+
if tool_choice == "required"
218+
else types.FunctionCallingConfigMode.AUTO
219+
)
220+
)
221+
)
222+
210223
# oper thinking config
211224
thinking_config = None
212225
if model_name in [
@@ -272,6 +285,7 @@ async def _prepare_query_config(
272285
seed=payloads.get("seed"),
273286
response_modalities=modalities,
274287
tools=cast(types.ToolListUnion | None, tool_list),
288+
tool_config=tool_config,
275289
safety_settings=self.safety_settings if self.safety_settings else None,
276290
thinking_config=thinking_config,
277291
automatic_function_calling=types.AutomaticFunctionCallingConfig(
@@ -535,6 +549,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
535549
config = await self._prepare_query_config(
536550
payloads,
537551
tools,
552+
payloads.get("tool_choice", "auto"),
538553
system_instruction,
539554
modalities,
540555
temperature,
@@ -616,6 +631,7 @@ async def _query_stream(
616631
config = await self._prepare_query_config(
617632
payloads,
618633
tools,
634+
payloads.get("tool_choice", "auto"),
619635
system_instruction,
620636
)
621637
result = await self.client.models.generate_content_stream(
@@ -728,6 +744,7 @@ async def text_chat(
728744
tool_calls_result=None,
729745
model=None,
730746
extra_user_content_parts=None,
747+
tool_choice: Literal["auto", "required"] = "auto",
731748
**kwargs,
732749
) -> LLMResponse:
733750
if contexts is None:
@@ -758,6 +775,8 @@ async def text_chat(
758775
model = model or self.get_model()
759776

760777
payloads = {"messages": context_query, "model": model}
778+
if func_tool and not func_tool.empty():
779+
payloads["tool_choice"] = tool_choice
761780

762781
retry = 10
763782
keys = self.api_keys.copy()
@@ -783,6 +802,7 @@ async def text_chat_stream(
783802
tool_calls_result=None,
784803
model=None,
785804
extra_user_content_parts=None,
805+
tool_choice: Literal["auto", "required"] = "auto",
786806
**kwargs,
787807
) -> AsyncGenerator[LLMResponse, None]:
788808
if contexts is None:
@@ -813,6 +833,8 @@ async def text_chat_stream(
813833
model = model or self.get_model()
814834

815835
payloads = {"messages": context_query, "model": model}
836+
if func_tool and not func_tool.empty():
837+
payloads["tool_choice"] = tool_choice
816838

817839
retry = 10
818840
keys = self.api_keys.copy()

0 commit comments

Comments
 (0)