Skip to content

Commit 95d8057

Browse files
RC-CHNTsukumi233w31r4Soulter
authored
refactor(ltm): redesign long-term memory with context compaction (reopen of #8144) (#8226)
* refactor(ltm): redesign long-term memory with context compaction - Add raw_records / contexts / summaries data model per group - Add LLM summary compaction strategy alongside truncation - Add turn-based (_split_into_rounds) granularity - Add image caption integration into LTM history - Add tool_call / tool_result persistence into raw_records - Add active reply support driven by LTM state - Improve summary injection prefix with system note and delimiters - Add info-level logging for summary compaction lifecycle - Clarify default summary prompt with explicit preserve/drop rules - Add context_guard for history overflow protection in agent runner - Add internal agent history compaction in agent_sub_stages - Add comprehensive LTM unit tests and compaction test suites * fix(ltm): handle malformed JSON in tool args and clean up lock on session removal * fix(ltm): guard against duplicate system prompt note injection * fix(ltm): fall back to user message when internal marker parsing fails - Treat lines starting with <T:CALL>, <T:RES, or <BOT/ as regular user messages when their respective parsers return None, instead of silently dropping them. Defensive guard against malformed internal markers. * fix(ltm): release session lock during LLM summary generation * fix(ltm): trim raw_records in handle_message to prevent unbounded growth * perf(ltm): use len(s) instead of len(s.encode()) in trim loop Avoid allocating a new bytes object for every string when calculating buffer size in _trim_raw_records. Character count is sufficient for the approximate memory cap. * feat(ltm): make user segment truncation limits configurable * feat(ltm): pre-fill default LTM summary prompt in config and i18n * refactor(ltm): hardcode internal segment/trim constants * refactor(ltm): unify compaction strategy with main agent runner * feat(ltm): add @mention weight marker for group chat messages * test: fix test failures from LTM compaction unification * chore(dashboard): remove obsolete LTM compaction i18n metadata * chore: shrink codebase * feat(group-chat): implement group chat context management and related functionality --------- Co-authored-by: Tsukumi <112180165+Tsukumi233@users.noreply.github.com> Co-authored-by: zenfun <zenfun510@gmail.com> Co-authored-by: Soulter <905617992@qq.com>
1 parent 61b6813 commit 95d8057

16 files changed

Lines changed: 712 additions & 421 deletions

File tree

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import asyncio
2+
import datetime
3+
import random
4+
import uuid
5+
from collections import defaultdict, deque
6+
7+
from astrbot import logger
8+
from astrbot.api import star
9+
from astrbot.api.event import AstrMessageEvent
10+
from astrbot.api.message_components import At, Image, Plain
11+
from astrbot.api.platform import MessageType
12+
from astrbot.api.provider import Provider, ProviderRequest
13+
from astrbot.core.agent.message import TextPart
14+
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
15+
16+
"""
17+
Group chat context awareness.
18+
"""
19+
20+
GROUP_HISTORY_HEADER = (
21+
"<system_reminder>"
22+
"You are in a group chat. "
23+
"Belows are group chat context after your last reply:\n"
24+
"--- BEGIN CONTEXT---\n"
25+
)
26+
GROUP_HISTORY_FOOTER = "\n--- END CONTEXT ---\n</system_reminder>"
27+
DEFAULT_GROUP_MESSAGE_MAX_CNT = 300
28+
29+
30+
class GroupChatContext:
31+
def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
32+
self.acm = acm
33+
self.context = context
34+
self._locks: dict[str, asyncio.Lock] = {}
35+
self.raw_records: dict[str, deque[str]] = defaultdict(deque)
36+
self._record_ids: dict[str, deque[str]] = defaultdict(deque)
37+
38+
def _get_lock(self, umo: str) -> asyncio.Lock:
39+
lock = self._locks.get(umo)
40+
if lock is None:
41+
lock = asyncio.Lock()
42+
self._locks[umo] = lock
43+
return lock
44+
45+
def cfg(self, event: AstrMessageEvent):
46+
cfg = self.context.get_config(umo=event.unified_msg_origin)
47+
group_context_cfg = cfg["provider_ltm_settings"]
48+
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
49+
image_caption_provider_id = group_context_cfg.get("image_caption_provider_id")
50+
image_caption = group_context_cfg["image_caption"] and bool(
51+
image_caption_provider_id
52+
)
53+
active_reply = group_context_cfg["active_reply"]
54+
enable_active_reply = active_reply.get("enable", False)
55+
ar_method = active_reply["method"]
56+
ar_possibility = active_reply["possibility_reply"]
57+
ar_prompt = active_reply.get("prompt", "")
58+
ar_whitelist = active_reply.get("whitelist", [])
59+
return {
60+
"group_message_max_cnt": _positive_int(
61+
group_context_cfg.get(
62+
"group_message_max_cnt",
63+
DEFAULT_GROUP_MESSAGE_MAX_CNT,
64+
),
65+
DEFAULT_GROUP_MESSAGE_MAX_CNT,
66+
),
67+
"image_caption": image_caption,
68+
"image_caption_prompt": image_caption_prompt,
69+
"image_caption_provider_id": image_caption_provider_id,
70+
"enable_active_reply": enable_active_reply,
71+
"ar_method": ar_method,
72+
"ar_possibility": ar_possibility,
73+
"ar_prompt": ar_prompt,
74+
"ar_whitelist": ar_whitelist,
75+
}
76+
77+
async def get_image_caption(
78+
self,
79+
image_url: str,
80+
image_caption_provider_id: str,
81+
image_caption_prompt: str,
82+
) -> str:
83+
if not image_caption_provider_id:
84+
provider = self.context.get_using_provider()
85+
else:
86+
provider = self.context.get_provider_by_id(image_caption_provider_id)
87+
if not provider:
88+
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
89+
if not isinstance(provider, Provider):
90+
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
91+
response = await provider.text_chat(
92+
prompt=image_caption_prompt,
93+
session_id=uuid.uuid4().hex,
94+
image_urls=[image_url],
95+
persist=False,
96+
)
97+
return response.completion_text
98+
99+
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
100+
cfg = self.cfg(event)
101+
if not cfg["enable_active_reply"]:
102+
return False
103+
if event.get_message_type() != MessageType.GROUP_MESSAGE:
104+
return False
105+
if event.is_at_or_wake_command:
106+
return False
107+
if cfg["ar_whitelist"] and (
108+
event.unified_msg_origin not in cfg["ar_whitelist"]
109+
and (
110+
event.get_group_id() and event.get_group_id() not in cfg["ar_whitelist"]
111+
)
112+
):
113+
return False
114+
match cfg["ar_method"]:
115+
case "possibility_reply":
116+
return random.random() < cfg["ar_possibility"]
117+
return False
118+
119+
async def remove_session(self, event: AstrMessageEvent) -> int:
120+
umo = event.unified_msg_origin
121+
lock = self._get_lock(umo)
122+
async with lock:
123+
cnt = len(self.raw_records.get(umo, deque()))
124+
self.raw_records.pop(umo, None)
125+
self._record_ids.pop(umo, None)
126+
self._locks.pop(umo, None)
127+
return cnt
128+
129+
async def handle_message(self, event: AstrMessageEvent) -> None:
130+
if event.get_message_type() != MessageType.GROUP_MESSAGE:
131+
return
132+
133+
umo = event.unified_msg_origin
134+
cfg = self.cfg(event)
135+
final_message = await self._format_message(event, cfg)
136+
137+
async with self._get_lock(umo):
138+
records = self.raw_records[umo]
139+
record_ids = self._record_ids[umo]
140+
record_id = uuid.uuid4().hex
141+
records.append(final_message)
142+
record_ids.append(record_id)
143+
_trim_left(records, cfg["group_message_max_cnt"], record_ids)
144+
event.set_extra("_group_context_record_id", record_id)
145+
event.set_extra("_group_context_raw_idx", len(records) - 1)
146+
147+
logger.debug(f"group_chat_context | {umo} | {final_message}")
148+
149+
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
150+
umo = event.unified_msg_origin
151+
record_id = event.get_extra("_group_context_record_id", None)
152+
prompt_idx = event.get_extra("_group_context_raw_idx", -1)
153+
if not isinstance(record_id, str) and (
154+
not isinstance(prompt_idx, int) or prompt_idx < 0
155+
):
156+
return
157+
158+
async with self._get_lock(umo):
159+
records = self.raw_records.get(umo)
160+
if not records:
161+
return
162+
163+
raw_list = list(records)
164+
id_list = list(self._record_ids.get(umo, deque()))
165+
if isinstance(record_id, str) and record_id in id_list:
166+
prompt_idx = id_list.index(record_id)
167+
168+
if prompt_idx >= len(raw_list):
169+
return
170+
171+
records_to_inject = raw_list[:prompt_idx]
172+
remaining = raw_list[prompt_idx + 1 :]
173+
remaining_ids = id_list[prompt_idx + 1 :] if id_list else []
174+
records.clear()
175+
records.extend(remaining)
176+
if id_list:
177+
record_ids = self._record_ids[umo]
178+
record_ids.clear()
179+
record_ids.extend(remaining_ids)
180+
181+
if records_to_inject:
182+
req.extra_user_content_parts.append(
183+
TextPart(text=_format_group_history_block(records_to_inject))
184+
)
185+
186+
async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
187+
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
188+
parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "]
189+
190+
for comp in event.get_messages():
191+
if isinstance(comp, Plain):
192+
parts.append(f" {comp.text}")
193+
elif isinstance(comp, Image):
194+
if cfg["image_caption"]:
195+
try:
196+
url = comp.url if comp.url else comp.file
197+
if not url:
198+
raise Exception("图片 URL 为空")
199+
caption = await self.get_image_caption(
200+
url,
201+
cfg["image_caption_provider_id"],
202+
cfg["image_caption_prompt"],
203+
)
204+
parts.append(f" [Image: {caption}]")
205+
except Exception as e:
206+
logger.error(f"获取图片描述失败: {e}")
207+
else:
208+
parts.append(" [Image]")
209+
elif isinstance(comp, At):
210+
is_at_self = str(comp.qq) in (
211+
event.get_self_id(),
212+
"all",
213+
)
214+
if is_at_self:
215+
parts.insert(1, "⚠️[DIRECTED AT YOU] ")
216+
parts.append(f" [At: {comp.name}]")
217+
218+
return "".join(parts)
219+
220+
221+
def _positive_int(value, fallback: int) -> int:
222+
try:
223+
parsed = int(value)
224+
except (TypeError, ValueError):
225+
return fallback
226+
return parsed if parsed > 0 else fallback
227+
228+
229+
def _trim_left(
230+
records: deque[str],
231+
max_records: int,
232+
record_ids: deque[str] | None = None,
233+
) -> None:
234+
while len(records) > max_records:
235+
records.popleft()
236+
if record_ids:
237+
record_ids.popleft()
238+
239+
240+
def _format_group_history_block(records: list[str]) -> str:
241+
return GROUP_HISTORY_HEADER + "\n".join(records) + GROUP_HISTORY_FOOTER

0 commit comments

Comments
 (0)