Skip to content

Commit 20ba952

Browse files
committed
Backend add token compression feat
1 parent 2d78626 commit 20ba952

8 files changed

Lines changed: 116 additions & 14 deletions

File tree

astrbot/core/agent/context/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ class ContextConfig:
2323
1. Enforce max turns truncation.
2424
2. Truncation by turns compression strategy.
2525
"""
26+
context_limit_type: str = "turn"
27+
"""Compression trigger mode: "turn" uses model context window × 0.82 rate; "token" uses an absolute token threshold."""
28+
compression_token_threshold: int = 4000
29+
"""When context_limit_type is "token", compression triggers when total tokens >= this threshold."""
2630
llm_compress_instruction: str | None = None
2731
"""Instruction prompt for LLM-based compression."""
2832
llm_compress_keep_recent: int = 0

astrbot/core/agent/context/manager.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ def __init__(
4141
truncate_turns=config.truncate_turns
4242
)
4343

44+
def _has_compressible_messages(self, messages: list[Message]) -> bool:
45+
"""Check if there are any compressible (user/assistant) messages beyond system prompts."""
46+
for msg in messages:
47+
if msg.role in ("user", "assistant"):
48+
return True
49+
return False
50+
4451
async def process(
4552
self, messages: list[Message], trusted_token_usage: int = 0
4653
) -> list[Message]:
@@ -56,7 +63,11 @@ async def process(
5663
result = messages
5764

5865
# 1. 基于轮次的截断 (Enforce max turns)
59-
if self.config.enforce_max_turns != -1:
66+
# Skip turn-based truncation in token mode to avoid conflicts with absolute token threshold
67+
if (
68+
self.config.context_limit_type != "token"
69+
and self.config.enforce_max_turns != -1
70+
):
6071
result = self.truncator.truncate_by_turns(
6172
result,
6273
keep_most_recent_turns=self.config.enforce_max_turns,
@@ -69,10 +80,17 @@ async def process(
6980
result, trusted_token_usage
7081
)
7182

72-
if self.compressor.should_compress(
73-
result, total_tokens, self.config.max_context_tokens
74-
):
75-
result = await self._run_compression(result, total_tokens)
83+
if self.config.context_limit_type == "token":
84+
if (
85+
self._has_compressible_messages(result)
86+
and total_tokens >= self.config.compression_token_threshold
87+
):
88+
result = await self._run_compression(result, total_tokens)
89+
else:
90+
if self.compressor.should_compress(
91+
result, total_tokens, self.config.max_context_tokens
92+
):
93+
result = await self._run_compression(result, total_tokens)
7694

7795
return result
7896
except Exception as e:
@@ -100,21 +118,36 @@ async def _run_compression(
100118
tokens_after_summary = self.token_counter.count_tokens(messages)
101119

102120
# calculate compress rate
103-
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
121+
if self.config.context_limit_type == "token":
122+
denominator = self.config.compression_token_threshold
123+
else:
124+
denominator = self.config.max_context_tokens
125+
compress_rate = (
126+
(tokens_after_summary / denominator) * 100 if denominator > 0 else 0
127+
)
104128
logger.info(
105129
f"Compress completed."
106130
f" {prev_tokens} -> {tokens_after_summary} tokens,"
107131
f" compression rate: {compress_rate:.2f}%.",
108132
)
109133

110134
# last check
111-
if self.compressor.should_compress(
112-
messages, tokens_after_summary, self.config.max_context_tokens
113-
):
114-
logger.info(
115-
"Context still exceeds max tokens after compression, applying halving truncation..."
116-
)
117-
# still need compress, truncate by half
118-
messages = self.truncator.truncate_by_halving(messages)
135+
if self.config.context_limit_type == "token":
136+
if (
137+
self._has_compressible_messages(messages)
138+
and tokens_after_summary >= self.config.compression_token_threshold
139+
):
140+
logger.info(
141+
"Context still exceeds compression threshold after compression, applying halving truncation..."
142+
)
143+
messages = self.truncator.truncate_by_halving(messages)
144+
else:
145+
if self.compressor.should_compress(
146+
messages, tokens_after_summary, self.config.max_context_tokens
147+
):
148+
logger.info(
149+
"Context still exceeds max tokens after compression, applying halving truncation..."
150+
)
151+
messages = self.truncator.truncate_by_halving(messages)
119152

120153
return messages

astrbot/core/agent/runners/tool_loop_agent_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ async def reset(
220220
llm_compress_provider: Provider | None = None,
221221
# truncate by turns compressor
222222
truncate_turns: int = 1,
223+
# token-threshold compression
224+
context_limit_type: str = "turn",
225+
compression_token_threshold: int = 4000,
223226
# customize
224227
custom_token_counter: TokenCounter | None = None,
225228
custom_compressor: ContextCompressor | None = None,
@@ -236,6 +239,8 @@ async def reset(
236239
self.llm_compress_keep_recent = llm_compress_keep_recent
237240
self.llm_compress_provider = llm_compress_provider
238241
self.truncate_turns = truncate_turns
242+
self.context_limit_type = context_limit_type
243+
self.compression_token_threshold = compression_token_threshold
239244
self.custom_token_counter = custom_token_counter
240245
self.custom_compressor = custom_compressor
241246
self.tool_result_overflow_dir = tool_result_overflow_dir
@@ -250,6 +255,8 @@ async def reset(
250255
# enforce max turns before compression
251256
enforce_max_turns=self.enforce_max_turns,
252257
truncate_turns=self.truncate_turns,
258+
context_limit_type=self.context_limit_type,
259+
compression_token_threshold=self.compression_token_threshold,
253260
llm_compress_instruction=self.llm_compress_instruction,
254261
llm_compress_keep_recent=self.llm_compress_keep_recent,
255262
llm_compress_provider=self.llm_compress_provider,

astrbot/core/astr_main_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ class MainAgentBuildConfig:
161161
"""The number of oldest turns to remove when context length limit is reached."""
162162
fallback_max_context_tokens: int = 128000
163163
"""Fallback max context tokens. When max_context_tokens is 0 and the model is not in LLM_METADATAS, use this value."""
164+
context_limit_type: str = "turn"
165+
"""Compression trigger mode: "turn" uses model context window × 0.82 rate; "token" uses an absolute token threshold."""
166+
compression_token_threshold: int = 4000
167+
"""When context_limit_type is "token", compression triggers when total tokens >= this threshold."""
164168
llm_safety_mode: bool = True
165169
"""This will inject healthy and safe system prompt into the main agent,
166170
to prevent LLM output harmful information"""
@@ -1473,6 +1477,8 @@ async def build_main_agent(
14731477
llm_compress_provider=_get_compress_provider(config, plugin_context),
14741478
truncate_turns=config.dequeue_context_length,
14751479
enforce_max_turns=config.max_context_length,
1480+
context_limit_type=config.context_limit_type,
1481+
compression_token_threshold=config.compression_token_threshold,
14761482
tool_schema_mode=config.tool_schema_mode,
14771483
fallback_providers=_get_fallback_chat_providers(
14781484
provider, plugin_context, config.provider_settings

astrbot/core/config/default.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@
132132
"llm_compress_provider_id": "",
133133
"max_context_length": -1,
134134
"dequeue_context_length": 1,
135+
"context_limit_type": "turn",
136+
"compression_token_threshold": 4000,
135137
"streaming_response": False,
136138
"show_tool_use_status": False,
137139
"show_tool_call_result": False,
@@ -3509,6 +3511,7 @@
35093511
"type": "int",
35103512
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
35113513
"condition": {
3514+
"provider_settings.context_limit_type": "turn",
35123515
"provider_settings.agent_runner_type": "local",
35133516
},
35143517
},
@@ -3566,6 +3569,25 @@
35663569
"provider_settings.agent_runner_type": "local",
35673570
},
35683571
},
3572+
"provider_settings.context_limit_type": {
3573+
"description": "上下文压缩触发模式",
3574+
"type": "string",
3575+
"options": ["turn", "token"],
3576+
"labels": ["按百分比(模型窗口 × 82%)", "按固定 Token 阈值"],
3577+
"hint": '"按百分比"为默认行为:当上下文 Token 数超过模型窗口的 82% 时触发压缩。"按固定 Token 阈值"允许您设置一个绝对的 Token 数作为触发阈值,适用于需要更早触发压缩的场景。',
3578+
"condition": {
3579+
"provider_settings.agent_runner_type": "local",
3580+
},
3581+
},
3582+
"provider_settings.compression_token_threshold": {
3583+
"description": "Token 触发阈值",
3584+
"type": "int",
3585+
"hint": '当"上下文压缩触发模式"设为"按固定 Token 阈值"时生效。当前上下文 Token 数达到此值时触发压缩。',
3586+
"condition": {
3587+
"provider_settings.context_limit_type": "token",
3588+
"provider_settings.agent_runner_type": "local",
3589+
},
3590+
},
35693591
},
35703592
"condition": {
35713593
"provider_settings.agent_runner_type": "local",

astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ async def initialize(self, ctx: PipelineContext) -> None:
111111
self.fallback_max_context_tokens: int = settings.get(
112112
"fallback_max_context_tokens", 128000
113113
)
114+
self.context_limit_type: str = settings.get("context_limit_type", "turn")
115+
self.compression_token_threshold: int = settings.get(
116+
"compression_token_threshold", 4000
117+
)
114118

115119
self.llm_safety_mode = settings.get("llm_safety_mode", True)
116120
self.safety_mode_strategy = settings.get(
@@ -141,6 +145,8 @@ async def initialize(self, ctx: PipelineContext) -> None:
141145
max_context_length=self.max_context_length,
142146
dequeue_context_length=self.dequeue_context_length,
143147
fallback_max_context_tokens=self.fallback_max_context_tokens,
148+
context_limit_type=self.context_limit_type,
149+
compression_token_threshold=self.compression_token_threshold,
144150
llm_safety_mode=self.llm_safety_mode,
145151
safety_mode_strategy=self.safety_mode_strategy,
146152
computer_use_runtime=self.computer_use_runtime,

dashboard/src/i18n/locales/en-US/features/config-metadata.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,18 @@
276276
"fallback_max_context_tokens": {
277277
"description": "Fallback context window size",
278278
"hint": "When max_context_tokens is 0 and the model is not in built-in metadata, use this value as the context window size. Default: 128000."
279+
},
280+
"context_limit_type": {
281+
"description": "Compression Trigger Mode",
282+
"labels": [
283+
"By Percentage (Model Window × 82%)",
284+
"By Fixed Token Threshold"
285+
],
286+
"hint": "\"By Percentage\" is the default behavior. \"By Fixed Token Threshold\" allows setting an absolute token count as the trigger point."
287+
},
288+
"compression_token_threshold": {
289+
"description": "Token Trigger Threshold",
290+
"hint": "Effective when trigger mode is set to \"By Fixed Token Threshold\". Compression triggers when the current context token count reaches this value."
279291
}
280292
}
281293
},

dashboard/src/i18n/locales/zh-CN/features/config-metadata.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,18 @@
278278
"fallback_max_context_tokens": {
279279
"description": "上下文窗口兜底值",
280280
"hint": "当 max_context_tokens 为 0 且模型不在内置元数据中时,使用此值作为上下文窗口大小。默认 128000。"
281+
},
282+
"context_limit_type": {
283+
"description": "上下文压缩触发模式",
284+
"labels": [
285+
"按百分比(模型窗口 × 82%)",
286+
"按固定 Token 阈值"
287+
],
288+
"hint": "\"按百分比\"为默认行为。\"按固定 Token 阈值\"允许设置绝对的 Token 数作为触发阈值。"
289+
},
290+
"compression_token_threshold": {
291+
"description": "Token 触发阈值",
292+
"hint": "当触发模式为\"按固定 Token 阈值\"时生效,当前上下文 Token 数达到此值时触发压缩。"
281293
}
282294
}
283295
},

0 commit comments

Comments
 (0)