From a2557921f6735f46f0c69edcc62dd04f56abd0c9 Mon Sep 17 00:00:00 2001 From: CompilError-bts <61622360+CompilError-bts@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:41:41 +0800 Subject: [PATCH 1/5] fix: keep tool call message pairs when trimming context --- astrbot/core/provider/provider.py | 54 ++++++++++++++++++++++++------- tests/test_openai_source.py | 44 +++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index fab3ce6104..c65faa9362 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -162,19 +162,49 @@ async def text_chat_stream( raise NotImplementedError() async def pop_record(self, context: list) -> None: - """弹出 context 第一条非系统提示词对话记录""" - poped = 0 - indexs_to_pop = [] - for idx, record in enumerate(context): - if record["role"] == "system": - continue - indexs_to_pop.append(idx) - poped += 1 - if poped == 2: + """Pop earliest non-system records while preserving tool-call pairing.""" + + def _has_tool_calls(message: dict) -> bool: + return bool(message.get("tool_calls")) + + def _first_non_system_index() -> int | None: + for idx, record in enumerate(context): + if record.get("role") != "system": + return idx + return None + + def _pop_earliest_unit() -> int: + start_idx = _first_non_system_index() + if start_idx is None: + return 0 + + record = context[start_idx] + role = record.get("role") + end_idx = start_idx + + if role == "assistant" and _has_tool_calls(record): + # Keep assistant(tool_calls) and following tool messages atomic. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + # Remove leading orphan tool messages together. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + + removed_count = end_idx - start_idx + 1 + del context[start_idx : end_idx + 1] + return removed_count + + removed = 0 + while removed < 2: + removed_now = _pop_earliest_unit() + if removed_now == 0: break - - for idx in reversed(indexs_to_pop): - context.pop(idx) + removed += removed_now def _ensure_message_to_dicts( self, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 39bb6d3810..ef33cae40f 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -244,6 +244,50 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): await provider.terminate() +@pytest.mark.asyncio +async def test_pop_record_removes_assistant_tool_calls_with_following_tools_atomically(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": None}, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_removes_leading_orphan_tool_messages(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "call_1", "content": "orphan"}, + {"role": "user", "content": "old user"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_groq_payload_drops_reasoning_content_from_assistant_history(): provider = _make_groq_provider() From a3aed9b2a813d3f54ea0dea762d041f07dccd50a Mon Sep 17 00:00:00 2001 From: CompilError-bts <61622360+CompilError-bts@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:58:20 +0800 Subject: [PATCH 2/5] test: add regression coverage for pop_record truncation --- astrbot/core/provider/provider.py | 29 ++++++++++++++- tests/test_openai_source.py | 62 +++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index c65faa9362..7031d52fc8 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -162,7 +162,7 @@ async def text_chat_stream( raise NotImplementedError() async def pop_record(self, context: list) -> None: - """Pop earliest non-system records while preserving tool-call pairing.""" + """弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。""" def _has_tool_calls(message: dict) -> bool: return bool(message.get("tool_calls")) @@ -199,8 +199,35 @@ def _pop_earliest_unit() -> int: del context[start_idx : end_idx + 1] return removed_count + def _peek_earliest_unit_count() -> int: + start_idx = _first_non_system_index() + if start_idx is None: + return 0 + + record = context[start_idx] + role = record.get("role") + end_idx = start_idx + if role == "assistant" and _has_tool_calls(record): + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + return end_idx - start_idx + 1 + removed = 0 while removed < 2: + next_unit_count = _peek_earliest_unit_count() + if next_unit_count == 0: + break + # Keep behavior close to the old "pop around 2 records" strategy, + # while still preserving tool-call atomicity. + if removed > 0 and removed + next_unit_count > 3: + break removed_now = _pop_earliest_unit() if removed_now == 0: break diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index ef33cae40f..4ab41e2852 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -288,6 +288,68 @@ async def test_pop_record_removes_leading_orphan_tool_messages(): await provider.terminate() +@pytest.mark.asyncio +async def test_pop_record_normal_messages_no_regression(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user1"}, + {"role": "assistant", "content": "assistant1"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_assistant_with_multiple_tool_calls(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + { + "role": "assistant", + "tool_calls": [{"id": "call_1"}, {"id": "call_2"}], + "content": None, + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_only_system_messages(): + provider = _make_provider() + try: + context = [{"role": "system", "content": "system"}] + + await provider.pop_record(context) + + assert context == [{"role": "system", "content": "system"}] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_groq_payload_drops_reasoning_content_from_assistant_history(): provider = _make_groq_provider() From a844133e56a4f7acdb24a8a6632f87211e50c5ff Mon Sep 17 00:00:00 2001 From: CompilError-bts <61622360+CompilError-bts@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:19:09 +0800 Subject: [PATCH 3/5] test: make file URI tests cross-platform --- tests/test_openai_source.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 4ab41e2852..55b5fb46e5 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,6 @@ from types import SimpleNamespace +from pathlib import Path +from urllib.parse import urlparse, urlunparse import pytest from openai.types.chat.chat_completion import ChatCompletion @@ -888,9 +890,8 @@ async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp async def test_file_uri_to_path_preserves_windows_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -899,9 +900,8 @@ async def test_file_uri_to_path_preserves_windows_drive_letter(): async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -910,9 +910,8 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://server/share/quoted-image.png") == ( - "//server/share/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://server/share/quoted-image.png") + assert Path(resolved) == Path("//server/share/quoted-image.png") finally: await provider.terminate() @@ -1083,7 +1082,10 @@ async def test_prepare_chat_payload_materializes_context_localhost_file_uri_imag image_path = tmp_path / "quoted-image.png" PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path) - localhost_uri = f"file://localhost{image_path.as_posix()}" + parsed_local_uri = urlparse(image_path.as_uri()) + localhost_uri = urlunparse( + ("file", "localhost", parsed_local_uri.path, "", "", "") + ) payloads, _ = await provider._prepare_chat_payload( prompt=None, contexts=[ From cd613e674f5c0c7f37521059c6928725b2d71e79 Mon Sep 17 00:00:00 2001 From: CompilError-bts <61622360+CompilError-bts@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:44:32 +0800 Subject: [PATCH 4/5] refactor: deduplicate unit range detection in pop_record --- astrbot/core/provider/provider.py | 49 +++++++++--- tests/test_openai_source.py | 128 +++++++++++++++++++++++++++--- 2 files changed, 155 insertions(+), 22 deletions(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index fab3ce6104..b5f291f8e1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -162,19 +162,44 @@ async def text_chat_stream( raise NotImplementedError() async def pop_record(self, context: list) -> None: - """弹出 context 第一条非系统提示词对话记录""" - poped = 0 - indexs_to_pop = [] - for idx, record in enumerate(context): - if record["role"] == "system": - continue - indexs_to_pop.append(idx) - poped += 1 - if poped == 2: + """弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。""" + + def _has_tool_calls(message: dict) -> bool: + return bool(message.get("tool_calls")) + + def _next_unit_bounds() -> tuple[int, int] | None: + for idx, record in enumerate(context): + if record.get("role") != "system": + end_idx = idx + role = record.get("role") + if role == "assistant" and _has_tool_calls(record): + # Keep assistant(tool_calls) and following tool messages atomic. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + # Remove leading orphan tool messages together. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + return idx, end_idx + return None + + removed = 0 + while removed < 2: + next_unit = _next_unit_bounds() + if next_unit is None: break - - for idx in reversed(indexs_to_pop): - context.pop(idx) + start_idx, end_idx = next_unit + next_unit_count = end_idx - start_idx + 1 + # Keep behavior close to the old "pop around 2 records" strategy, + # while still preserving tool-call atomicity. + if removed > 0 and removed + next_unit_count > 3: + break + del context[start_idx : end_idx + 1] + removed += next_unit_count def _ensure_message_to_dicts( self, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 39bb6d3810..0f31dea3da 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,6 @@ +from pathlib import Path from types import SimpleNamespace +from urllib.parse import urlparse, urlunparse import pytest from openai.types.chat.chat_completion import ChatCompletion @@ -244,6 +246,112 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): await provider.terminate() +@pytest.mark.asyncio +async def test_pop_record_removes_assistant_tool_calls_with_following_tools_atomically(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": None}, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_removes_leading_orphan_tool_messages(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "call_1", "content": "orphan"}, + {"role": "user", "content": "old user"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_normal_messages_no_regression(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user1"}, + {"role": "assistant", "content": "assistant1"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_assistant_with_multiple_tool_calls(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + { + "role": "assistant", + "tool_calls": [{"id": "call_1"}, {"id": "call_2"}], + "content": None, + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_only_system_messages(): + provider = _make_provider() + try: + context = [{"role": "system", "content": "system"}] + + await provider.pop_record(context) + + assert context == [{"role": "system", "content": "system"}] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_groq_payload_drops_reasoning_content_from_assistant_history(): provider = _make_groq_provider() @@ -782,9 +890,8 @@ async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp async def test_file_uri_to_path_preserves_windows_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -793,9 +900,8 @@ async def test_file_uri_to_path_preserves_windows_drive_letter(): async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -804,9 +910,8 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://server/share/quoted-image.png") == ( - "//server/share/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://server/share/quoted-image.png") + assert Path(resolved) == Path("//server/share/quoted-image.png") finally: await provider.terminate() @@ -977,7 +1082,10 @@ async def test_prepare_chat_payload_materializes_context_localhost_file_uri_imag image_path = tmp_path / "quoted-image.png" PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path) - localhost_uri = f"file://localhost{image_path.as_posix()}" + parsed_local_uri = urlparse(image_path.as_uri()) + localhost_uri = urlunparse( + ("file", "localhost", parsed_local_uri.path, "", "", "") + ) payloads, _ = await provider._prepare_chat_payload( prompt=None, contexts=[ From a90f4803a6095511d1c00eb78b9c8fd096425a7a Mon Sep 17 00:00:00 2001 From: CompilError-bts <61622360+CompilError-bts@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:25:10 +0800 Subject: [PATCH 5/5] refactor: clarify pop_record removal policy constants --- astrbot/core/provider/provider.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index b5f291f8e1..6ca7a343a9 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -187,8 +187,13 @@ def _next_unit_bounds() -> tuple[int, int] | None: return idx, end_idx return None + # Removal policy: try to remove around TARGET_RECORDS messages, + # but allow up to MAX_RECORDS to keep tool-call/message units atomic. + TARGET_RECORDS = 2 + MAX_RECORDS = 3 + removed = 0 - while removed < 2: + while removed < TARGET_RECORDS: next_unit = _next_unit_bounds() if next_unit is None: break @@ -196,7 +201,7 @@ def _next_unit_bounds() -> tuple[int, int] | None: next_unit_count = end_idx - start_idx + 1 # Keep behavior close to the old "pop around 2 records" strategy, # while still preserving tool-call atomicity. - if removed > 0 and removed + next_unit_count > 3: + if removed > 0 and removed + next_unit_count > MAX_RECORDS: break del context[start_idx : end_idx + 1] removed += next_unit_count