Skip to content

Commit 160fffe

Browse files
author
LehaoLin
committed
fix(provider): clean orphaned tool messages after truncation
1 parent 22e24e5 commit 160fffe

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

astrbot/core/provider/provider.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,48 @@ async def pop_record(self, context: list) -> None:
172172
for idx in reversed(indexs_to_pop):
173173
context.pop(idx)
174174

175+
context[:] = self._fix_tool_call_pairs_in_dict_context(context)
176+
177+
@staticmethod
178+
def _fix_tool_call_pairs_in_dict_context(context: list[dict]) -> list[dict]:
179+
"""Remove orphaned tool call chains from dict-based message history."""
180+
if not context:
181+
return context
182+
183+
fixed_context: list[dict] = []
184+
pending_assistant: dict | None = None
185+
pending_tools: list[dict] = []
186+
187+
def flush_pending_if_valid() -> None:
188+
nonlocal pending_assistant, pending_tools
189+
if pending_assistant is not None and pending_tools:
190+
fixed_context.append(pending_assistant)
191+
fixed_context.extend(pending_tools)
192+
pending_assistant = None
193+
pending_tools = []
194+
195+
for message in context:
196+
role = message.get("role")
197+
if role == "tool":
198+
if pending_assistant is not None:
199+
pending_tools.append(message)
200+
continue
201+
202+
if (
203+
role == "assistant"
204+
and message.get("tool_calls") is not None
205+
and len(message.get("tool_calls")) > 0
206+
):
207+
flush_pending_if_valid()
208+
pending_assistant = message
209+
continue
210+
211+
flush_pending_if_valid()
212+
fixed_context.append(message)
213+
214+
flush_pending_if_valid()
215+
return fixed_context
216+
175217
def _ensure_message_to_dicts(
176218
self,
177219
messages: list[dict] | list[Message] | None,

tests/test_openai_source.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,87 @@ async def test_handle_api_error_model_not_vlm_after_fallback_raises():
165165
await provider.terminate()
166166

167167

168+
@pytest.mark.asyncio
169+
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():
170+
provider = _make_provider()
171+
try:
172+
payloads = {
173+
"messages": [
174+
{"role": "system", "content": "system"},
175+
{"role": "user", "content": "Run tool"},
176+
{
177+
"role": "assistant",
178+
"content": "",
179+
"tool_calls": [
180+
{
181+
"id": "call_1",
182+
"type": "function",
183+
"function": {"name": "search", "arguments": "{}"},
184+
}
185+
],
186+
},
187+
{"role": "tool", "content": "Tool result", "tool_call_id": "call_1"},
188+
{"role": "assistant", "content": "Final answer"},
189+
]
190+
}
191+
context_query = payloads["messages"]
192+
193+
success, *_rest = await provider._handle_api_error(
194+
Exception("maximum context length exceeded"),
195+
payloads=payloads,
196+
context_query=context_query,
197+
func_tool=None,
198+
chosen_key="test-key",
199+
available_api_keys=["test-key"],
200+
retry_cnt=0,
201+
max_retries=10,
202+
)
203+
204+
assert success is False
205+
assert payloads["messages"] == [
206+
{"role": "system", "content": "system"},
207+
{"role": "assistant", "content": "Final answer"},
208+
]
209+
finally:
210+
await provider.terminate()
211+
212+
213+
@pytest.mark.asyncio
214+
async def test_handle_api_error_context_length_preserves_remaining_valid_messages():
215+
provider = _make_provider()
216+
try:
217+
payloads = {
218+
"messages": [
219+
{"role": "system", "content": "system"},
220+
{"role": "user", "content": "old question"},
221+
{"role": "assistant", "content": "old answer"},
222+
{"role": "user", "content": "new question"},
223+
{"role": "assistant", "content": "new answer"},
224+
]
225+
}
226+
context_query = payloads["messages"]
227+
228+
success, *_rest = await provider._handle_api_error(
229+
Exception("maximum context length exceeded"),
230+
payloads=payloads,
231+
context_query=context_query,
232+
func_tool=None,
233+
chosen_key="test-key",
234+
available_api_keys=["test-key"],
235+
retry_cnt=0,
236+
max_retries=10,
237+
)
238+
239+
assert success is False
240+
assert payloads["messages"] == [
241+
{"role": "system", "content": "system"},
242+
{"role": "user", "content": "new question"},
243+
{"role": "assistant", "content": "new answer"},
244+
]
245+
finally:
246+
await provider.terminate()
247+
248+
168249
@pytest.mark.asyncio
169250
async def test_handle_api_error_content_moderated_with_unserializable_body():
170251
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})

0 commit comments

Comments
 (0)