Skip to content

Commit 33d93ae

Browse files
committed
fix: ensure tool call/response pairing in context truncation
1 parent c24de24 commit 33d93ae

1 file changed

Lines changed: 56 additions & 10 deletions

File tree

astrbot/core/agent/context/truncator.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,65 @@
44
class ContextTruncator:
55
"""Context truncator."""
66

7+
def _has_tool_calls(self, message: Message) -> bool:
8+
"""Check if a message contains tool calls."""
9+
return (
10+
message.role == "assistant"
11+
and message.tool_calls is not None
12+
and len(message.tool_calls) > 0
13+
)
14+
715
def fix_messages(self, messages: list[Message]) -> list[Message]:
8-
fixed_messages = []
9-
for message in messages:
10-
if message.role == "tool":
11-
# tool block 前面必须要有 user 和 assistant block
12-
if len(fixed_messages) < 2:
13-
# 这种情况可能是上下文被截断导致的
14-
# 我们直接将之前的上下文都清空
15-
fixed_messages = []
16+
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。
17+
18+
此方法确保:
19+
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
20+
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
21+
22+
这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。
23+
"""
24+
if not messages:
25+
return messages
26+
27+
# First pass: identify which assistant(tool_calls) have valid tool responses
28+
# Build a set of indices for assistant messages that have valid tool responses
29+
valid_tool_call_indices: set[int] = set()
30+
i = 0
31+
while i < len(messages):
32+
msg = messages[i]
33+
if self._has_tool_calls(msg):
34+
# Check if next message(s) are tool responses
35+
j = i + 1
36+
has_tool_response = False
37+
while j < len(messages) and messages[j].role == "tool":
38+
has_tool_response = True
39+
j += 1
40+
if has_tool_response:
41+
valid_tool_call_indices.add(i)
42+
i += 1
43+
44+
# Second pass: build fixed message list
45+
fixed_messages: list[Message] = []
46+
in_valid_tool_chain = False # 是否处于有效的 tool call 链中
47+
48+
for i, msg in enumerate(messages):
49+
if msg.role == "tool":
50+
# tool 消息:只有在有效的 tool call 链中才保留
51+
if in_valid_tool_chain:
52+
fixed_messages.append(msg)
53+
# else: 孤立的 tool 消息,跳过
54+
elif self._has_tool_calls(msg):
55+
# assistant(tool_calls):只保留有效的(后面有 tool response 的)
56+
if i in valid_tool_call_indices:
57+
fixed_messages.append(msg)
58+
in_valid_tool_chain = True # 进入有效的 tool call 链
1659
else:
17-
fixed_messages.append(message)
60+
in_valid_tool_chain = False # 孤立的 tool_calls,跳过并重置状态
1861
else:
19-
fixed_messages.append(message)
62+
# system, user, 或不含 tool_calls 的 assistant
63+
fixed_messages.append(msg)
64+
in_valid_tool_chain = False # 退出 tool call 链
65+
2066
return fixed_messages
2167

2268
def truncate_by_turns(

0 commit comments

Comments
 (0)