diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index e08cdc5157..9a314879a5 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -1,4 +1,5 @@ import datetime +import inspect import random import uuid from collections import defaultdict @@ -10,6 +11,7 @@ from astrbot.api.platform import MessageType from astrbot.api.provider import LLMResponse, Provider, ProviderRequest from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.message.message_event_result import ResultContentType """ 聊天记忆增强 @@ -25,26 +27,50 @@ def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: def cfg(self, event: AstrMessageEvent): cfg = self.context.get_config(umo=event.unified_msg_origin) + ltm_cfg = cfg["provider_ltm_settings"] try: - max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"]) + max_cnt = int(ltm_cfg["group_message_max_cnt"]) except BaseException as e: logger.error(e) max_cnt = 300 + try: + flow_max_records = int(ltm_cfg.get("group_flow_max_records", 5000)) + except BaseException as e: + logger.error(e) + flow_max_records = 5000 + try: + flow_max_delta_messages = int( + ltm_cfg.get("group_flow_max_delta_messages", 200) + ) + except BaseException as e: + logger.error(e) + flow_max_delta_messages = 200 + try: + flow_max_message_chars = int( + ltm_cfg.get("group_flow_max_message_chars", 1000) + ) + except BaseException as e: + logger.error(e) + flow_max_message_chars = 1000 image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] - image_caption_provider_id = cfg["provider_ltm_settings"].get( - "image_caption_provider_id" - ) - image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool( - image_caption_provider_id - ) - active_reply = cfg["provider_ltm_settings"]["active_reply"] + image_caption_provider_id = ltm_cfg.get("image_caption_provider_id") + image_caption = ltm_cfg["image_caption"] and bool(image_caption_provider_id) + active_reply = ltm_cfg["active_reply"] enable_active_reply = active_reply.get("enable", False) ar_method = active_reply["method"] ar_possibility = active_reply["possibility_reply"] ar_prompt = active_reply.get("prompt", "") ar_whitelist = active_reply.get("whitelist", []) ret = { + "group_icl_enable": ltm_cfg.get("group_icl_enable", False), + "group_context_mode": ltm_cfg.get("group_context_mode", "sliding_window"), "max_cnt": max_cnt, + "flow_max_records": flow_max_records, + "flow_max_delta_messages": flow_max_delta_messages, + "flow_max_message_chars": flow_max_message_chars, + "flow_record_bot_messages": ltm_cfg.get( + "group_flow_record_bot_messages", False + ), "image_caption": image_caption, "image_caption_prompt": image_caption_prompt, "image_caption_provider_id": image_caption_provider_id, @@ -56,13 +82,57 @@ def cfg(self, event: AstrMessageEvent): } return ret + def _is_flow_mode(self, event: AstrMessageEvent, cfg: dict | None = None) -> bool: + cfg = cfg or self.cfg(event) + return ( + bool(cfg.get("group_icl_enable")) + and cfg.get("group_context_mode") == "flow" + and event.get_message_type() == MessageType.GROUP_MESSAGE + ) + + def _flow_session_id(self, event: AstrMessageEvent) -> str: + group_id = event.get_group_id() + if group_id: + return f"{event.get_platform_id()}:{MessageType.GROUP_MESSAGE.value}:{group_id}" + return event.unified_msg_origin + + def _append_sliding_message( + self, + event: AstrMessageEvent, + message: str, + max_cnt: int, + ) -> None: + logger.debug(f"ltm | {event.unified_msg_origin} | {message}") + self.session_chats[event.unified_msg_origin].append(message) + if len(self.session_chats[event.unified_msg_origin]) > max_cnt: + self.session_chats[event.unified_msg_origin].pop(0) + async def remove_session(self, event: AstrMessageEvent) -> int: cnt = 0 if event.unified_msg_origin in self.session_chats: cnt = len(self.session_chats[event.unified_msg_origin]) del self.session_chats[event.unified_msg_origin] + if self._is_flow_mode(event): + await self.reset_flow_cursor(event) return cnt + async def reset_flow_cursor(self, event: AstrMessageEvent) -> None: + curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + event.unified_msg_origin + ) + if not curr_cid: + return + flow_session_id = self._flow_session_id(event) + latest_id = await self.context.group_message_flow_manager.get_latest_record_id( + flow_session_id + ) + await self.context.group_message_flow_manager.set_cursor( + platform_id=event.get_platform_id(), + flow_session_id=flow_session_id, + conversation_id=curr_cid, + last_record_id=latest_id, + ) + async def get_image_caption( self, image_url: str, @@ -111,52 +181,217 @@ async def need_active_reply(self, event: AstrMessageEvent) -> bool: return False + async def _render_group_message( + self, + event: AstrMessageEvent, + cfg: dict, + sender_name: str | None = None, + ) -> str: + """Render one group message in the legacy LTM style.""" + datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + display_name = sender_name or event.get_sender_name() or event.get_sender_id() + parts = [f"[{display_name}/{datetime_str}]: "] + + for comp in event.get_messages(): + if isinstance(comp, Plain): + parts.append(f" {comp.text}") + elif isinstance(comp, Image): + if cfg["image_caption"]: + try: + url = comp.url if comp.url else comp.file + if not url: + raise Exception("图片 URL 为空") + caption = await self.get_image_caption( + url, + cfg["image_caption_provider_id"], + cfg["image_caption_prompt"], + ) + parts.append(f" [Image: {caption}]") + except Exception as e: + logger.error(f"获取图片描述失败: {e}") + parts.append(" [Image]") + else: + parts.append(" [Image]") + elif isinstance(comp, At): + parts.append(f" [At: {comp.name or comp.qq}]") + else: + comp_type = getattr(comp, "type", comp.__class__.__name__) + parts.append(f" [{comp_type}]") + + return "".join(parts) + + async def _components_to_dict(self, components) -> list[dict]: + content = [] + for comp in components: + try: + content.append(await self._component_to_json_safe_dict(comp)) + except Exception as e: + logger.warning(f"Failed to serialize group flow message component: {e}") + return content + + async def _component_to_json_safe_dict(self, comp) -> dict: + if hasattr(comp, "to_dict"): + data = comp.to_dict() + if inspect.isawaitable(data): + data = await data + elif hasattr(comp, "toDict"): + data = comp.toDict() + else: + data = {"type": comp.__class__.__name__, "data": {}} + return await self._json_safe(data) + + async def _json_safe(self, value): + if hasattr(value, "to_dict"): + return await self._component_to_json_safe_dict(value) + if hasattr(value, "toDict"): + return await self._json_safe(value.toDict()) + if isinstance(value, dict): + return {k: await self._json_safe(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [await self._json_safe(item) for item in value] + return value + + async def _message_content_to_dict(self, event: AstrMessageEvent) -> list[dict]: + return await self._components_to_dict(event.get_messages()) + + def _truncate_flow_message_text(self, message: str, max_chars: int) -> str: + if max_chars <= 0: + return message + return message[:max_chars] + + async def _record_flow_message( + self, + event: AstrMessageEvent, + rendered_text: str, + role: str = "user", + content: list[dict] | None = None, + ) -> int | None: + cfg = self.cfg(event) + if not self._is_flow_mode(event, cfg): + return None + flow_session_id = self._flow_session_id(event) + record = await self.context.group_message_flow_manager.insert_record( + platform_id=event.get_platform_id(), + flow_session_id=flow_session_id, + group_id=event.get_group_id() or None, + sender_id=event.get_sender_id() if role == "user" else event.get_self_id(), + sender_name=event.get_sender_name() if role == "user" else "You", + role=role, + content=content + if content is not None + else await self._message_content_to_dict(event), + rendered_text=rendered_text, + ) + await self.context.group_message_flow_manager.prune_records( + flow_session_id, + cfg["flow_max_records"], + ) + return record.id + async def handle_message(self, event: AstrMessageEvent) -> None: """仅支持群聊""" if event.get_message_type() == MessageType.GROUP_MESSAGE: - datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + cfg = self.cfg(event) + final_message = await self._render_group_message(event, cfg) - parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "] + if cfg["enable_active_reply"] or not self._is_flow_mode(event, cfg): + self._append_sliding_message(event, final_message, cfg["max_cnt"]) - cfg = self.cfg(event) + if self._is_flow_mode(event, cfg): + record_id = await self._record_flow_message(event, final_message) + if record_id: + event.set_extra("_group_message_flow_record_id", record_id) - for comp in event.get_messages(): - if isinstance(comp, Plain): - parts.append(f" {comp.text}") - elif isinstance(comp, Image): - if cfg["image_caption"]: - try: - url = comp.url if comp.url else comp.file - if not url: - raise Exception("图片 URL 为空") - caption = await self.get_image_caption( - url, - cfg["image_caption_provider_id"], - cfg["image_caption_prompt"], - ) - parts.append(f" [Image: {caption}]") - except Exception as e: - logger.error(f"获取图片描述失败: {e}") - else: - parts.append(" [Image]") - elif isinstance(comp, At): - parts.append(f" [At: {comp.name}]") + async def _inject_flow_delta( + self, + event: AstrMessageEvent, + req: ProviderRequest, + cfg: dict, + ) -> None: + if not req.conversation: + return + flow_session_id = self._flow_session_id(event) + cursor = await self.context.group_message_flow_manager.get_cursor( + flow_session_id, + req.conversation.cid, + ) + after_id = cursor.last_record_id if cursor else 0 + current_record_id = event.get_extra("_group_message_flow_record_id") + if isinstance(current_record_id, int) and current_record_id > 0: + before_id = current_record_id + next_cursor_id = current_record_id + else: + before_id = None + next_cursor_id = ( + await self.context.group_message_flow_manager.get_latest_record_id( + flow_session_id + ) + ) - final_message = "".join(parts) - logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") - self.session_chats[event.unified_msg_origin].append(final_message) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) + records = await self.context.group_message_flow_manager.get_records_after( + flow_session_id=flow_session_id, + after_id=after_id, + before_id=before_id, + limit=cfg["flow_max_delta_messages"], + ) + if records: + chats_str = "\n---\n".join( + self._truncate_flow_message_text( + record.rendered_text, + cfg["flow_max_message_chars"], + ) + for record in records + ) + req.system_prompt += ( + "\n\n" + "You are now in a chatroom. New group messages since the last turn:\n" + f"{chats_str}\n" + "" + ) - async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: - """当触发 LLM 请求前,调用此方法修改 req""" - if event.unified_msg_origin not in self.session_chats: + event.set_extra( + "_group_message_flow_pending_cursor", + { + "platform_id": event.get_platform_id(), + "flow_session_id": flow_session_id, + "conversation_id": req.conversation.cid, + "last_record_id": next_cursor_id, + }, + ) + + async def _commit_pending_flow_cursor( + self, + event: AstrMessageEvent, + llm_resp: LLMResponse, + ) -> None: + if not llm_resp or llm_resp.role == "err": + return + + pending = event.get_extra("_group_message_flow_pending_cursor") + if not isinstance(pending, dict): + return + + platform_id = str(pending.get("platform_id") or "") + flow_session_id = str(pending.get("flow_session_id") or "") + conversation_id = str(pending.get("conversation_id") or "") + last_record_id = int(pending.get("last_record_id") or 0) + if not platform_id or not flow_session_id or not conversation_id: return - chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) + await self.context.group_message_flow_manager.set_cursor( + platform_id=platform_id, + flow_session_id=flow_session_id, + conversation_id=conversation_id, + last_record_id=last_record_id, + ) + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: + """当触发 LLM 请求前,调用此方法修改 req""" cfg = self.cfg(event) if cfg["enable_active_reply"]: + if event.unified_msg_origin not in self.session_chats: + return + chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) prompt = req.prompt req.prompt = ( f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" @@ -165,7 +400,12 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non "You MUST use the SAME language as the chatroom is using." ) req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 + elif self._is_flow_mode(event, cfg): + await self._inject_flow_delta(event, req, cfg) else: + if event.unified_msg_origin not in self.session_chats: + return + chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) req.system_prompt += ( "You are now in a chatroom. The chat history is as follows: \n" ) @@ -174,6 +414,10 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non async def after_req_llm( self, event: AstrMessageEvent, llm_resp: LLMResponse ) -> None: + cfg = self.cfg(event) + if self._is_flow_mode(event, cfg) and not cfg["enable_active_reply"]: + await self._commit_pending_flow_cursor(event, llm_resp) + return if event.unified_msg_origin not in self.session_chats: return @@ -182,7 +426,30 @@ async def after_req_llm( logger.debug( f"Recorded AI response: {event.unified_msg_origin} | {final_message}" ) - self.session_chats[event.unified_msg_origin].append(final_message) - cfg = self.cfg(event) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) + self._append_sliding_message(event, final_message, cfg["max_cnt"]) + + async def record_bot_message(self, event: AstrMessageEvent) -> None: + cfg = self.cfg(event) + if not self._is_flow_mode(event, cfg): + return + if not cfg["flow_record_bot_messages"]: + return + + result = event.get_result() + if not result or not result.chain: + return + if result.result_content_type in { + ResultContentType.LLM_RESULT, + ResultContentType.STREAMING_RESULT, + ResultContentType.STREAMING_FINISH, + }: + return + + datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + rendered_text = f"[You/{datetime_str}]: {result.get_plain_text(True)}" + await self._record_flow_message( + event, + rendered_text, + role="bot", + content=await self._components_to_dict(result.chain), + ) diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index c1500a5d1f..9c5e5ce503 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -238,5 +238,7 @@ async def after_message_sent(self, event: AstrMessageEvent) -> None: clean_session = event.get_extra("_clean_ltm_session", False) if clean_session: await self.ltm.remove_session(event) + else: + await self.ltm.record_bot_message(event) except Exception as e: logger.error(f"ltm: {e}") diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ce79559bd6..0110f55b53 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -217,7 +217,12 @@ }, "provider_ltm_settings": { "group_icl_enable": False, + "group_context_mode": "sliding_window", "group_message_max_cnt": 300, + "group_flow_max_records": 5000, + "group_flow_max_delta_messages": 200, + "group_flow_max_message_chars": 1000, + "group_flow_record_bot_messages": False, "image_caption": False, "image_caption_provider_id": "", "active_reply": { @@ -2884,9 +2889,25 @@ "group_icl_enable": { "type": "bool", }, + "group_context_mode": { + "type": "string", + "options": ["sliding_window", "flow"], + }, "group_message_max_cnt": { "type": "int", }, + "group_flow_max_records": { + "type": "int", + }, + "group_flow_max_delta_messages": { + "type": "int", + }, + "group_flow_max_message_chars": { + "type": "int", + }, + "group_flow_record_bot_messages": { + "type": "bool", + }, "image_caption": { "type": "bool", }, @@ -4100,9 +4121,60 @@ "description": "启用群聊上下文感知", "type": "bool", }, + "provider_ltm_settings.group_context_mode": { + "description": "群聊上下文模式", + "type": "string", + "options": ["sliding_window", "flow"], + "labels": ["滑动窗口", "消息流"], + "hint": "sliding_window 保持旧的滑动窗口行为;flow 使用持久化群聊消息流和对话游标。", + "condition": { + "provider_ltm_settings.group_icl_enable": True, + }, + }, "provider_ltm_settings.group_message_max_cnt": { "description": "最大消息数量", "type": "int", + "hint": "仅用于 sliding_window 模式。", + "condition": { + "provider_ltm_settings.group_icl_enable": True, + "provider_ltm_settings.group_context_mode": "sliding_window", + }, + }, + "provider_ltm_settings.group_flow_max_records": { + "description": "群聊消息流保留数量", + "type": "int", + "hint": "仅用于 flow 模式。每个群聊消息流最多保留的历史消息数,0 表示不清理。", + "condition": { + "provider_ltm_settings.group_icl_enable": True, + "provider_ltm_settings.group_context_mode": "flow", + }, + }, + "provider_ltm_settings.group_flow_max_delta_messages": { + "description": "单次注入消息数量上限", + "type": "int", + "hint": "仅用于 flow 模式。每次只注入游标之后、当前触发消息之前的最近 N 条群聊消息;0 表示不限制。", + "condition": { + "provider_ltm_settings.group_icl_enable": True, + "provider_ltm_settings.group_context_mode": "flow", + }, + }, + "provider_ltm_settings.group_flow_max_message_chars": { + "description": "单条消息字符上限", + "type": "int", + "hint": "仅用于 flow 模式。每条注入的群聊消息最多保留前 N 个字符;0 表示不限制。", + "condition": { + "provider_ltm_settings.group_icl_enable": True, + "provider_ltm_settings.group_context_mode": "flow", + }, + }, + "provider_ltm_settings.group_flow_record_bot_messages": { + "description": "记录普通机器人消息", + "type": "bool", + "hint": "仅用于 flow 模式。LLM 本次回复始终不会写入群聊消息流;此项只影响命令或插件产生的普通机器人消息。", + "condition": { + "provider_ltm_settings.group_icl_enable": True, + "provider_ltm_settings.group_context_mode": "flow", + }, }, "provider_ltm_settings.image_caption": { "description": "自动理解图片", diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 725b170003..5151740181 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -23,6 +23,7 @@ from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron import CronJobManager from astrbot.core.db import BaseDatabase +from astrbot.core.group_message_flow_mgr import GroupMessageFlowManager from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -211,6 +212,9 @@ async def initialize(self) -> None: # 初始化平台消息历史管理器 self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) + # 初始化群聊消息流管理器 + self.group_message_flow_manager = GroupMessageFlowManager(self.db) + # 初始化知识库管理器 self.kb_manager = KnowledgeBaseManager(self.provider_manager) @@ -233,7 +237,8 @@ async def initialize(self) -> None: self.astrbot_config_mgr, self.kb_manager, self.cron_manager, - self.subagent_orchestrator, + subagent_orchestrator=self.subagent_orchestrator, + group_message_flow_manager=self.group_message_flow_manager, ) # 初始化插件管理器 diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 1800887fb0..e038569889 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -15,6 +15,8 @@ CommandConflict, ConversationV2, CronJob, + GroupMessageFlowCursor, + GroupMessageFlowRecord, Persona, PersonaFolder, PlatformMessageHistory, @@ -254,6 +256,69 @@ async def get_platform_message_history_by_id( """Get a platform message history record by its ID.""" ... + @abc.abstractmethod + async def insert_group_message_flow_record( + self, + platform_id: str, + flow_session_id: str, + content: list, + rendered_text: str, + group_id: str | None = None, + sender_id: str | None = None, + sender_name: str | None = None, + role: str = "user", + ) -> GroupMessageFlowRecord: + """Insert a persisted group message flow record.""" + ... + + @abc.abstractmethod + async def get_group_message_flow_records_after( + self, + flow_session_id: str, + after_id: int, + before_id: int | None = None, + limit: int = 0, + ) -> list[GroupMessageFlowRecord]: + """Get recent group message flow records after a cursor, ordered oldest first.""" + ... + + @abc.abstractmethod + async def get_latest_group_message_flow_record_id( + self, + flow_session_id: str, + ) -> int: + """Get the latest record ID for a group message flow.""" + ... + + @abc.abstractmethod + async def get_group_message_flow_cursor( + self, + flow_session_id: str, + conversation_id: str, + ) -> GroupMessageFlowCursor | None: + """Get a conversation cursor for a group message flow.""" + ... + + @abc.abstractmethod + async def upsert_group_message_flow_cursor( + self, + platform_id: str, + flow_session_id: str, + conversation_id: str, + last_record_id: int, + ) -> GroupMessageFlowCursor: + """Create or update a conversation cursor for a group message flow.""" + ... + + @abc.abstractmethod + async def prune_group_message_flow_records( + self, + flow_session_id: str, + max_records: int, + ) -> None: + """Keep at most max_records records for a group message flow.""" + ... + @abc.abstractmethod async def create_webchat_thread( self, diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 0d3b9822a3..5ff07e89b4 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -247,6 +247,50 @@ class PlatformMessageHistory(TimestampMixin, SQLModel, table=True): llm_checkpoint_id: str | None = Field(default=None, index=True) +class GroupMessageFlowRecord(TimestampMixin, SQLModel, table=True): + """Persisted group chat messages for long-context group flow.""" + + __tablename__: str = "group_message_flow_records" + + id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + platform_id: str = Field(nullable=False, index=True) + flow_session_id: str = Field(nullable=False, index=True) + group_id: str | None = Field(default=None, index=True) + sender_id: str | None = Field(default=None, index=True) + sender_name: str | None = Field(default=None) + role: str = Field(default="user", nullable=False, index=True) + content: list = Field(default_factory=list, sa_type=JSON, nullable=False) + rendered_text: str = Field(default="", sa_type=Text, nullable=False) + + +class GroupMessageFlowCursor(TimestampMixin, SQLModel, table=True): + """Per-conversation cursor into a group message flow.""" + + __tablename__: str = "group_message_flow_cursors" + + id: int | None = Field( + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, + ) + platform_id: str = Field(nullable=False, index=True) + flow_session_id: str = Field(nullable=False, index=True) + conversation_id: str = Field(nullable=False, index=True) + last_record_id: int = Field(default=0, nullable=False) + + __table_args__ = ( + UniqueConstraint( + "flow_session_id", + "conversation_id", + name="uix_group_message_flow_cursor", + ), + ) + + class WebChatThread(TimestampMixin, SQLModel, table=True): """A side thread created from a selected WebChat assistant response.""" diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index d79ac9d703..8ac3b0b4d1 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -17,6 +17,8 @@ CommandConflict, ConversationV2, CronJob, + GroupMessageFlowCursor, + GroupMessageFlowRecord, Persona, PersonaFolder, PlatformMessageHistory, @@ -627,6 +629,161 @@ async def get_platform_message_history_by_id( result = await session.execute(query) return result.scalar_one_or_none() + async def insert_group_message_flow_record( + self, + platform_id: str, + flow_session_id: str, + content: list, + rendered_text: str, + group_id: str | None = None, + sender_id: str | None = None, + sender_name: str | None = None, + role: str = "user", + ) -> GroupMessageFlowRecord: + """Insert a persisted group message flow record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + record = GroupMessageFlowRecord( + platform_id=platform_id, + flow_session_id=flow_session_id, + group_id=group_id, + sender_id=sender_id, + sender_name=sender_name, + role=role, + content=content, + rendered_text=rendered_text, + ) + session.add(record) + await session.flush() + await session.refresh(record) + return record + + async def get_group_message_flow_records_after( + self, + flow_session_id: str, + after_id: int, + before_id: int | None = None, + limit: int = 0, + ) -> list[GroupMessageFlowRecord]: + """Get recent group message flow records after a cursor, ordered oldest first.""" + async with self.get_db() as session: + session: AsyncSession + conditions = [ + col(GroupMessageFlowRecord.flow_session_id) == flow_session_id, + col(GroupMessageFlowRecord.id) > after_id, + ] + if before_id is not None: + conditions.append(col(GroupMessageFlowRecord.id) < before_id) + if limit and limit > 0: + query = ( + select(GroupMessageFlowRecord) + .where(*conditions) + .order_by(desc(GroupMessageFlowRecord.id)) + .limit(limit) + ) + result = await session.execute(query) + return list(reversed(result.scalars().all())) + query = ( + select(GroupMessageFlowRecord) + .where(*conditions) + .order_by(col(GroupMessageFlowRecord.id)) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def get_latest_group_message_flow_record_id( + self, + flow_session_id: str, + ) -> int: + """Get the latest record ID for a group message flow.""" + async with self.get_db() as session: + session: AsyncSession + query = select(func.max(GroupMessageFlowRecord.id)).where( + col(GroupMessageFlowRecord.flow_session_id) == flow_session_id + ) + result = await session.execute(query) + return int(result.scalar_one_or_none() or 0) + + async def get_group_message_flow_cursor( + self, + flow_session_id: str, + conversation_id: str, + ) -> GroupMessageFlowCursor | None: + """Get a conversation cursor for a group message flow.""" + async with self.get_db() as session: + session: AsyncSession + query = select(GroupMessageFlowCursor).where( + col(GroupMessageFlowCursor.flow_session_id) == flow_session_id, + col(GroupMessageFlowCursor.conversation_id) == conversation_id, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def upsert_group_message_flow_cursor( + self, + platform_id: str, + flow_session_id: str, + conversation_id: str, + last_record_id: int, + ) -> GroupMessageFlowCursor: + """Create or update a conversation cursor for a group message flow.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + select(GroupMessageFlowCursor).where( + col(GroupMessageFlowCursor.flow_session_id) == flow_session_id, + col(GroupMessageFlowCursor.conversation_id) == conversation_id, + ) + ) + cursor = result.scalar_one_or_none() + if cursor: + cursor.platform_id = platform_id + cursor.last_record_id = last_record_id + session.add(cursor) + else: + cursor = GroupMessageFlowCursor( + platform_id=platform_id, + flow_session_id=flow_session_id, + conversation_id=conversation_id, + last_record_id=last_record_id, + ) + session.add(cursor) + await session.flush() + await session.refresh(cursor) + return cursor + + async def prune_group_message_flow_records( + self, + flow_session_id: str, + max_records: int, + ) -> None: + """Keep at most max_records records for a group message flow.""" + if max_records <= 0: + return + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + cutoff_result = await session.execute( + select(GroupMessageFlowRecord.id) + .where( + col(GroupMessageFlowRecord.flow_session_id) == flow_session_id + ) + .order_by(desc(GroupMessageFlowRecord.id)) + .offset(max_records) + .limit(1) + ) + cutoff_id = cutoff_result.scalar_one_or_none() + if cutoff_id is None: + return + await session.execute( + delete(GroupMessageFlowRecord).where( + col(GroupMessageFlowRecord.flow_session_id) == flow_session_id, + col(GroupMessageFlowRecord.id) <= cutoff_id, + ) + ) + async def create_webchat_thread( self, creator: str, diff --git a/astrbot/core/group_message_flow_mgr.py b/astrbot/core/group_message_flow_mgr.py new file mode 100644 index 0000000000..ff92c5ec8a --- /dev/null +++ b/astrbot/core/group_message_flow_mgr.py @@ -0,0 +1,75 @@ +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import GroupMessageFlowCursor, GroupMessageFlowRecord + + +class GroupMessageFlowManager: + """Manage persisted group message flows and per-conversation cursors.""" + + def __init__(self, db: BaseDatabase) -> None: + self.db = db + + async def insert_record( + self, + platform_id: str, + flow_session_id: str, + content: list, + rendered_text: str, + group_id: str | None = None, + sender_id: str | None = None, + sender_name: str | None = None, + role: str = "user", + ) -> GroupMessageFlowRecord: + return await self.db.insert_group_message_flow_record( + platform_id=platform_id, + flow_session_id=flow_session_id, + group_id=group_id, + sender_id=sender_id, + sender_name=sender_name, + role=role, + content=content, + rendered_text=rendered_text, + ) + + async def get_records_after( + self, + flow_session_id: str, + after_id: int, + before_id: int | None = None, + limit: int = 0, + ) -> list[GroupMessageFlowRecord]: + return await self.db.get_group_message_flow_records_after( + flow_session_id=flow_session_id, + after_id=after_id, + before_id=before_id, + limit=limit, + ) + + async def get_latest_record_id(self, flow_session_id: str) -> int: + return await self.db.get_latest_group_message_flow_record_id(flow_session_id) + + async def get_cursor( + self, + flow_session_id: str, + conversation_id: str, + ) -> GroupMessageFlowCursor | None: + return await self.db.get_group_message_flow_cursor( + flow_session_id=flow_session_id, + conversation_id=conversation_id, + ) + + async def set_cursor( + self, + platform_id: str, + flow_session_id: str, + conversation_id: str, + last_record_id: int, + ) -> GroupMessageFlowCursor: + return await self.db.upsert_group_message_flow_cursor( + platform_id=platform_id, + flow_session_id=flow_session_id, + conversation_id=conversation_id, + last_record_id=last_record_id, + ) + + async def prune_records(self, flow_session_id: str, max_records: int) -> None: + await self.db.prune_group_message_flow_records(flow_session_id, max_records) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 593bad9365..abc194e081 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -15,6 +15,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase +from astrbot.core.group_message_flow_mgr import GroupMessageFlowManager from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain from astrbot.core.persona_mgr import PersonaManager @@ -80,6 +81,7 @@ def __init__( knowledge_base_manager: KnowledgeBaseManager, cron_manager: CronJobManager, subagent_orchestrator: SubAgentOrchestrator | None = None, + group_message_flow_manager: GroupMessageFlowManager | None = None, ) -> None: self._event_queue = event_queue """事件队列。消息平台通过事件队列传递消息事件。""" @@ -95,6 +97,10 @@ def __init__( """会话管理器""" self.message_history_manager = message_history_manager """平台消息历史管理器""" + self.group_message_flow_manager = group_message_flow_manager or ( + GroupMessageFlowManager(db) + ) + """群聊消息流管理器""" self.persona_manager = persona_manager """人格角色设定管理器""" self.astrbot_config_mgr = astrbot_config_mgr diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 6363b71e31..b07be1ac64 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -985,8 +985,33 @@ "group_icl_enable": { "description": "Enable Group Chat Context Awareness" }, + "group_context_mode": { + "description": "Group Chat Context Mode", + "hint": "sliding_window keeps the existing sliding-window behavior; flow uses persisted group message flow and conversation cursors.", + "labels": [ + "Sliding Window", + "Message Flow" + ] + }, "group_message_max_cnt": { - "description": "Maximum Message Count" + "description": "Maximum Message Count", + "hint": "Only used by sliding_window mode." + }, + "group_flow_max_records": { + "description": "Group Message Flow Retention", + "hint": "Only used by flow mode. Maximum records retained per group message flow. 0 disables cleanup." + }, + "group_flow_max_delta_messages": { + "description": "Injected Message Count Limit", + "hint": "Only used by flow mode. Each request injects only the latest N group messages after the cursor and before the current trigger message. 0 disables this limit." + }, + "group_flow_max_message_chars": { + "description": "Per-message Character Limit", + "hint": "Only used by flow mode. Each injected group message keeps at most the first N characters. 0 disables this limit." + }, + "group_flow_record_bot_messages": { + "description": "Record General Bot Messages", + "hint": "Only used by flow mode. The current LLM reply is never written to group message flow; this only affects general bot messages from commands or plugins." }, "image_caption": { "description": "Auto-understand Images", diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 028bff8675..9eed186b36 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -986,8 +986,33 @@ "group_icl_enable": { "description": "Включить осведомленность о контексте группы" }, + "group_context_mode": { + "description": "Режим контекста группового чата", + "hint": "sliding_window сохраняет прежнее поведение скользящего окна; flow использует сохраненный поток сообщений группы и курсоры диалогов.", + "labels": [ + "Скользящее окно", + "Поток сообщений" + ] + }, "group_message_max_cnt": { - "description": "Максимальное количество сообщений" + "description": "Максимальное количество сообщений", + "hint": "Используется только в режиме sliding_window." + }, + "group_flow_max_records": { + "description": "Хранение потока сообщений группы", + "hint": "Используется только в режиме flow. Максимальное количество сообщений для каждого потока группы. 0 отключает очистку." + }, + "group_flow_max_delta_messages": { + "description": "Лимит количества добавляемых сообщений", + "hint": "Используется только в режиме flow. Каждый запрос добавляет только последние N сообщений группы после курсора и до текущего сообщения-триггера. 0 отключает этот лимит." + }, + "group_flow_max_message_chars": { + "description": "Лимит символов на сообщение", + "hint": "Используется только в режиме flow. Для каждого добавляемого сообщения группы сохраняются только первые N символов. 0 отключает этот лимит." + }, + "group_flow_record_bot_messages": { + "description": "Записывать обычные сообщения бота", + "hint": "Используется только в режиме flow. Текущий ответ LLM никогда не записывается в поток группы; настройка влияет только на обычные сообщения бота из команд или плагинов." }, "image_caption": { "description": "Автоматическое понимание изображений", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 70f4fa5c79..94f161b670 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -987,8 +987,33 @@ "group_icl_enable": { "description": "启用群聊上下文感知" }, + "group_context_mode": { + "description": "群聊上下文模式", + "hint": "sliding_window 保持旧的滑动窗口行为;flow 使用持久化群聊消息流和对话游标。", + "labels": [ + "滑动窗口", + "消息流" + ] + }, "group_message_max_cnt": { - "description": "最大消息数量" + "description": "最大消息数量", + "hint": "仅用于 sliding_window 模式。" + }, + "group_flow_max_records": { + "description": "群聊消息流保留数量", + "hint": "仅用于 flow 模式。每个群聊消息流最多保留的历史消息数,0 表示不清理。" + }, + "group_flow_max_delta_messages": { + "description": "单次注入消息数量上限", + "hint": "仅用于 flow 模式。每次只注入游标之后、当前触发消息之前的最近 N 条群聊消息;0 表示不限制。" + }, + "group_flow_max_message_chars": { + "description": "单条消息字符上限", + "hint": "仅用于 flow 模式。每条注入的群聊消息最多保留前 N 个字符;0 表示不限制。" + }, + "group_flow_record_bot_messages": { + "description": "记录普通机器人消息", + "hint": "仅用于 flow 模式。LLM 本次回复始终不会写入群聊消息流;此项只影响命令或插件产生的普通机器人消息。" }, "image_caption": { "description": "自动理解图片", diff --git a/tests/conftest.py b/tests/conftest.py index b9807c1ded..de2c3c211c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,12 @@ import sys from asyncio import Queue from pathlib import Path -from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio # 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义 -from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component # 将项目根目录添加到 sys.path PROJECT_ROOT = Path(__file__).parent.parent @@ -312,6 +310,7 @@ async def mock_context( platform_manager = MagicMock() conversation_manager = MagicMock() message_history_manager = MagicMock() + group_message_flow_manager = MagicMock() persona_manager = MagicMock() persona_manager.personas_v3 = [] astrbot_config_mgr = MagicMock() @@ -332,6 +331,7 @@ async def mock_context( knowledge_base_manager, cron_manager, subagent_orchestrator, + group_message_flow_manager, ) return context diff --git a/tests/test_group_message_flow.py b/tests/test_group_message_flow.py new file mode 100644 index 0000000000..f9a455b167 --- /dev/null +++ b/tests/test_group_message_flow.py @@ -0,0 +1,271 @@ +from copy import deepcopy +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory +from astrbot.core.config.default import DEFAULT_CONFIG +from astrbot.core.db.po import Conversation +from astrbot.core.group_message_flow_mgr import GroupMessageFlowManager +from astrbot.core.message.components import Plain, Reply +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.provider.entities import LLMResponse, ProviderRequest + + +class ConcreteAstrMessageEvent(AstrMessageEvent): + async def send(self, message): + await super().send(message) + + +class SyncDictComponent: + def to_dict(self): + return {"type": "sync", "data": {"items": [Plain(text="nested sync")]}} + + +def _flow_config(overrides: dict | None = None) -> dict: + config = deepcopy(DEFAULT_CONFIG) + config["provider_ltm_settings"]["group_icl_enable"] = True + config["provider_ltm_settings"]["group_context_mode"] = "flow" + config["provider_ltm_settings"]["group_flow_max_records"] = 0 + config["provider_ltm_settings"]["image_caption"] = False + config["provider_ltm_settings"]["active_reply"]["enable"] = False + if overrides: + config["provider_ltm_settings"].update(overrides) + return config + + +def _event(text: str, sender_id: str = "user-1", sender_name: str = "Alice"): + message = AstrBotMessage() + message.type = MessageType.GROUP_MESSAGE + message.self_id = "bot-1" + message.session_id = "group-1" + message.message_id = f"msg-{text}" + message.sender = MessageMember(user_id=sender_id, nickname=sender_name) + message.group_id = "group-1" + message.message = [Plain(text=text)] + message.message_str = text + message.raw_message = None + return ConcreteAstrMessageEvent( + message_str=text, + message_obj=message, + platform_meta=PlatformMetadata( + name="aiocqhttp", + description="test", + id="default", + ), + session_id="group-1", + ) + + +def _conversation(cid: str = "conv-1") -> Conversation: + return Conversation( + platform_id="default", + user_id="default:GroupMessage:group-1", + cid=cid, + history=[], + ) + + +@pytest.fixture +def flow_context(temp_db): + return _flow_context(temp_db) + + +def _flow_context(temp_db, config: dict | None = None): + manager = GroupMessageFlowManager(temp_db) + config = config or _flow_config() + return SimpleNamespace( + get_config=lambda umo=None: config, + get_using_provider=lambda *args, **kwargs: None, + get_provider_by_id=lambda *args, **kwargs: None, + group_message_flow_manager=manager, + conversation_manager=SimpleNamespace( + get_curr_conversation_id=AsyncMock(return_value="conv-1"), + ), + ) + + +@pytest.mark.asyncio +async def test_group_flow_delta_excludes_current_message(temp_db, flow_context): + ltm = LongTermMemory(AsyncMock(), flow_context) + previous = _event("previous message", sender_name="Alice") + trigger = _event("trigger message", sender_name="Bob") + + await ltm.handle_message(previous) + await ltm.handle_message(trigger) + + req = ProviderRequest(prompt="trigger message", conversation=_conversation()) + await ltm.on_req_llm(trigger, req) + + assert "" in req.system_prompt + assert "previous message" in req.system_prompt + assert "Alice" in req.system_prompt + assert "trigger message" not in req.system_prompt + assert "message_id" not in req.system_prompt + assert "seq_" not in req.system_prompt + + +@pytest.mark.asyncio +async def test_group_flow_cursor_advances_between_turns(temp_db, flow_context): + ltm = LongTermMemory(AsyncMock(), flow_context) + first = _event("first background") + first_trigger = _event("first trigger") + second_background = _event("second background", sender_name="Carol") + second_trigger = _event("second trigger") + + await ltm.handle_message(first) + await ltm.handle_message(first_trigger) + first_req = ProviderRequest(prompt="first trigger", conversation=_conversation()) + await ltm.on_req_llm(first_trigger, first_req) + await ltm.after_req_llm( + first_trigger, + LLMResponse(role="assistant", completion_text="first response"), + ) + + await ltm.handle_message(second_background) + await ltm.handle_message(second_trigger) + second_req = ProviderRequest(prompt="second trigger", conversation=_conversation()) + await ltm.on_req_llm(second_trigger, second_req) + + assert "first background" not in second_req.system_prompt + assert "first trigger" not in second_req.system_prompt + assert "second background" in second_req.system_prompt + assert "second trigger" not in second_req.system_prompt + + +@pytest.mark.asyncio +async def test_group_flow_cursor_advances_only_after_llm_success(temp_db, flow_context): + ltm = LongTermMemory(AsyncMock(), flow_context) + background = _event("background before first trigger") + trigger = _event("first trigger") + second_trigger = _event("second trigger") + + await ltm.handle_message(background) + await ltm.handle_message(trigger) + first_req = ProviderRequest(prompt="first trigger", conversation=_conversation()) + await ltm.on_req_llm(trigger, first_req) + + await ltm.handle_message(second_trigger) + retry_req = ProviderRequest(prompt="second trigger", conversation=_conversation()) + await ltm.on_req_llm(second_trigger, retry_req) + + assert "background before first trigger" in retry_req.system_prompt + assert "first trigger" in retry_req.system_prompt + assert "second trigger" not in retry_req.system_prompt + + await ltm.after_req_llm( + trigger, + LLMResponse(role="assistant", completion_text="first response"), + ) + committed_req = ProviderRequest( + prompt="second trigger", conversation=_conversation() + ) + await ltm.on_req_llm(second_trigger, committed_req) + + assert "background before first trigger" not in committed_req.system_prompt + assert "first trigger" not in committed_req.system_prompt + assert "second trigger" not in committed_req.system_prompt + + +@pytest.mark.asyncio +async def test_group_flow_delta_limit_keeps_tail_messages(temp_db): + context = _flow_context( + temp_db, + _flow_config({"group_flow_max_delta_messages": 2}), + ) + ltm = LongTermMemory(AsyncMock(), context) + first = _event("first background") + second = _event("second background") + third = _event("third background") + trigger = _event("trigger message") + + for event in [first, second, third, trigger]: + await ltm.handle_message(event) + + req = ProviderRequest(prompt="trigger message", conversation=_conversation()) + await ltm.on_req_llm(trigger, req) + + assert "first background" not in req.system_prompt + assert "second background" in req.system_prompt + assert "third background" in req.system_prompt + assert "trigger message" not in req.system_prompt + + +@pytest.mark.asyncio +async def test_group_flow_delta_truncates_each_message(temp_db): + context = _flow_context( + temp_db, + _flow_config({"group_flow_max_message_chars": 32}), + ) + ltm = LongTermMemory(AsyncMock(), context) + background = _event("hello " + ("x" * 100) + "TAIL_SHOULD_NOT_APPEAR") + trigger = _event("trigger message") + + await ltm.handle_message(background) + await ltm.handle_message(trigger) + + req = ProviderRequest(prompt="trigger message", conversation=_conversation()) + await ltm.on_req_llm(trigger, req) + + assert "hello" in req.system_prompt + assert "TAIL_SHOULD_NOT_APPEAR" not in req.system_prompt + assert "trigger message" not in req.system_prompt + + +@pytest.mark.asyncio +async def test_group_flow_reset_moves_cursor_to_latest(temp_db, flow_context): + ltm = LongTermMemory(AsyncMock(), flow_context) + before_reset = _event("before reset") + reset_command = _event("/reset") + after_reset_trigger = _event("after reset trigger") + + await ltm.handle_message(before_reset) + await ltm.handle_message(reset_command) + await ltm.remove_session(reset_command) + await ltm.handle_message(after_reset_trigger) + + req = ProviderRequest(prompt="after reset trigger", conversation=_conversation()) + await ltm.on_req_llm(after_reset_trigger, req) + + assert "before reset" not in req.system_prompt + assert "/reset" not in req.system_prompt + assert "after reset trigger" not in req.system_prompt + + +@pytest.mark.asyncio +async def test_group_flow_serializes_nested_reply_components(temp_db, flow_context): + ltm = LongTermMemory(AsyncMock(), flow_context) + event = _event("reply wrapper") + event.message_obj.message = [ + Reply(id="quoted-1", chain=[Plain(text="quoted text")]), + Plain(text="reply wrapper"), + ] + + await ltm.handle_message(event) + + rows = await flow_context.group_message_flow_manager.get_records_after( + flow_session_id="default:GroupMessage:group-1", + after_id=0, + ) + + assert len(rows) == 1 + assert rows[0].content[0]["type"] == "reply" + assert rows[0].content[0]["data"]["chain"][0]["data"]["text"] == "quoted text" + + +@pytest.mark.asyncio +async def test_group_flow_serializes_sync_to_dict_components(temp_db, flow_context): + ltm = LongTermMemory(AsyncMock(), flow_context) + + content = await ltm._components_to_dict([SyncDictComponent()]) + + assert content == [ + { + "type": "sync", + "data": {"items": [{"type": "text", "data": {"text": "nested sync"}}]}, + } + ]