|
| 1 | +"""Pair-aware conversation history truncation. |
| 2 | +
|
| 3 | +Replaces naive ``conversation[-N:]`` slicing with a walker that keeps |
| 4 | +``assistant.tool_calls`` and their matching ``role="tool"`` messages as an |
| 5 | +atomic block — never half a pair, never orphan tool messages. |
| 6 | +
|
| 7 | +Why: OpenAI Responses API and Chat Completions both reject input where a |
| 8 | +``function_call_output`` / ``role="tool"`` message has no matching |
| 9 | +``function_call`` / ``assistant.tool_calls`` earlier in the input. Naive |
| 10 | +``[-N:]`` slicing can leave such orphans at the head when the cut lands |
| 11 | +between an assistant message and its tool results. This is the failure mode |
| 12 | +reported in issue #446. |
| 13 | +
|
| 14 | +Orphan detection is by ``tool_call_id`` matching, not by adjacency — a |
| 15 | +tool message inserted between a valid pair and other messages (from |
| 16 | +malformed persistence or upstream truncation) is dropped, not folded |
| 17 | +into an adjacent block. This makes the helper robust against orphans |
| 18 | +at any position, not just at the slice head. |
| 19 | +
|
| 20 | +Input is expected to be in OpenAI chat-completion format (post-reorganization |
| 21 | +from DB ``role="tool_call"`` rows). |
| 22 | +""" |
| 23 | + |
| 24 | +from __future__ import annotations |
| 25 | + |
| 26 | +from typing import Any |
| 27 | + |
| 28 | + |
| 29 | +def _identify_orphans(messages: list[dict[str, Any]]) -> set[int]: |
| 30 | + """Return indices of ``role="tool"`` messages whose ``tool_call_id`` has |
| 31 | + no matching ``assistant.tool_calls`` earlier in the conversation. |
| 32 | +
|
| 33 | + OpenAI rejects the request the moment a ``function_call_output`` is |
| 34 | + sent without its matching ``function_call``, regardless of whether |
| 35 | + that tool message is at the head, middle, or end. So orphan detection |
| 36 | + is by ID matching, not by position. |
| 37 | + """ |
| 38 | + orphans: set[int] = set() |
| 39 | + for i, msg in enumerate(messages): |
| 40 | + if msg.get("role") != "tool": |
| 41 | + continue |
| 42 | + tcid = msg.get("tool_call_id") |
| 43 | + if not tcid: |
| 44 | + orphans.add(i) |
| 45 | + continue |
| 46 | + # Search backward for an assistant whose tool_calls contains this id. |
| 47 | + # Walks past intervening user / system / other-assistant messages. |
| 48 | + found = False |
| 49 | + j = i - 1 |
| 50 | + while j >= 0: |
| 51 | + m = messages[j] |
| 52 | + if m.get("role") == "assistant" and m.get("tool_calls"): |
| 53 | + ids = {tc.get("id") for tc in m["tool_calls"]} |
| 54 | + if tcid in ids: |
| 55 | + found = True |
| 56 | + break |
| 57 | + j -= 1 |
| 58 | + if not found: |
| 59 | + orphans.add(i) |
| 60 | + return orphans |
| 61 | + |
| 62 | + |
| 63 | +def truncate_by_message_count( |
| 64 | + messages: list[dict[str, Any]], |
| 65 | + max_messages: int, |
| 66 | +) -> list[dict[str, Any]]: |
| 67 | + """Keep at most ``max_messages`` recent messages, preserving tool-call pairs. |
| 68 | +
|
| 69 | + A "block" is either: |
| 70 | + - a single non-tool, non-tool-calling message (user / system / assistant text), or |
| 71 | + - an ``assistant`` with ``tool_calls`` plus every matching ``role="tool"`` |
| 72 | + message (identified by ``tool_call_id``, not adjacency). |
| 73 | +
|
| 74 | + Blocks are atomic: included whole or not at all. Orphan ``role="tool"`` |
| 75 | + messages — those whose ``tool_call_id`` has no matching assistant — are |
| 76 | + silently dropped regardless of budget. Sending them to OpenAI causes the |
| 77 | + #446 error. |
| 78 | +
|
| 79 | + Args: |
| 80 | + messages: Conversation list in OpenAI format. Empty list is fine. |
| 81 | + max_messages: Soft upper bound on the number of returned entries. |
| 82 | + Values ``<= 0`` return ``[]``. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + A new list (input is never mutated) of at most ``max_messages`` entries |
| 86 | + from the tail of ``messages``, with all tool-call pairs intact. |
| 87 | + """ |
| 88 | + if max_messages <= 0 or not messages: |
| 89 | + return [] |
| 90 | + |
| 91 | + orphans = _identify_orphans(messages) |
| 92 | + n = len(messages) |
| 93 | + consumed: set[int] = set(orphans) # orphans drop unconditionally |
| 94 | + blocks: list[set[int]] = [] # tail-to-head order |
| 95 | + |
| 96 | + for i in range(n - 1, -1, -1): |
| 97 | + if i in consumed: |
| 98 | + continue |
| 99 | + msg = messages[i] |
| 100 | + role = msg.get("role") |
| 101 | + |
| 102 | + if role == "tool": |
| 103 | + # Find this tool's owning assistant by matching tool_call_id |
| 104 | + tcid = msg.get("tool_call_id") |
| 105 | + asst_idx = -1 |
| 106 | + j = i - 1 |
| 107 | + while j >= 0: |
| 108 | + m = messages[j] |
| 109 | + if m.get("role") == "assistant" and m.get("tool_calls"): |
| 110 | + ids = {tc.get("id") for tc in m["tool_calls"]} |
| 111 | + if tcid in ids: |
| 112 | + asst_idx = j |
| 113 | + break |
| 114 | + j -= 1 |
| 115 | + if asst_idx < 0: |
| 116 | + # Defensive — orphan detection should have caught this |
| 117 | + consumed.add(i) |
| 118 | + continue |
| 119 | + # Block = assistant + ALL of its matching tool messages (siblings) |
| 120 | + asst_tc_ids = {tc.get("id") for tc in messages[asst_idx]["tool_calls"]} |
| 121 | + block = {asst_idx} |
| 122 | + for k in range(asst_idx + 1, n): |
| 123 | + if k in consumed: |
| 124 | + continue |
| 125 | + m = messages[k] |
| 126 | + if ( |
| 127 | + m.get("role") == "tool" |
| 128 | + and m.get("tool_call_id") in asst_tc_ids |
| 129 | + ): |
| 130 | + block.add(k) |
| 131 | + consumed |= block |
| 132 | + blocks.append(block) |
| 133 | + elif role == "assistant" and msg.get("tool_calls"): |
| 134 | + # Encountered the assistant before any of its tools (e.g. tools |
| 135 | + # were truncated upstream or are still in flight). Group with |
| 136 | + # whatever matching tools follow it. |
| 137 | + asst_tc_ids = {tc.get("id") for tc in msg["tool_calls"]} |
| 138 | + block = {i} |
| 139 | + for k in range(i + 1, n): |
| 140 | + if k in consumed: |
| 141 | + continue |
| 142 | + m = messages[k] |
| 143 | + if ( |
| 144 | + m.get("role") == "tool" |
| 145 | + and m.get("tool_call_id") in asst_tc_ids |
| 146 | + ): |
| 147 | + block.add(k) |
| 148 | + consumed |= block |
| 149 | + blocks.append(block) |
| 150 | + else: |
| 151 | + consumed.add(i) |
| 152 | + blocks.append({i}) |
| 153 | + |
| 154 | + # Walk blocks tail-to-head, taking until budget exhausted. |
| 155 | + keep: set[int] = set() |
| 156 | + budget = max_messages |
| 157 | + for block in blocks: |
| 158 | + size = len(block) |
| 159 | + if size <= budget: |
| 160 | + keep |= block |
| 161 | + budget -= size |
| 162 | + else: |
| 163 | + # Block doesn't fit — stop. Do NOT partial-include (would split pair). |
| 164 | + break |
| 165 | + |
| 166 | + return [messages[k] for k in sorted(keep)] |
0 commit comments