Skip to content

Commit 0f309c8

Browse files
committed
perf: improve truncate algo
1 parent c589714 commit 0f309c8

1 file changed

Lines changed: 46 additions & 70 deletions

File tree

astrbot/core/agent/context/truncator.py

Lines changed: 46 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ def _has_tool_calls(self, message: Message) -> bool:
1212
and len(message.tool_calls) > 0
1313
)
1414

15-
def _split_system_and_rest(
16-
self, messages: list[Message]
15+
@staticmethod
16+
def _split_system_rest(
17+
messages: list[Message],
1718
) -> tuple[list[Message], list[Message]]:
1819
"""Split messages into system messages and the rest.
1920
@@ -25,66 +26,36 @@ def _split_system_and_rest(
2526
if msg.role != "system":
2627
first_non_system = i
2728
break
28-
2929
return messages[:first_non_system], messages[first_non_system:]
3030

31-
def _ensure_first_user_message(
32-
self,
31+
@staticmethod
32+
def _ensure_user_message(
3333
system_messages: list[Message],
34-
non_system_messages: list[Message],
34+
truncated: list[Message],
3535
original_messages: list[Message],
3636
) -> list[Message]:
3737
"""Ensure the result always contains the first user message right after
3838
system messages. This is required by many LLM APIs (e.g. Zhipu) that
3939
mandate a ``user`` message immediately following the ``system`` message.
40-
41-
If the truncated ``non_system_messages`` already starts with a ``user``
42-
message, the list is returned as-is (with ``fix_messages`` applied).
43-
Otherwise the first ``user`` message from the *original* full message
44-
list is located and prepended.
45-
46-
Args:
47-
system_messages: The system messages extracted earlier.
48-
non_system_messages: The truncated non-system messages.
49-
original_messages: The full, untruncated message list (used to
50-
locate the original first ``user`` message when it has been
51-
removed by truncation).
52-
53-
Returns:
54-
A well-formed message list: ``system + [first_user +] rest``.
5540
"""
56-
# Fast path: already starts with a user message – nothing to fix.
57-
if non_system_messages and non_system_messages[0].role == "user":
58-
return self.fix_messages(system_messages + non_system_messages)
41+
if truncated and truncated[0].role == "user":
42+
return system_messages + truncated
5943

6044
# Locate the first user message from the *original* list.
61-
first_user_msg: Message | None = None
62-
for msg in original_messages:
63-
if msg.role == "user":
64-
first_user_msg = msg
65-
break
45+
first_user = next((m for m in original_messages if m.role == "user"), None)
46+
if first_user is None:
47+
return system_messages + truncated
6648

67-
if first_user_msg is None:
68-
# Degenerate case: no user message exists at all.
69-
return self.fix_messages(system_messages + non_system_messages)
70-
71-
# Avoid duplicate: if the located message is already in the truncated
72-
# list (identity check), don't prepend again.
73-
if any(m is first_user_msg for m in non_system_messages):
74-
return self.fix_messages(system_messages + non_system_messages)
75-
76-
# Prepend the first user message so the sequence is valid.
77-
result = system_messages + [first_user_msg] + non_system_messages
78-
return self.fix_messages(result)
49+
return system_messages + [first_user] + truncated
7950

8051
def fix_messages(self, messages: list[Message]) -> list[Message]:
81-
"""修复消息列表,确保 tool call tool response 的配对关系有效。
52+
"""Fix the message list to ensure the validity of tool call and tool response pairing.
8253
83-
此方法确保:
84-
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
85-
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
54+
This method ensures that:
55+
1. Each `tool` message is preceded by an `assistant` message containing `tool_calls`.
56+
2. Each `assistant` message containing `tool_calls` is followed by corresponding `
8657
87-
这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。
58+
This is a requirement of the OpenAI Chat Completions API specification (Gemini enforces this strictly).
8859
"""
8960
if not messages:
9061
return messages
@@ -103,24 +74,25 @@ def flush_pending_if_valid() -> None:
10374

10475
for msg in messages:
10576
if msg.role == "tool":
106-
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
77+
# Only record tool responses when there is a pending assistant(tool_calls)
10778
if pending_assistant is not None:
10879
pending_tools.append(msg)
109-
# else: 孤立的 tool 消息,直接忽略
80+
# Isolated tool messages without a preceding assistant(tool_calls) are ignored
11081
continue
11182

11283
if self._has_tool_calls(msg):
113-
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending
84+
# When encountering a new assistant(tool_calls), first process the old pending chain
11485
flush_pending_if_valid()
11586
pending_assistant = msg
11687
continue
11788

118-
# tool,且不含 tool_calls 的消息
119-
# 先结束任何 pending 链,再正常追加
89+
# Non-tool messages that do not contain tool_calls will break the pending chain.
90+
# Flush any pending chain first, then append the current message normally.
12091
flush_pending_if_valid()
12192
fixed_messages.append(msg)
12293

123-
# 结束时处理最后一个 pending 链
94+
# Flush the last pending chain at the end,
95+
# ensuring that any remaining valid assistant(tool_calls) and its tools are included in the final list.
12496
flush_pending_if_valid()
12597

12698
return fixed_messages
@@ -131,22 +103,23 @@ def truncate_by_turns(
131103
keep_most_recent_turns: int,
132104
drop_turns: int = 1,
133105
) -> list[Message]:
134-
"""截断上下文列表,确保不超过最大长度。
135-
一个 turn 包含一个 user 消息和一个 assistant 消息。
136-
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
106+
"""
107+
Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
108+
A turn consists of a user message and an assistant message.
109+
This method ensures that the truncated context list conforms to OpenAI's context format.
137110
138111
Args:
139-
messages: 上下文列表
140-
keep_most_recent_turns: 保留最近的对话轮数
141-
drop_turns: 一次性丢弃的对话轮数
112+
messages: The original list of messages in the context.
113+
keep_most_recent_turns: The number of most recent turns to keep. If set to -1, it means keeping all turns (no truncation).
114+
drop_turns: The number of turns to drop from the beginning.
142115
143116
Returns:
144-
截断后的上下文列表
117+
The truncated list of messages.
145118
"""
146119
if keep_most_recent_turns == -1:
147120
return messages
148121

149-
system_messages, non_system_messages = self._split_system_and_rest(messages)
122+
system_messages, non_system_messages = self._split_system_rest(messages)
150123

151124
if len(non_system_messages) // 2 <= keep_most_recent_turns:
152125
return messages
@@ -157,70 +130,73 @@ def truncate_by_turns(
157130
else:
158131
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
159132

160-
# 找到第一个 role 为 user 的索引,确保上下文格式正确
133+
# Find the first user message
161134
index = next(
162135
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
163136
None,
164137
)
165138
if index is not None and index > 0:
166139
truncated_contexts = truncated_contexts[index:]
167140

168-
return self._ensure_first_user_message(
141+
result = self._ensure_user_message(
169142
system_messages, truncated_contexts, messages
170143
)
144+
return self.fix_messages(result)
171145

172146
def truncate_by_dropping_oldest_turns(
173147
self,
174148
messages: list[Message],
175149
drop_turns: int = 1,
176150
) -> list[Message]:
177-
"""丢弃最旧的 N 个对话轮次。"""
151+
"""Drop the oldest N turns, regardless of the number of turns to keep."""
178152
if drop_turns <= 0:
179153
return messages
180154

181-
system_messages, non_system_messages = self._split_system_and_rest(messages)
155+
system_messages, non_system_messages = self._split_system_rest(messages)
182156

183157
if len(non_system_messages) // 2 <= drop_turns:
184158
truncated_non_system = []
185159
else:
186160
truncated_non_system = non_system_messages[drop_turns * 2 :]
187161

162+
# Find the first user message
188163
index = next(
189164
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
190165
None,
191166
)
192167
if index is not None:
193168
truncated_non_system = truncated_non_system[index:]
194-
elif truncated_non_system:
195-
truncated_non_system = []
196169

197-
return self._ensure_first_user_message(
170+
result = self._ensure_user_message(
198171
system_messages, truncated_non_system, messages
199172
)
173+
return self.fix_messages(result)
200174

201175
def truncate_by_halving(
202176
self,
203177
messages: list[Message],
204178
) -> list[Message]:
205-
"""对半砍策略,删除 50% 的消息"""
179+
"""Halve the number of messages, keeping the most recent ones."""
206180
if len(messages) <= 2:
207181
return messages
208182

209-
system_messages, non_system_messages = self._split_system_and_rest(messages)
183+
system_messages, non_system_messages = self._split_system_rest(messages)
210184

211185
messages_to_delete = len(non_system_messages) // 2
212186
if messages_to_delete == 0:
213187
return messages
214188

215189
truncated_non_system = non_system_messages[messages_to_delete:]
216190

191+
# Find the first user message
217192
index = next(
218193
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
219194
None,
220195
)
221196
if index is not None:
222197
truncated_non_system = truncated_non_system[index:]
223198

224-
return self._ensure_first_user_message(
199+
result = self._ensure_user_message(
225200
system_messages, truncated_non_system, messages
226201
)
202+
return self.fix_messages(result)

0 commit comments

Comments
 (0)