Skip to content

Commit e286da7

Browse files
he-yufengRC-CHN
andauthored
fix: 截断器丢失唯一 user 消息导致智谱等 provider 返回 400 (AstrBotDevs#6581)
* fix: 截断器丢失唯一 user 消息导致 API 400 修复 AstrBotDevs#6196 当对话只有一条 user 消息(长 tool chain 场景:system → user → assistant → tool → assistant → tool → ...),三个截断方法都会把这条 user 消息丢掉, 导致智谱、Gemini 等要求 user 消息的 provider 返回 400。 改动: - 提取 `_split_system_rest()` 去掉三个方法里重复的 system/non-system 拆分 - 新增 `_ensure_user_message()`:截断后如果没有 user 了,从原始消息里补回 第一条 user,避免违反 API 格式要求 - 删掉 `truncate_by_dropping_oldest_turns` 里把没有 user 就清空全部消息的逻辑 - 5 个新测试覆盖单 user + 长 tool chain 场景,3 个旧测试更新断言 * style: format code --------- Co-authored-by: Yufeng He <40085740+universeplayer@users.noreply.github.com> Co-authored-by: RC-CHN <1051989940@qq.com>
1 parent 7008a46 commit e286da7

2 files changed

Lines changed: 112 additions & 39 deletions

File tree

astrbot/core/agent/context/truncator.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,37 @@ def _has_tool_calls(self, message: Message) -> bool:
1212
and len(message.tool_calls) > 0
1313
)
1414

