Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
a74d29f
refactor(ltm): redesign long-term memory with context compaction
RC-CHN May 18, 2026
b42c1d1
fix(ltm): handle malformed JSON in tool args and clean up lock on ses…
RC-CHN May 18, 2026
3826e13
fix(ltm): guard against duplicate system prompt note injection
RC-CHN May 18, 2026
bd2500c
fix(ltm): fall back to user message when internal marker parsing fails
RC-CHN May 18, 2026
e10d6d3
fix(ltm): release session lock during LLM summary generation
RC-CHN May 18, 2026
6aa43d6
fix(ltm): trim raw_records in handle_message to prevent unbounded growth
RC-CHN May 18, 2026
c1bd4ad
perf(ltm): use len(s) instead of len(s.encode()) in trim loop
RC-CHN May 18, 2026
8a501bd
feat(ltm): make user segment truncation limits configurable
RC-CHN May 18, 2026
031290f
feat(ltm): pre-fill default LTM summary prompt in config and i18n
RC-CHN May 19, 2026
6591fa9
Merge remote-tracking branch 'upstream/master' into refactor-ltm
w31r4 May 23, 2026
7e94aed
refactor(ltm): hardcode internal segment/trim constants
RC-CHN May 25, 2026
7a6e853
refactor(ltm): unify compaction strategy with main agent runner
RC-CHN May 25, 2026
ad9d4ed
feat(ltm): add @mention weight marker for group chat messages
RC-CHN May 25, 2026
3c87f95
test: fix test failures from LTM compaction unification
RC-CHN May 25, 2026
ffbc134
chore(dashboard): remove obsolete LTM compaction i18n metadata
RC-CHN May 25, 2026
194c6c1
chore: shrink codebase
Soulter May 30, 2026
f564de1
feat(group-chat): implement group chat context management and related…
Soulter May 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions astrbot/builtin_stars/astrbot/group_chat_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import asyncio
import datetime
import random
import uuid
from collections import defaultdict, deque

from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import MessageType
from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager

"""
Group chat context awareness.
"""

GROUP_HISTORY_HEADER = (
"<system_reminder>"
"You are in a group chat. "
"Belows are group chat context after your last reply:\n"
"--- BEGIN CONTEXT---\n"
)
GROUP_HISTORY_FOOTER = "\n--- END CONTEXT ---\n</system_reminder>"
DEFAULT_GROUP_MESSAGE_MAX_CNT = 300


class GroupChatContext:
def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
self.acm = acm
self.context = context
self._locks: dict[str, asyncio.Lock] = {}
self.raw_records: dict[str, deque[str]] = defaultdict(deque)
self._record_ids: dict[str, deque[str]] = defaultdict(deque)

def _get_lock(self, umo: str) -> asyncio.Lock:
lock = self._locks.get(umo)
if lock is None:
lock = asyncio.Lock()
self._locks[umo] = lock
return lock

def cfg(self, event: AstrMessageEvent):
cfg = self.context.get_config(umo=event.unified_msg_origin)
group_context_cfg = cfg["provider_ltm_settings"]
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
image_caption_provider_id = group_context_cfg.get("image_caption_provider_id")
image_caption = group_context_cfg["image_caption"] and bool(
image_caption_provider_id
)
active_reply = group_context_cfg["active_reply"]
enable_active_reply = active_reply.get("enable", False)
ar_method = active_reply["method"]
ar_possibility = active_reply["possibility_reply"]
ar_prompt = active_reply.get("prompt", "")
ar_whitelist = active_reply.get("whitelist", [])
return {
"group_message_max_cnt": _positive_int(
group_context_cfg.get(
"group_message_max_cnt",
DEFAULT_GROUP_MESSAGE_MAX_CNT,
),
DEFAULT_GROUP_MESSAGE_MAX_CNT,
),
"image_caption": image_caption,
"image_caption_prompt": image_caption_prompt,
"image_caption_provider_id": image_caption_provider_id,
"enable_active_reply": enable_active_reply,
"ar_method": ar_method,
"ar_possibility": ar_possibility,
"ar_prompt": ar_prompt,
"ar_whitelist": ar_whitelist,
}

