@@ -12,8 +12,9 @@ def _has_tool_calls(self, message: Message) -> bool:
1212 and len (message .tool_calls ) > 0
1313 )
1414
15- def _split_system_and_rest (
16- self , messages : list [Message ]
15+ @staticmethod
16+ def _split_system_rest (
17+ messages : list [Message ],
1718 ) -> tuple [list [Message ], list [Message ]]:
1819 """Split messages into system messages and the rest.
1920
@@ -25,66 +26,36 @@ def _split_system_and_rest(
2526 if msg .role != "system" :
2627 first_non_system = i
2728 break
28-
2929 return messages [:first_non_system ], messages [first_non_system :]
3030
31- def _ensure_first_user_message (
32- self ,
31+ @ staticmethod
32+ def _ensure_user_message (
3333 system_messages : list [Message ],
34- non_system_messages : list [Message ],
34+ truncated : list [Message ],
3535 original_messages : list [Message ],
3636 ) -> list [Message ]:
3737 """Ensure the result always contains the first user message right after
3838 system messages. This is required by many LLM APIs (e.g. Zhipu) that
3939 mandate a ``user`` message immediately following the ``system`` message.
40-
41- If the truncated ``non_system_messages`` already starts with a ``user``
42- message, the list is returned as-is (with ``fix_messages`` applied).
43- Otherwise the first ``user`` message from the *original* full message
44- list is located and prepended.
45-
46- Args:
47- system_messages: The system messages extracted earlier.
48- non_system_messages: The truncated non-system messages.
49- original_messages: The full, untruncated message list (used to
50- locate the original first ``user`` message when it has been
51- removed by truncation).
52-
53- Returns:
54- A well-formed message list: ``system + [first_user +] rest``.
5540 """
56- # Fast path: already starts with a user message – nothing to fix.
57- if non_system_messages and non_system_messages [0 ].role == "user" :
58- return self .fix_messages (system_messages + non_system_messages )
41+ if truncated and truncated [0 ].role == "user" :
42+ return system_messages + truncated
5943
6044 # Locate the first user message from the *original* list.
61- first_user_msg : Message | None = None
62- for msg in original_messages :
63- if msg .role == "user" :
64- first_user_msg = msg
65- break
45+ first_user = next ((m for m in original_messages if m .role == "user" ), None )
46+ if first_user is None :
47+ return system_messages + truncated
6648
67- if first_user_msg is None :
68- # Degenerate case: no user message exists at all.
69- return self .fix_messages (system_messages + non_system_messages )
70-
71- # Avoid duplicate: if the located message is already in the truncated
72- # list (identity check), don't prepend again.
73- if any (m is first_user_msg for m in non_system_messages ):
74- return self .fix_messages (system_messages + non_system_messages )
75-
76- # Prepend the first user message so the sequence is valid.
77- result = system_messages + [first_user_msg ] + non_system_messages
78- return self .fix_messages (result )
49+ return system_messages + [first_user ] + truncated
7950
8051 def fix_messages (self , messages : list [Message ]) -> list [Message ]:
81- """修复消息列表,确保 tool call 和 tool response 的配对关系有效。
52+ """Fix the message list to ensure the validity of tool call and tool response pairing.
8253
83- 此方法确保:
84- 1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
85- 2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
54+ This method ensures that:
55+ 1. Each `tool` message is preceded by an `assistant` message containing `tool_calls`.
56+ 2. Each `assistant` message containing `tool_calls` is followed by corresponding `
8657
87- 这是 OpenAI Chat Completions API 规范的要求( Gemini 对此执行严格检查)。
58+ This is a requirement of the OpenAI Chat Completions API specification ( Gemini enforces this strictly).
8859 """
8960 if not messages :
9061 return messages
@@ -103,24 +74,25 @@ def flush_pending_if_valid() -> None:
10374
10475 for msg in messages :
10576 if msg .role == "tool" :
106- # 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
77+ # Only record tool responses when there is a pending assistant(tool_calls)
10778 if pending_assistant is not None :
10879 pending_tools .append (msg )
109- # else: 孤立的 tool 消息,直接忽略
80+ # Isolated tool messages without a preceding assistant(tool_calls) are ignored
11081 continue
11182
11283 if self ._has_tool_calls (msg ):
113- # 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链
84+ # When encountering a new assistant(tool_calls), first process the old pending chain
11485 flush_pending_if_valid ()
11586 pending_assistant = msg
11687 continue
11788
118- # 非 tool,且不含 tool_calls 的消息
119- # 先结束任何 pending 链,再正常追加
89+ # Non- tool messages that do not contain tool_calls will break the pending chain.
90+ # Flush any pending chain first, then append the current message normally.
12091 flush_pending_if_valid ()
12192 fixed_messages .append (msg )
12293
123- # 结束时处理最后一个 pending 链
94+ # Flush the last pending chain at the end,
95+ # ensuring that any remaining valid assistant(tool_calls) and its tools are included in the final list.
12496 flush_pending_if_valid ()
12597
12698 return fixed_messages
@@ -131,22 +103,23 @@ def truncate_by_turns(
131103 keep_most_recent_turns : int ,
132104 drop_turns : int = 1 ,
133105 ) -> list [Message ]:
134- """截断上下文列表,确保不超过最大长度。
135- 一个 turn 包含一个 user 消息和一个 assistant 消息。
136- 这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
106+ """
107+ Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
108+ A turn consists of a user message and an assistant message.
109+ This method ensures that the truncated context list conforms to OpenAI's context format.
137110
138111 Args:
139- messages: 上下文列表
140- keep_most_recent_turns: 保留最近的对话轮数
141- drop_turns: 一次性丢弃的对话轮数
112+ messages: The original list of messages in the context.
113+ keep_most_recent_turns: The number of most recent turns to keep. If set to -1, it means keeping all turns (no truncation).
114+ drop_turns: The number of turns to drop from the beginning.
142115
143116 Returns:
144- 截断后的上下文列表
117+ The truncated list of messages.
145118 """
146119 if keep_most_recent_turns == - 1 :
147120 return messages
148121
149- system_messages , non_system_messages = self ._split_system_and_rest (messages )
122+ system_messages , non_system_messages = self ._split_system_rest (messages )
150123
151124 if len (non_system_messages ) // 2 <= keep_most_recent_turns :
152125 return messages
@@ -157,70 +130,73 @@ def truncate_by_turns(
157130 else :
158131 truncated_contexts = non_system_messages [- num_to_keep * 2 :]
159132
160- # 找到第一个 role 为 user 的索引,确保上下文格式正确
133+ # Find the first user message
161134 index = next (
162135 (i for i , item in enumerate (truncated_contexts ) if item .role == "user" ),
163136 None ,
164137 )
165138 if index is not None and index > 0 :
166139 truncated_contexts = truncated_contexts [index :]
167140
168- return self ._ensure_first_user_message (
141+ result = self ._ensure_user_message (
169142 system_messages , truncated_contexts , messages
170143 )
144+ return self .fix_messages (result )
171145
172146 def truncate_by_dropping_oldest_turns (
173147 self ,
174148 messages : list [Message ],
175149 drop_turns : int = 1 ,
176150 ) -> list [Message ]:
177- """丢弃最旧的 N 个对话轮次。 """
151+ """Drop the oldest N turns, regardless of the number of turns to keep. """
178152 if drop_turns <= 0 :
179153 return messages
180154
181- system_messages , non_system_messages = self ._split_system_and_rest (messages )
155+ system_messages , non_system_messages = self ._split_system_rest (messages )
182156
183157 if len (non_system_messages ) // 2 <= drop_turns :
184158 truncated_non_system = []
185159 else :
186160 truncated_non_system = non_system_messages [drop_turns * 2 :]
187161
162+ # Find the first user message
188163 index = next (
189164 (i for i , item in enumerate (truncated_non_system ) if item .role == "user" ),
190165 None ,
191166 )
192167 if index is not None :
193168 truncated_non_system = truncated_non_system [index :]
194- elif truncated_non_system :
195- truncated_non_system = []
196169
197- return self ._ensure_first_user_message (
170+ result = self ._ensure_user_message (
198171 system_messages , truncated_non_system , messages
199172 )
173+ return self .fix_messages (result )
200174
201175 def truncate_by_halving (
202176 self ,
203177 messages : list [Message ],
204178 ) -> list [Message ]:
205- """对半砍策略,删除 50% 的消息 """
179+ """Halve the number of messages, keeping the most recent ones. """
206180 if len (messages ) <= 2 :
207181 return messages
208182
209- system_messages , non_system_messages = self ._split_system_and_rest (messages )
183+ system_messages , non_system_messages = self ._split_system_rest (messages )
210184
211185 messages_to_delete = len (non_system_messages ) // 2
212186 if messages_to_delete == 0 :
213187 return messages
214188
215189 truncated_non_system = non_system_messages [messages_to_delete :]
216190
191+ # Find the first user message
217192 index = next (
218193 (i for i , item in enumerate (truncated_non_system ) if item .role == "user" ),
219194 None ,
220195 )
221196 if index is not None :
222197 truncated_non_system = truncated_non_system [index :]
223198
224- return self ._ensure_first_user_message (
199+ result = self ._ensure_user_message (
225200 system_messages , truncated_non_system , messages
226201 )
202+ return self .fix_messages (result )
0 commit comments