diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index fab3ce6104..6ca7a343a9 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -162,19 +162,49 @@ async def text_chat_stream( raise NotImplementedError() async def pop_record(self, context: list) -> None: - """弹出 context 第一条非系统提示词对话记录""" - poped = 0 - indexs_to_pop = [] - for idx, record in enumerate(context): - if record["role"] == "system": - continue - indexs_to_pop.append(idx) - poped += 1 - if poped == 2: + """弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。""" + + def _has_tool_calls(message: dict) -> bool: + return bool(message.get("tool_calls")) + + def _next_unit_bounds() -> tuple[int, int] | None: + for idx, record in enumerate(context): + if record.get("role") != "system": + end_idx = idx + role = record.get("role") + if role == "assistant" and _has_tool_calls(record): + # Keep assistant(tool_calls) and following tool messages atomic. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + # Remove leading orphan tool messages together. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + return idx, end_idx + return None + + # Removal policy: try to remove around TARGET_RECORDS messages, + # but allow up to MAX_RECORDS to keep tool-call/message units atomic. + TARGET_RECORDS = 2 + MAX_RECORDS = 3 + + removed = 0 + while removed < TARGET_RECORDS: + next_unit = _next_unit_bounds() + if next_unit is None: break - - for idx in reversed(indexs_to_pop): - context.pop(idx) + start_idx, end_idx = next_unit + next_unit_count = end_idx - start_idx + 1 + # Keep behavior close to the old "pop around 2 records" strategy, + # while still preserving tool-call atomicity. + if removed > 0 and removed + next_unit_count > MAX_RECORDS: + break + del context[start_idx : end_idx + 1] + removed += next_unit_count def _ensure_message_to_dicts( self, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 39bb6d3810..0f31dea3da 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,6 @@ +from pathlib import Path from types import SimpleNamespace +from urllib.parse import urlparse, urlunparse import pytest from openai.types.chat.chat_completion import ChatCompletion @@ -244,6 +246,112 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): await provider.terminate() +@pytest.mark.asyncio +async def test_pop_record_removes_assistant_tool_calls_with_following_tools_atomically(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": None}, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_removes_leading_orphan_tool_messages(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "call_1", "content": "orphan"}, + {"role": "user", "content": "old user"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_normal_messages_no_regression(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user1"}, + {"role": "assistant", "content": "assistant1"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_assistant_with_multiple_tool_calls(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + { + "role": "assistant", + "tool_calls": [{"id": "call_1"}, {"id": "call_2"}], + "content": None, + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_only_system_messages(): + provider = _make_provider() + try: + context = [{"role": "system", "content": "system"}] + + await provider.pop_record(context) + + assert context == [{"role": "system", "content": "system"}] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_groq_payload_drops_reasoning_content_from_assistant_history(): provider = _make_groq_provider() @@ -782,9 +890,8 @@ async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp async def test_file_uri_to_path_preserves_windows_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -793,9 +900,8 @@ async def test_file_uri_to_path_preserves_windows_drive_letter(): async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -804,9 +910,8 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://server/share/quoted-image.png") == ( - "//server/share/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://server/share/quoted-image.png") + assert Path(resolved) == Path("//server/share/quoted-image.png") finally: await provider.terminate() @@ -977,7 +1082,10 @@ async def test_prepare_chat_payload_materializes_context_localhost_file_uri_imag image_path = tmp_path / "quoted-image.png" PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path) - localhost_uri = f"file://localhost{image_path.as_posix()}" + parsed_local_uri = urlparse(image_path.as_uri()) + localhost_uri = urlunparse( + ("file", "localhost", parsed_local_uri.path, "", "", "") + ) payloads, _ = await provider._prepare_chat_payload( prompt=None, contexts=[