async def get_image_caption(
self,
image_url: str,
image_caption_provider_id: str,
image_caption_prompt: str,
) -> str:
if not image_caption_provider_id:
provider = self.context.get_using_provider()
else:
provider = self.context.get_provider_by_id(image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
response = await provider.text_chat(
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
return response.completion_text

async def need_active_reply(self, event: AstrMessageEvent) -> bool:
cfg = self.cfg(event)
if not cfg["enable_active_reply"]:
return False
if event.get_message_type() != MessageType.GROUP_MESSAGE:
return False
if event.is_at_or_wake_command:
return False
if cfg["ar_whitelist"] and (
event.unified_msg_origin not in cfg["ar_whitelist"]
and (
event.get_group_id() and event.get_group_id() not in cfg["ar_whitelist"]
)
):
return False
match cfg["ar_method"]:
case "possibility_reply":
return random.random() < cfg["ar_possibility"]
return False

async def remove_session(self, event: AstrMessageEvent) -> int:
umo = event.unified_msg_origin
lock = self._get_lock(umo)
async with lock:
cnt = len(self.raw_records.get(umo, deque()))
self.raw_records.pop(umo, None)
self._record_ids.pop(umo, None)
self._locks.pop(umo, None)
return cnt

async def handle_message(self, event: AstrMessageEvent) -> None:
if event.get_message_type() != MessageType.GROUP_MESSAGE:
return

umo = event.unified_msg_origin
cfg = self.cfg(event)
final_message = await self._format_message(event, cfg)

async with self._get_lock(umo):
records = self.raw_records[umo]
record_ids = self._record_ids[umo]
record_id = uuid.uuid4().hex
records.append(final_message)
record_ids.append(record_id)
_trim_left(records, cfg["group_message_max_cnt"], record_ids)
event.set_extra("_group_context_record_id", record_id)
event.set_extra("_group_context_raw_idx", len(records) - 1)

logger.debug(f"group_chat_context | {umo} | {final_message}")

async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
umo = event.unified_msg_origin
record_id = event.get_extra("_group_context_record_id", None)
prompt_idx = event.get_extra("_group_context_raw_idx", -1)
if not isinstance(record_id, str) and (
not isinstance(prompt_idx, int) or prompt_idx < 0
):
return

async with self._get_lock(umo):
records = self.raw_records.get(umo)
if not records:
return

raw_list = list(records)
id_list = list(self._record_ids.get(umo, deque()))
if isinstance(record_id, str) and record_id in id_list:
prompt_idx = id_list.index(record_id)

if prompt_idx >= len(raw_list):
return

records_to_inject = raw_list[:prompt_idx]
remaining = raw_list[prompt_idx + 1 :]
remaining_ids = id_list[prompt_idx + 1 :] if id_list else []
records.clear()
records.extend(remaining)
if id_list:
record_ids = self._record_ids[umo]
record_ids.clear()
record_ids.extend(remaining_ids)

if records_to_inject:
req.extra_user_content_parts.append(
TextPart(text=_format_group_history_block(records_to_inject))
)

async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "]

for comp in event.get_messages():
if isinstance(comp, Plain):
parts.append(f" {comp.text}")
elif isinstance(comp, Image):
if cfg["image_caption"]:
try:
url = comp.url if comp.url else comp.file
if not url:
raise Exception("图片 URL 为空")
caption = await self.get_image_caption(
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
)
parts.append(f" [Image: {caption}]")
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
else:
parts.append(" [Image]")
elif isinstance(comp, At):
is_at_self = str(comp.qq) in (
event.get_self_id(),
"all",
)
if is_at_self:
parts.insert(1, "⚠️[DIRECTED AT YOU] ")
parts.append(f" [At: {comp.name}]")

return "".join(parts)


def _positive_int(value, fallback: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
return fallback
return parsed if parsed > 0 else fallback


def _trim_left(
records: deque[str],
max_records: int,
record_ids: deque[str] | None = None,
) -> None:
while len(records) > max_records:
records.popleft()
if record_ids:
record_ids.popleft()


def _format_group_history_block(records: list[str]) -> str:
return GROUP_HISTORY_HEADER + "\n".join(records) + GROUP_HISTORY_FOOTER
Loading
Loading