|
4 | 4 | class ContextTruncator: |
5 | 5 | """Context truncator.""" |
6 | 6 |
|
| 7 | + def _has_tool_calls(self, message: Message) -> bool: |
| 8 | + """Check if a message contains tool calls.""" |
| 9 | + return ( |
| 10 | + message.role == "assistant" |
| 11 | + and message.tool_calls is not None |
| 12 | + and len(message.tool_calls) > 0 |
| 13 | + ) |
| 14 | + |
7 | 15 | def fix_messages(self, messages: list[Message]) -> list[Message]: |
8 | | - fixed_messages = [] |
9 | | - for message in messages: |
10 | | - if message.role == "tool": |
11 | | - # tool block 前面必须要有 user 和 assistant block |
12 | | - if len(fixed_messages) < 2: |
13 | | - # 这种情况可能是上下文被截断导致的 |
14 | | - # 我们直接将之前的上下文都清空 |
15 | | - fixed_messages = [] |
| 16 | + """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 |
| 17 | +
|
| 18 | + 此方法确保: |
| 19 | + 1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息 |
| 20 | + 2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应 |
| 21 | +
|
| 22 | + 这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。 |
| 23 | + """ |
| 24 | + if not messages: |
| 25 | + return messages |
| 26 | + |
| 27 | + # First pass: identify which assistant(tool_calls) have valid tool responses |
| 28 | + # Build a set of indices for assistant messages that have valid tool responses |
| 29 | + valid_tool_call_indices: set[int] = set() |
| 30 | + i = 0 |
| 31 | + while i < len(messages): |
| 32 | + msg = messages[i] |
| 33 | + if self._has_tool_calls(msg): |
| 34 | + # Check if next message(s) are tool responses |
| 35 | + j = i + 1 |
| 36 | + has_tool_response = False |
| 37 | + while j < len(messages) and messages[j].role == "tool": |
| 38 | + has_tool_response = True |
| 39 | + j += 1 |
| 40 | + if has_tool_response: |
| 41 | + valid_tool_call_indices.add(i) |
| 42 | + i += 1 |
| 43 | + |
| 44 | + # Second pass: build fixed message list |
| 45 | + fixed_messages: list[Message] = [] |
| 46 | + in_valid_tool_chain = False # 是否处于有效的 tool call 链中 |
| 47 | + |
| 48 | + for i, msg in enumerate(messages): |
| 49 | + if msg.role == "tool": |
| 50 | + # tool 消息:只有在有效的 tool call 链中才保留 |
| 51 | + if in_valid_tool_chain: |
| 52 | + fixed_messages.append(msg) |
| 53 | + # else: 孤立的 tool 消息,跳过 |
| 54 | + elif self._has_tool_calls(msg): |
| 55 | + # assistant(tool_calls):只保留有效的(后面有 tool response 的) |
| 56 | + if i in valid_tool_call_indices: |
| 57 | + fixed_messages.append(msg) |
| 58 | + in_valid_tool_chain = True # 进入有效的 tool call 链 |
16 | 59 | else: |
17 | | - fixed_messages.append(message) |
| 60 | + in_valid_tool_chain = False # 孤立的 tool_calls,跳过并重置状态 |
18 | 61 | else: |
19 | | - fixed_messages.append(message) |
| 62 | + # system, user, 或不含 tool_calls 的 assistant |
| 63 | + fixed_messages.append(msg) |
| 64 | + in_valid_tool_chain = False # 退出 tool call 链 |
| 65 | + |
20 | 66 | return fixed_messages |
21 | 67 |
|
22 | 68 | def truncate_by_turns( |
|
0 commit comments