15+
@staticmethod
16+
def _split_system_rest(
17+
messages: list[Message],
18+
) -> tuple[list[Message], list[Message]]:
19+
"""把 system 消息和后面的对话消息分开。"""
20+
first_non_system = 0
21+
for i, msg in enumerate(messages):
22+
if msg.role != "system":
23+
first_non_system = i
24+
break
25+
return messages[:first_non_system], messages[first_non_system:]
26+
27+
@staticmethod
28+
def _ensure_user_message(
29+
system_messages: list[Message],
30+
truncated: list[Message],
31+
original_messages: list[Message],
32+
) -> list[Message]:
33+
"""截断后如果没有 user 消息了,从原始列表里把第一条 user 补回来。
34+
很多 provider (智谱、Gemini 等) 要求 system 之后必须紧跟 user,否则直接 400。
35+
"""
36+
if any(m.role == "user" for m in truncated):
37+
return system_messages + truncated
38+
39+
# 从原始消息里找第一条 user
40+
first_user = next((m for m in original_messages if m.role == "user"), None)
41+
if first_user is None:
42+
return system_messages + truncated
43+
44+
return system_messages + [first_user] + truncated
45+
1546
def fix_messages(self, messages: list[Message]) -> list[Message]:
1647
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。
1748
@@ -81,14 +112,7 @@ def truncate_by_turns(
81112
if keep_most_recent_turns == -1:
82113
return messages
83114

84-
first_non_system = 0
85-
for i, msg in enumerate(messages):
86-
if msg.role != "system":
87-
first_non_system = i
88-
break
89-
90-
system_messages = messages[:first_non_system]
91-
non_system_messages = messages[first_non_system:]
115+
system_messages, non_system_messages = self._split_system_rest(messages)
92116

93117
if len(non_system_messages) // 2 <= keep_most_recent_turns:
94118
return messages
@@ -99,16 +123,17 @@ def truncate_by_turns(
99123
else:
100124
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
101125

102-
# 找到第一个 role 为 user 的索引,确保上下文格式正确
126+
# 对齐到第一条 user 消息
103127
index = next(
104128
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
105129
None,
106130
)
107131
if index is not None and index > 0:
108132
truncated_contexts = truncated_contexts[index:]
109133

110-
result = system_messages + truncated_contexts
111-
134+
result = self._ensure_user_message(
135+
system_messages, truncated_contexts, messages
136+
)
112137
return self.fix_messages(result)
113138

114139
def truncate_by_dropping_oldest_turns(
@@ -120,31 +145,24 @@ def truncate_by_dropping_oldest_turns(
120145
if drop_turns <= 0:
121146
return messages
122147

123-
first_non_system = 0
124-
for i, msg in enumerate(messages):
125-
if msg.role != "system":
126-
first_non_system = i
127-
break
128-
129-
system_messages = messages[:first_non_system]
130-
non_system_messages = messages[first_non_system:]
148+
system_messages, non_system_messages = self._split_system_rest(messages)
131149

132150
if len(non_system_messages) // 2 <= drop_turns:
133151
truncated_non_system = []
134152
else:
135153
truncated_non_system = non_system_messages[drop_turns * 2 :]
136154

155+
# 对齐到第一条 user
137156
index = next(
138157
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
139158
None,
140159
)
141160
if index is not None:
142161
truncated_non_system = truncated_non_system[index:]
143-
elif truncated_non_system:
144-
truncated_non_system = []
145-
146-
result = system_messages + truncated_non_system
147162

163+
result = self._ensure_user_message(
164+
system_messages, truncated_non_system, messages
165+
)
148166
return self.fix_messages(result)
149167

150168
def truncate_by_halving(
@@ -155,28 +173,23 @@ def truncate_by_halving(
155173
if len(messages) <= 2:
156174
return messages
157175

158-
first_non_system = 0
159-
for i, msg in enumerate(messages):
160-
if msg.role != "system":
161-
first_non_system = i
162-
break
163-
164-
system_messages = messages[:first_non_system]
165-
non_system_messages = messages[first_non_system:]
176+
system_messages, non_system_messages = self._split_system_rest(messages)
166177

167178
messages_to_delete = len(non_system_messages) // 2
168179
if messages_to_delete == 0:
169180
return messages
170181

171182
truncated_non_system = non_system_messages[messages_to_delete:]
172183

184+
# 对齐到第一条 user
173185
index = next(
174186
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
175187
None,
176188
)
177189
if index is not None:
178190
truncated_non_system = truncated_non_system[index:]
179191

180-
result = system_messages + truncated_non_system
181-
192+
result = self._ensure_user_message(
193+
system_messages, truncated_non_system, messages
194+
)
182195
return self.fix_messages(result)

tests/agent/test_truncator.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def test_truncate_by_turns_zero_keep(self):
104104
messages, keep_most_recent_turns=0, drop_turns=1
105105
)
106106

107-
# Should result in empty or minimal list
108-
assert len(result) == 0
107+
# 截断后至少保留一条 user 消息 (#6196)
108+
assert len(result) >= 1
109+
assert result[0].role == "user"
109110

110111
def test_truncate_by_turns_below_threshold(self):
111112
"""Test truncate_by_turns when messages are below threshold."""
@@ -201,8 +202,9 @@ def test_truncate_by_dropping_oldest_turns_drop_all(self):
201202
messages = self.create_messages(4)
202203
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
203204

204-
# Should drop all turns
205-
assert len(result) == 0
205+
# 即使 drop 掉所有 turn,也会把 user 消息补回来 (#6196)
206+
assert len(result) >= 1
207+
assert result[0].role == "user"
206208

207209
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
208210
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
@@ -211,8 +213,9 @@ def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
211213
messages = self.create_messages(4)
212214
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
213215

214-
# Should result in empty list
215-
assert len(result) == 0
216+
# 同理,user 消息会被保留 (#6196)
217+
assert len(result) >= 1
218+
assert result[0].role == "user"
216219

217220
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
218221
"""Test that result starts with user message after dropping."""
@@ -372,3 +375,60 @@ def test_all_system_messages(self):
372375
assert len(result) >= 0 # May keep system messages or clear all
373376
if len(result) > 0:
374377
assert all(msg.role == "system" for msg in result)
378+
379+
# ==================== #6196: 长 tool chain 只有一条 user 消息 ====================
380+
381+
def _build_tool_chain(self, tool_rounds: int = 20) -> list[Message]:
382+
"""构造 system -> user -> (assistant -> tool) * N 的长链,只有一条 user。"""
383+
msgs = [
384+
self.create_message("system", "You are a helpful assistant."),
385+
self.create_message("user", "帮我查一下天气"),
386+
]
387+
for i in range(tool_rounds):
388+
msgs.append(self.create_message("assistant", f"调用工具 {i}"))
389+
msgs.append(self.create_message("tool", f"工具结果 {i}"))
390+
return msgs
391+
392+
def test_drop_oldest_preserves_sole_user(self):
393+
"""#6196: drop 1 turn 不应丢掉唯一的 user 消息。"""
394+
truncator = ContextTruncator()
395+
msgs = self._build_tool_chain(20) # 1 system + 1 user + 40 asst/tool = 42
396+
result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=1)
397+
roles = [m.role for m in result]
398+
assert "user" in roles, "唯一的 user 消息被丢掉了"
399+
assert roles[0] == "system"
400+
401+
def test_halving_preserves_sole_user(self):
402+
"""#6196: 对半砍不应丢掉唯一的 user 消息。"""
403+
truncator = ContextTruncator()
404+
msgs = self._build_tool_chain(20)
405+
result = truncator.truncate_by_halving(msgs)
406+
roles = [m.role for m in result]
407+
assert "user" in roles, "唯一的 user 消息被丢掉了"
408+
409+
def test_truncate_by_turns_preserves_sole_user(self):
410+
"""#6196: keep_most_recent_turns 也不应丢掉唯一的 user 消息。"""
411+
truncator = ContextTruncator()
412+
msgs = self._build_tool_chain(20)
413+
result = truncator.truncate_by_turns(
414+
msgs, keep_most_recent_turns=3, drop_turns=1
415+
)
416+
roles = [m.role for m in result]
417+
assert "user" in roles, "唯一的 user 消息被丢掉了"
418+
419+
def test_drop_oldest_heavy_drops_still_has_user(self):
420+
"""#6196: 大量 drop 也不会丢 user。"""
421+
truncator = ContextTruncator()
422+
msgs = self._build_tool_chain(30)
423+
result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=10)
424+
roles = [m.role for m in result]
425+
assert "user" in roles
426+
427+
def test_normal_multi_user_not_affected(self):
428+
"""正常多 user 对话不受影响。"""
429+
truncator = ContextTruncator()
430+
msgs = self.create_messages(20, include_system=True)
431+
result_before = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=2)
432+
# 多 user 场景下截断后仍有 user
433+
roles = [m.role for m in result_before]
434+
assert "user" in roles

0 commit comments

Comments
 (0)