Skip to content

Commit 3393d1e

Browse files
committed
fix(context): restore turn cap and serialize content parts for llm compression
1 parent 25b1344 commit 3393d1e

4 files changed

Lines changed: 64 additions & 2 deletions

File tree

astrbot/core/agent/context/compressor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING, Protocol, runtime_checkable
22

3-
from ..message import Message
3+
from ..message import ContentPart, Message
44

55
if TYPE_CHECKING:
66
from astrbot import logger
@@ -100,7 +100,15 @@ def _message_to_dict(msg: Message) -> dict:
100100
"""Convert a Message to a plain dict suitable for round splitting."""
101101
d = {"role": msg.role}
102102
if msg.content is not None:
103-
d["content"] = msg.content
103+
if isinstance(msg.content, list):
104+
d["content"] = [
105+
part.model_dump_for_context()
106+
if isinstance(part, ContentPart)
107+
else part
108+
for part in msg.content
109+
]
110+
else:
111+
d["content"] = msg.content
104112
if getattr(msg, "tool_calls", None):
105113
d["tool_calls"] = msg.tool_calls
106114
if getattr(msg, "tool_call_id", None):

astrbot/core/astr_main_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,7 @@ async def build_main_agent(
15341534
llm_compress_keep_recent=config.llm_compress_keep_recent,
15351535
llm_compress_provider=_get_compress_provider(config, plugin_context, event),
15361536
truncate_turns=config.dequeue_context_length,
1537+
enforce_max_turns=config.max_context_length,
15371538
tool_schema_mode=config.tool_schema_mode,
15381539
fallback_providers=fallback_providers,
15391540
tool_result_overflow_dir=(

tests/agent/test_context_manager.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,27 @@ async def test_llm_compressor_keeps_history_when_summary_is_empty(self):
113113
"LLM context compression returned an empty summary."
114114
)
115115

116+
@pytest.mark.asyncio
117+
async def test_llm_compressor_handles_textpart_content(self):
118+
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
119+
120+
provider = MockProvider()
121+
compressor = LLMSummaryCompressor(provider=provider, keep_recent=1) # type: ignore[arg-type]
122+
messages = [
123+
Message(role="user", content=[TextPart(text="Hello")]),
124+
Message(role="assistant", content=[TextPart(text="Hi there")]),
125+
Message(role="user", content=[TextPart(text="Summarize our work")]),
126+
Message(role="assistant", content=[TextPart(text="Sure")]),
127+
]
128+
129+
result = await compressor(messages)
130+
131+
assert len(result) == 4
132+
assert result[0].role == "user"
133+
assert isinstance(result[0].content, str)
134+
assert "previous history conversation summary" in result[0].content
135+
assert result[-1].content == [TextPart(text="Sure")]
136+
116137
# ==================== Empty and Edge Cases ====================
117138

118139
@pytest.mark.asyncio

tests/unit/test_astr_main_agent.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,38 @@ async def test_build_main_agent_basic(
985985
assert result is not None
986986
assert isinstance(result, module.MainAgentBuildResult)
987987

988+
@pytest.mark.asyncio
989+
async def test_build_main_agent_passes_max_context_length_to_runner(
990+
self, mock_event, mock_context, mock_provider
991+
):
992+
module = ama
993+
mock_context.get_provider_by_id.return_value = None
994+
mock_context.get_using_provider.return_value = mock_provider
995+
mock_context.get_config.return_value = {}
996+
997+
conv_mgr = mock_context.conversation_manager
998+
_setup_conversation_for_build(conv_mgr)
999+
1000+
with (
1001+
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
1002+
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
1003+
):
1004+
mock_runner = MagicMock()
1005+
mock_runner.reset = AsyncMock()
1006+
mock_runner_cls.return_value = mock_runner
1007+
1008+
result = await module.build_main_agent(
1009+
event=mock_event,
1010+
plugin_context=mock_context,
1011+
config=module.MainAgentBuildConfig(
1012+
tool_call_timeout=60,
1013+
max_context_length=7,
1014+
),
1015+
)
1016+
1017+
assert result is not None
1018+
assert mock_runner.reset.await_args.kwargs["enforce_max_turns"] == 7
1019+
9881020
@pytest.mark.asyncio
9891021
async def test_build_main_agent_no_provider(self, mock_event, mock_context):
9901022
"""Test building main agent when no provider is available."""

0 commit comments

Comments
 (0)