|
4 | 4 | class ContextTruncator: |
5 | 5 | """Context truncator.""" |
6 | 6 |
|
| 7 | + def _has_tool_calls(self, message: Message) -> bool: |
| 8 | + """Check if a message contains tool calls.""" |
| 9 | + return ( |
| 10 | + message.role == "assistant" |
| 11 | + and message.tool_calls is not None |
| 12 | + and len(message.tool_calls) > 0 |
| 13 | + ) |
| 14 | + |
7 | 15 | def fix_messages(self, messages: list[Message]) -> list[Message]: |
8 | | - fixed_messages = [] |
9 | | - for message in messages: |
10 | | - if message.role == "tool": |
11 | | - # tool block 前面必须要有 user 和 assistant block |
12 | | - if len(fixed_messages) < 2: |
13 | | - # 这种情况可能是上下文被截断导致的 |
14 | | - # 我们直接将之前的上下文都清空 |
15 | | - fixed_messages = [] |
16 | | - else: |
17 | | - fixed_messages.append(message) |
18 | | - else: |
19 | | - fixed_messages.append(message) |
| 16 | + """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 |
| 17 | +
|
| 18 | + 此方法确保: |
| 19 | + 1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息 |
| 20 | + 2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应 |
| 21 | +
|
| 22 | + 这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。 |
| 23 | + """ |
| 24 | + if not messages: |
| 25 | + return messages |
| 26 | + |
| 27 | + fixed_messages: list[Message] = [] |
| 28 | + pending_assistant: Message | None = None |
| 29 | + pending_tools: list[Message] = [] |
| 30 | + |
| 31 | + def flush_pending_if_valid() -> None: |
| 32 | + nonlocal pending_assistant, pending_tools |
| 33 | + if pending_assistant is not None and pending_tools: |
| 34 | + fixed_messages.append(pending_assistant) |
| 35 | + fixed_messages.extend(pending_tools) |
| 36 | + pending_assistant = None |
| 37 | + pending_tools = [] |
| 38 | + |
| 39 | + for msg in messages: |
| 40 | + if msg.role == "tool": |
| 41 | + # 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应 |
| 42 | + if pending_assistant is not None: |
| 43 | + pending_tools.append(msg) |
| 44 | + # else: 孤立的 tool 消息,直接忽略 |
| 45 | + continue |
| 46 | + |
| 47 | + if self._has_tool_calls(msg): |
| 48 | + # 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链 |
| 49 | + flush_pending_if_valid() |
| 50 | + pending_assistant = msg |
| 51 | + continue |
| 52 | + |
| 53 | + # 非 tool,且不含 tool_calls 的消息 |
| 54 | + # 先结束任何 pending 链,再正常追加 |
| 55 | + flush_pending_if_valid() |
| 56 | + fixed_messages.append(msg) |
| 57 | + |
| 58 | + # 结束时处理最后一个 pending 链 |
| 59 | + flush_pending_if_valid() |
| 60 | + |
20 | 61 | return fixed_messages |
21 | 62 |
|
22 | 63 | def truncate_by_turns( |
|
0 commit comments