diff --git a/astrbot/builtin_stars/astrbot/group_chat_context.py b/astrbot/builtin_stars/astrbot/group_chat_context.py new file mode 100644 index 0000000000..7fee3c0df9 --- /dev/null +++ b/astrbot/builtin_stars/astrbot/group_chat_context.py @@ -0,0 +1,241 @@ +import asyncio +import datetime +import random +import uuid +from collections import defaultdict, deque + +from astrbot import logger +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent +from astrbot.api.message_components import At, Image, Plain +from astrbot.api.platform import MessageType +from astrbot.api.provider import Provider, ProviderRequest +from astrbot.core.agent.message import TextPart +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager + +""" +Group chat context awareness. +""" + +GROUP_HISTORY_HEADER = ( + "" + "You are in a group chat. " + "Belows are group chat context after your last reply:\n" + "--- BEGIN CONTEXT---\n" +) +GROUP_HISTORY_FOOTER = "\n--- END CONTEXT ---\n" +DEFAULT_GROUP_MESSAGE_MAX_CNT = 300 + + +class GroupChatContext: + def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: + self.acm = acm + self.context = context + self._locks: dict[str, asyncio.Lock] = {} + self.raw_records: dict[str, deque[str]] = defaultdict(deque) + self._record_ids: dict[str, deque[str]] = defaultdict(deque) + + def _get_lock(self, umo: str) -> asyncio.Lock: + lock = self._locks.get(umo) + if lock is None: + lock = asyncio.Lock() + self._locks[umo] = lock + return lock + + def cfg(self, event: AstrMessageEvent): + cfg = self.context.get_config(umo=event.unified_msg_origin) + group_context_cfg = cfg["provider_ltm_settings"] + image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] + image_caption_provider_id = group_context_cfg.get("image_caption_provider_id") + image_caption = group_context_cfg["image_caption"] and bool( + image_caption_provider_id + ) + active_reply = group_context_cfg["active_reply"] + enable_active_reply = active_reply.get("enable", False) + ar_method = active_reply["method"] + ar_possibility = active_reply["possibility_reply"] + ar_prompt = active_reply.get("prompt", "") + ar_whitelist = active_reply.get("whitelist", []) + return { + "group_message_max_cnt": _positive_int( + group_context_cfg.get( + "group_message_max_cnt", + DEFAULT_GROUP_MESSAGE_MAX_CNT, + ), + DEFAULT_GROUP_MESSAGE_MAX_CNT, + ), + "image_caption": image_caption, + "image_caption_prompt": image_caption_prompt, + "image_caption_provider_id": image_caption_provider_id, + "enable_active_reply": enable_active_reply, + "ar_method": ar_method, + "ar_possibility": ar_possibility, + "ar_prompt": ar_prompt, + "ar_whitelist": ar_whitelist, + } + + async def get_image_caption( + self, + image_url: str, + image_caption_provider_id: str, + image_caption_prompt: str, + ) -> str: + if not image_caption_provider_id: + provider = self.context.get_using_provider() + else: + provider = self.context.get_provider_by_id(image_caption_provider_id) + if not provider: + raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商") + if not isinstance(provider, Provider): + raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述") + response = await provider.text_chat( + prompt=image_caption_prompt, + session_id=uuid.uuid4().hex, + image_urls=[image_url], + persist=False, + ) + return response.completion_text + + async def need_active_reply(self, event: AstrMessageEvent) -> bool: + cfg = self.cfg(event) + if not cfg["enable_active_reply"]: + return False + if event.get_message_type() != MessageType.GROUP_MESSAGE: + return False + if event.is_at_or_wake_command: + return False + if cfg["ar_whitelist"] and ( + event.unified_msg_origin not in cfg["ar_whitelist"] + and ( + event.get_group_id() and event.get_group_id() not in cfg["ar_whitelist"] + ) + ): + return False + match cfg["ar_method"]: + case "possibility_reply": + return random.random() < cfg["ar_possibility"] + return False + + async def remove_session(self, event: AstrMessageEvent) -> int: + umo = event.unified_msg_origin + lock = self._get_lock(umo) + async with lock: + cnt = len(self.raw_records.get(umo, deque())) + self.raw_records.pop(umo, None) + self._record_ids.pop(umo, None) + self._locks.pop(umo, None) + return cnt + + async def handle_message(self, event: AstrMessageEvent) -> None: + if event.get_message_type() != MessageType.GROUP_MESSAGE: + return + + umo = event.unified_msg_origin + cfg = self.cfg(event) + final_message = await self._format_message(event, cfg) + + async with self._get_lock(umo): + records = self.raw_records[umo] + record_ids = self._record_ids[umo] + record_id = uuid.uuid4().hex + records.append(final_message) + record_ids.append(record_id) + _trim_left(records, cfg["group_message_max_cnt"], record_ids) + event.set_extra("_group_context_record_id", record_id) + event.set_extra("_group_context_raw_idx", len(records) - 1) + + logger.debug(f"group_chat_context | {umo} | {final_message}") + + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: + umo = event.unified_msg_origin + record_id = event.get_extra("_group_context_record_id", None) + prompt_idx = event.get_extra("_group_context_raw_idx", -1) + if not isinstance(record_id, str) and ( + not isinstance(prompt_idx, int) or prompt_idx < 0 + ): + return + + async with self._get_lock(umo): + records = self.raw_records.get(umo) + if not records: + return + + raw_list = list(records) + id_list = list(self._record_ids.get(umo, deque())) + if isinstance(record_id, str) and record_id in id_list: + prompt_idx = id_list.index(record_id) + + if prompt_idx >= len(raw_list): + return + + records_to_inject = raw_list[:prompt_idx] + remaining = raw_list[prompt_idx + 1 :] + remaining_ids = id_list[prompt_idx + 1 :] if id_list else [] + records.clear() + records.extend(remaining) + if id_list: + record_ids = self._record_ids[umo] + record_ids.clear() + record_ids.extend(remaining_ids) + + if records_to_inject: + req.extra_user_content_parts.append( + TextPart(text=_format_group_history_block(records_to_inject)) + ) + + async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str: + datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "] + + for comp in event.get_messages(): + if isinstance(comp, Plain): + parts.append(f" {comp.text}") + elif isinstance(comp, Image): + if cfg["image_caption"]: + try: + url = comp.url if comp.url else comp.file + if not url: + raise Exception("图片 URL 为空") + caption = await self.get_image_caption( + url, + cfg["image_caption_provider_id"], + cfg["image_caption_prompt"], + ) + parts.append(f" [Image: {caption}]") + except Exception as e: + logger.error(f"获取图片描述失败: {e}") + else: + parts.append(" [Image]") + elif isinstance(comp, At): + is_at_self = str(comp.qq) in ( + event.get_self_id(), + "all", + ) + if is_at_self: + parts.insert(1, "⚠️[DIRECTED AT YOU] ") + parts.append(f" [At: {comp.name}]") + + return "".join(parts) + + +def _positive_int(value, fallback: int) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + return fallback + return parsed if parsed > 0 else fallback + + +def _trim_left( + records: deque[str], + max_records: int, + record_ids: deque[str] | None = None, +) -> None: + while len(records) > max_records: + records.popleft() + if record_ids: + record_ids.popleft() + + +def _format_group_history_block(records: list[str]) -> str: + return GROUP_HISTORY_HEADER + "\n".join(records) + GROUP_HISTORY_FOOTER diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py deleted file mode 100644 index e08cdc5157..0000000000 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ /dev/null @@ -1,188 +0,0 @@ -import datetime -import random -import uuid -from collections import defaultdict - -from astrbot import logger -from astrbot.api import star -from astrbot.api.event import AstrMessageEvent -from astrbot.api.message_components import At, Image, Plain -from astrbot.api.platform import MessageType -from astrbot.api.provider import LLMResponse, Provider, ProviderRequest -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager - -""" -聊天记忆增强 -""" - - -class LongTermMemory: - def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: - self.acm = acm - self.context = context - self.session_chats = defaultdict(list) - """记录群成员的群聊记录""" - - def cfg(self, event: AstrMessageEvent): - cfg = self.context.get_config(umo=event.unified_msg_origin) - try: - max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"]) - except BaseException as e: - logger.error(e) - max_cnt = 300 - image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] - image_caption_provider_id = cfg["provider_ltm_settings"].get( - "image_caption_provider_id" - ) - image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool( - image_caption_provider_id - ) - active_reply = cfg["provider_ltm_settings"]["active_reply"] - enable_active_reply = active_reply.get("enable", False) - ar_method = active_reply["method"] - ar_possibility = active_reply["possibility_reply"] - ar_prompt = active_reply.get("prompt", "") - ar_whitelist = active_reply.get("whitelist", []) - ret = { - "max_cnt": max_cnt, - "image_caption": image_caption, - "image_caption_prompt": image_caption_prompt, - "image_caption_provider_id": image_caption_provider_id, - "enable_active_reply": enable_active_reply, - "ar_method": ar_method, - "ar_possibility": ar_possibility, - "ar_prompt": ar_prompt, - "ar_whitelist": ar_whitelist, - } - return ret - - async def remove_session(self, event: AstrMessageEvent) -> int: - cnt = 0 - if event.unified_msg_origin in self.session_chats: - cnt = len(self.session_chats[event.unified_msg_origin]) - del self.session_chats[event.unified_msg_origin] - return cnt - - async def get_image_caption( - self, - image_url: str, - image_caption_provider_id: str, - image_caption_prompt: str, - ) -> str: - if not image_caption_provider_id: - provider = self.context.get_using_provider() - else: - provider = self.context.get_provider_by_id(image_caption_provider_id) - if not provider: - raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商") - if not isinstance(provider, Provider): - raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述") - response = await provider.text_chat( - prompt=image_caption_prompt, - session_id=uuid.uuid4().hex, - image_urls=[image_url], - persist=False, - ) - return response.completion_text - - async def need_active_reply(self, event: AstrMessageEvent) -> bool: - cfg = self.cfg(event) - if not cfg["enable_active_reply"]: - return False - if event.get_message_type() != MessageType.GROUP_MESSAGE: - return False - - if event.is_at_or_wake_command: - # if the message is a command, let it pass - return False - - if cfg["ar_whitelist"] and ( - event.unified_msg_origin not in cfg["ar_whitelist"] - and ( - event.get_group_id() and event.get_group_id() not in cfg["ar_whitelist"] - ) - ): - return False - - match cfg["ar_method"]: - case "possibility_reply": - trig = random.random() < cfg["ar_possibility"] - return trig - - return False - - async def handle_message(self, event: AstrMessageEvent) -> None: - """仅支持群聊""" - if event.get_message_type() == MessageType.GROUP_MESSAGE: - datetime_str = datetime.datetime.now().strftime("%H:%M:%S") - - parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "] - - cfg = self.cfg(event) - - for comp in event.get_messages(): - if isinstance(comp, Plain): - parts.append(f" {comp.text}") - elif isinstance(comp, Image): - if cfg["image_caption"]: - try: - url = comp.url if comp.url else comp.file - if not url: - raise Exception("图片 URL 为空") - caption = await self.get_image_caption( - url, - cfg["image_caption_provider_id"], - cfg["image_caption_prompt"], - ) - parts.append(f" [Image: {caption}]") - except Exception as e: - logger.error(f"获取图片描述失败: {e}") - else: - parts.append(" [Image]") - elif isinstance(comp, At): - parts.append(f" [At: {comp.name}]") - - final_message = "".join(parts) - logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") - self.session_chats[event.unified_msg_origin].append(final_message) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) - - async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: - """当触发 LLM 请求前,调用此方法修改 req""" - if event.unified_msg_origin not in self.session_chats: - return - - chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) - - cfg = self.cfg(event) - if cfg["enable_active_reply"]: - prompt = req.prompt - req.prompt = ( - f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" - f"\nNow, a new message is coming: `{prompt}`. " - "Please react to it. Only output your response and do not output any other information. " - "You MUST use the SAME language as the chatroom is using." - ) - req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 - else: - req.system_prompt += ( - "You are now in a chatroom. The chat history is as follows: \n" - ) - req.system_prompt += chats_str - - async def after_req_llm( - self, event: AstrMessageEvent, llm_resp: LLMResponse - ) -> None: - if event.unified_msg_origin not in self.session_chats: - return - - if llm_resp.completion_text: - final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}" - logger.debug( - f"Recorded AI response: {event.unified_msg_origin} | {final_message}" - ) - self.session_chats[event.unified_msg_origin].append(final_message) - cfg = self.cfg(event) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index 3d800edd26..874a03a401 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -7,7 +7,7 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, filter from astrbot.api.message_components import Image, Plain -from astrbot.api.provider import LLMResponse, ProviderRequest +from astrbot.api.provider import ProviderRequest from astrbot.core import logger from astrbot.core.utils.session_waiter import ( FILTERS, @@ -17,7 +17,7 @@ session_waiter, ) -from .long_term_memory import LongTermMemory +from .group_chat_context import GroupChatContext def _iter_message_components(event: AstrMessageEvent): @@ -30,11 +30,14 @@ def _iter_message_components(event: AstrMessageEvent): class Main(star.Star): def __init__(self, context: star.Context) -> None: self.context = context - self.ltm = None + self.group_chat_context = None try: - self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context) + self.group_chat_context = GroupChatContext( + self.context.astrbot_config_mgr, + self.context, + ) except BaseException as e: - logger.error(f"聊天增强 err: {e}") + logger.error(f"group chat context init failed: {e}") @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) async def handle_session_control_agent(self, event: AstrMessageEvent) -> None: @@ -133,15 +136,18 @@ async def empty_mention_waiter( except Exception as e: logger.error("handle_empty_mention error: " + str(e)) - def ltm_enabled(self, event: AstrMessageEvent): - ltmse = self.context.get_config(umo=event.unified_msg_origin)[ + def group_context_enabled(self, event: AstrMessageEvent): + group_context_settings = self.context.get_config(umo=event.unified_msg_origin)[ "provider_ltm_settings" ] - return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"] + return ( + group_context_settings["group_icl_enable"] + or group_context_settings["active_reply"]["enable"] + ) @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) async def on_message(self, event: AstrMessageEvent): - """群聊记忆增强""" + """群聊上下文感知""" message_components = _iter_message_components(event) has_image_or_plain = False for comp in message_components: @@ -149,27 +155,31 @@ async def on_message(self, event: AstrMessageEvent): has_image_or_plain = True break - if self.ltm_enabled(event) and self.ltm and has_image_or_plain: - need_active = await self.ltm.need_active_reply(event) + group_context_enabled = False + if self.group_chat_context: + try: + group_context_enabled = self.group_context_enabled(event) + except BaseException as e: + logger.error(f"group chat context: {e}") + + if group_context_enabled and self.group_chat_context and has_image_or_plain: + need_active = await self.group_chat_context.need_active_reply(event) group_icl_enable = self.context.get_config(umo=event.unified_msg_origin)[ "provider_ltm_settings" ]["group_icl_enable"] if group_icl_enable: - """记录对话""" try: - await self.ltm.handle_message(event) + await self.group_chat_context.handle_message(event) except BaseException as e: logger.error(e) if need_active: - """主动回复""" provider = self.context.get_using_provider(event.unified_msg_origin) if not provider: logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复") return try: - conv = None session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( event.unified_msg_origin, ) @@ -185,6 +195,10 @@ async def on_message(self, event: AstrMessageEvent): session_curr_cid, ) + if not conv: + logger.error("未找到对话,无法主动回复") + return + prompt = event.message_str image_urls = [] for comp in message_components: @@ -194,10 +208,6 @@ async def on_message(self, event: AstrMessageEvent): except Exception: logger.exception("主动回复处理图片失败") - if not conv: - logger.error("未找到对话,无法主动回复") - return - yield event.request_llm( prompt=prompt, session_id=event.session_id, @@ -213,30 +223,19 @@ async def decorate_llm_req( self, event: AstrMessageEvent, req: ProviderRequest ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" - if self.ltm and self.ltm_enabled(event): + if self.group_chat_context and self.group_context_enabled(event): try: - await self.ltm.on_req_llm(event, req) + await self.group_chat_context.on_req_llm(event, req) except BaseException as e: - logger.error(f"ltm: {e}") - - @filter.on_llm_response() - async def record_llm_resp_to_ltm( - self, event: AstrMessageEvent, resp: LLMResponse - ) -> None: - """在 LLM 响应后记录对话""" - if self.ltm and self.ltm_enabled(event): - try: - await self.ltm.after_req_llm(event, resp) - except Exception as e: - logger.error(f"ltm: {e}") + logger.error(f"group chat context: {e}") @filter.after_message_sent() async def after_message_sent(self, event: AstrMessageEvent) -> None: """消息发送后处理""" - if self.ltm and self.ltm_enabled(event): + if self.group_chat_context and self.group_context_enabled(event): try: - clean_session = event.get_extra("_clean_ltm_session", False) + clean_session = event.get_extra("_clean_group_context_session", False) if clean_session: - await self.ltm.remove_session(event) + await self.group_chat_context.remove_session(event) except Exception as e: - logger.error(f"ltm: {e}") + logger.error(f"group chat context: {e}") diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index 9dcf369096..42282dc500 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -189,7 +189,7 @@ async def reset(self, message: AstrMessageEvent) -> None: ret = "✅ Conversation reset successfully." - message.set_extra("_clean_ltm_session", True) + message.set_extra("_clean_group_context_session", True) message.set_result(MessageEventResult().message(ret)) @@ -243,7 +243,7 @@ async def new_conv(self, message: AstrMessageEvent) -> None: persona_id=cpersona, ) - message.set_extra("_clean_ltm_session", True) + message.set_extra("_clean_group_context_session", True) message.set_result( MessageEventResult().message( diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index d4642bc506..0e018f8543 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -96,51 +96,32 @@ async def __call__(self, messages: list[Message]) -> list[Message]: return truncated_messages -def split_history( - messages: list[Message], keep_recent: int -) -> tuple[list[Message], list[Message], list[Message]]: - """Split the message list into system messages, messages to summarize, and recent messages. - - Ensures that the split point is between complete user-assistant pairs to maintain conversation flow. - - Args: - messages: The original message list. - keep_recent: The number of latest messages to keep. - - Returns: - tuple: (system_messages, messages_to_summarize, recent_messages) - """ - # keep the system messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i +def _message_to_dict(msg: Message) -> dict: + """Convert a Message to a plain dict suitable for round splitting.""" + d = {"role": msg.role} + if msg.content is not None: + d["content"] = msg.content + if getattr(msg, "tool_calls", None): + d["tool_calls"] = msg.tool_calls + if getattr(msg, "tool_call_id", None): + d["tool_call_id"] = msg.tool_call_id + return d + + +def _dict_to_message(d: dict) -> Message: + """Convert a plain dict back to a Message.""" + return Message(**d) + + +def _extract_system_messages(messages: list[Message]) -> list[Message]: + """Return the leading system messages from a message list.""" + result = [] + for msg in messages: + if msg.role == "system": + result.append(msg) + else: break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] - - if len(non_system_messages) <= keep_recent: - return system_messages, [], non_system_messages - - # Find the split point, ensuring recent_messages starts with a user message - # This maintains complete conversation turns - split_index = len(non_system_messages) - keep_recent - - # Search backward from split_index to find the first user message - # This ensures recent_messages starts with a user message (complete turn) - while split_index > 0 and non_system_messages[split_index].role != "user": - # TODO: +=1 or -=1 ? calculate by tokens - split_index -= 1 - - # If we couldn't find a user message, keep all messages as recent - if split_index == 0: - return system_messages, [], non_system_messages - - messages_to_summarize = non_system_messages[:split_index] - recent_messages = non_system_messages[split_index:] - - return system_messages, messages_to_summarize, recent_messages + return result class LLMSummaryCompressor: @@ -166,6 +147,7 @@ def __init__( self.provider = provider self.keep_recent = keep_recent self.compression_threshold = compression_threshold + self.existing_summary: str = "" self.instruction_text = instruction_text or ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" @@ -196,28 +178,44 @@ def should_compress( async def __call__(self, messages: list[Message]) -> list[Message]: """Use LLM to generate a summary of the conversation history. - Process: - 1. Divide messages: keep the system message and the latest N messages. - 2. Send the old messages + the instruction message to the LLM. - 3. Reconstruct the message list: [system message, summary message, latest messages]. + Uses round-based splitting to preserve user-assistant turn boundaries. + On LLM failure, returns the original messages unchanged (caller should + fall back to truncation). """ - if len(messages) <= self.keep_recent + 1: + from .round_utils import rounds_to_text, split_into_rounds + + # Convert messages to dict list for round splitting + msg_dicts = [_message_to_dict(m) for m in messages] + rounds = split_into_rounds(msg_dicts) + + if len(rounds) <= self.keep_recent: return messages - system_messages, messages_to_summarize, recent_messages = split_history( - messages, self.keep_recent - ) + old_rounds = rounds[: -self.keep_recent] + recent_rounds = rounds[-self.keep_recent :] - if not messages_to_summarize: + if not old_rounds: return messages - # build payload - instruction_message = Message(role="user", content=self.instruction_text) - llm_payload = messages_to_summarize + [instruction_message] + # Build LLM payload + old_text = rounds_to_text(old_rounds) + existing_note = "" + if self.existing_summary: + existing_note = ( + "\nExisting memory summary (merge with old rounds above):\n" + f"{self.existing_summary}\n" + ) + prompt = ( + f"{self.instruction_text}\n\n" + "--- BEGIN CONVERSATION ROUNDS TO SUMMARIZE ---\n" + f"{old_text}\n" + "--- END CONVERSATION ROUNDS ---" + f"{existing_note}" + ) - # generate summary + # Generate summary try: - response = await self.provider.text_chat(contexts=llm_payload) + response = await self.provider.text_chat(prompt=prompt) summary_content = (response.completion_text or "").strip() except Exception as e: logger.error(f"Failed to generate summary: {e}") @@ -227,9 +225,8 @@ async def __call__(self, messages: list[Message]) -> list[Message]: logger.warning("LLM context compression returned an empty summary.") return messages - # build result - result = [] - result.extend(system_messages) + # Build result: system messages + summary pair + recent rounds + result = _extract_system_messages(messages) result.append( Message( @@ -244,6 +241,9 @@ async def __call__(self, messages: list[Message]) -> list[Message]: ) ) - result.extend(recent_messages) + # Flatten recent rounds back to message list + for rnd in recent_rounds: + for seg in rnd: + result.append(_dict_to_message(seg)) return result diff --git a/astrbot/core/agent/context/round_utils.py b/astrbot/core/agent/context/round_utils.py new file mode 100644 index 0000000000..20c2f5711f --- /dev/null +++ b/astrbot/core/agent/context/round_utils.py @@ -0,0 +1,38 @@ +"""Round-based utilities shared by LTM compaction and LLMSummaryCompressor.""" + +import json +from typing import Any + + +def split_into_rounds( + contexts: list[dict[str, Any]], +) -> list[list[dict[str, Any]]]: + """Split a flat contexts list into logical rounds. + + A round begins at a ``user`` segment and includes all subsequent + ``assistant`` / ``tool`` segments until the next ``user`` segment. + """ + rounds: list[list[dict[str, Any]]] = [] + current: list[dict[str, Any]] = [] + for seg in contexts: + if seg.get("role") == "user" and current: + rounds.append(current) + current = [] + current.append(seg) + if current: + rounds.append(current) + return rounds + + +def rounds_to_text(rounds: list[list[dict[str, Any]]]) -> str: + """Render rounds into a plain-text string for LLM summarisation.""" + lines: list[str] = [] + for i, rnd in enumerate(rounds, 1): + lines.append(f"--- Round {i} ---") + for seg in rnd: + role = seg.get("role", "?") + content = seg.get("content") or seg.get("tool_calls") or "" + if isinstance(content, list): + content = json.dumps(content, ensure_ascii=False) + lines.append(f"[{role}] {content}") + return "\n".join(lines) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 2da36fda2b..417f090ea6 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -241,13 +241,10 @@ async def reset( self.tool_result_overflow_dir = tool_result_overflow_dir self.read_tool = read_tool self._tool_result_token_counter = EstimateTokenCounter() - # we will do compress when: - # 1. before requesting LLM - # TODO: 2. after LLM output a tool call - self.context_config = ContextConfig( - # <=0 will never do compress + self.request_context_manager_config = ContextConfig( + # <=0 disables token-based guarding. max_context_tokens=provider.provider_config.get("max_context_tokens", 0), - # enforce max turns before compression + # Enforce max turns before token-based guarding. enforce_max_turns=self.enforce_max_turns, truncate_turns=self.truncate_turns, llm_compress_instruction=self.llm_compress_instruction, @@ -256,7 +253,9 @@ async def reset( custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) - self.context_manager = ContextManager(self.context_config) + self.request_context_manager = ContextManager( + self.request_context_manager_config + ) self.provider = provider self.fallback_providers: list[Provider] = [] @@ -459,8 +458,11 @@ async def _iter_llm_responses( self, *, include_model: bool = True ) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" + messages_for_provider = getattr( + self, "_provider_messages", self.run_context.messages + ) payload = { - "contexts": self._sanitize_contexts_for_provider(self.run_context.messages), + "contexts": self._sanitize_contexts_for_provider(messages_for_provider), "func_tool": self._func_tool_for_provider(), "session_id": self.req.session_id, "extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart] @@ -704,10 +706,13 @@ async def step(self): self._transition_state(AgentState.RUNNING) llm_resp_result = None - # do truncate and compress + # Process request-time context on a copy so the runner's canonical + # messages are never mutated. The processed result is only used for this + # provider call. Persistent compaction is owned by the conversation / + # memory layer. token_usage = self.req.conversation.token_usage if self.req.conversation else 0 self._simple_print_message_role("[BefCompact]") - self.run_context.messages = await self.context_manager.process( + self._provider_messages = await self.request_context_manager.process( self.run_context.messages, trusted_token_usage=token_usage ) self._simple_print_message_role("[AftCompact]") @@ -1403,6 +1408,9 @@ async def _iter_tool_executor_results( self, executor: AsyncIterator[ToolExecutorResultT], ) -> T.AsyncGenerator[ToolExecutorResultT, None]: + async def _next_executor_result() -> ToolExecutorResultT: + return await anext(executor) + while True: if self._is_stop_requested(): await self._close_executor(executor) @@ -1410,7 +1418,7 @@ async def _iter_tool_executor_results( "Tool execution interrupted before reading the next tool result." ) - next_result_task = asyncio.create_task(anext(executor)) + next_result_task = asyncio.create_task(_next_executor_result()) abort_task = asyncio.create_task(self._abort_signal.wait()) try: done, _ = await asyncio.wait( diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index e522ce5453..8fc6ade63f 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -150,14 +150,14 @@ class MainAgentBuildConfig: """The strategy to handle context length limit reached.""" llm_compress_instruction: str = "" """The instruction for compression in llm_compress strategy.""" - llm_compress_keep_recent: int = 6 + llm_compress_keep_recent: int = 10 """The number of most recent turns to keep during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" - max_context_length: int = -1 + max_context_length: int = 50 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" - dequeue_context_length: int = 1 + dequeue_context_length: int = 10 """The number of oldest turns to remove when context length limit is reached.""" fallback_max_context_tokens: int = 128000 """Fallback max context tokens. When max_context_tokens is 0 and the model is not in LLM_METADATAS, use this value.""" @@ -1135,26 +1135,27 @@ async def _apply_web_search_tools( def _get_compress_provider( - config: MainAgentBuildConfig, plugin_context: Context + config: MainAgentBuildConfig, + plugin_context: Context, + event: AstrMessageEvent | None = None, ) -> Provider | None: - if not config.llm_compress_provider_id: - return None if config.context_limit_reached_strategy != "llm_compress": return None - provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) - if provider is None: - logger.warning( - "未找到指定的上下文压缩模型 %s,将跳过压缩。", - config.llm_compress_provider_id, - ) - return None - if not isinstance(provider, Provider): + if config.llm_compress_provider_id: + provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) + if provider and isinstance(provider, Provider): + return provider logger.warning( - "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", + "指定的上下文压缩模型 %s 不可用", config.llm_compress_provider_id, ) - return None - return provider + # fallback: use current chat provider for this session + if event: + try: + return plugin_context.get_using_provider(umo=event.unified_msg_origin) + except ValueError: + pass + return None def _get_fallback_chat_providers( @@ -1470,9 +1471,8 @@ async def build_main_agent( streaming=config.streaming_response, llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, - llm_compress_provider=_get_compress_provider(config, plugin_context), + llm_compress_provider=_get_compress_provider(config, plugin_context, event), truncate_turns=config.dequeue_context_length, - enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, fallback_providers=_get_fallback_chat_providers( provider, plugin_context, config.provider_settings diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index dec98692bc..35594b8708 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -120,7 +120,7 @@ "default_personality": "default", "persona_pool": ["*"], "prompt_prefix": "{{prompt}}", - "context_limit_reached_strategy": "truncate_by_turns", # or llm_compress + "context_limit_reached_strategy": "llm_compress", # or truncate_by_turns "llm_compress_instruction": ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" "1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n" @@ -128,10 +128,10 @@ "3. If there was an initial user goal, state it first and describe the current progress/status.\n" "4. Write the summary in the user's language.\n" ), - "llm_compress_keep_recent": 6, + "llm_compress_keep_recent": 10, "llm_compress_provider_id": "", - "max_context_length": -1, - "dequeue_context_length": 1, + "max_context_length": 50, + "dequeue_context_length": 10, "streaming_response": False, "show_tool_use_status": False, "show_tool_call_result": False, @@ -3505,30 +3505,30 @@ "type": "object", "items": { "provider_settings.max_context_length": { - "description": "最多携带对话轮数", + "description": "压缩前最多保留对话轮数", "type": "int", - "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", + "hint": "普通会话历史超过该轮数后,才会按下方策略进行持久化截断或 LLM 压缩;请求发送前也会先按该值约束上下文。-1 表示不按轮数限制。", "condition": { "provider_settings.agent_runner_type": "local", }, }, "provider_settings.dequeue_context_length": { - "description": "丢弃对话轮数", + "description": "轮次超限时一次丢弃轮数", "type": "int", - "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", + "hint": "当超过“压缩前最多保留对话轮数”且无法使用 LLM 压缩时,一次丢弃多少轮旧对话;请求期截断也会复用该值。", "condition": { "provider_settings.agent_runner_type": "local", }, }, "provider_settings.context_limit_reached_strategy": { - "description": "超出模型上下文窗口时的处理方式", + "description": "历史超限或上下文接近上限时的处理方式", "type": "string", "options": ["truncate_by_turns", "llm_compress"], "labels": ["按对话轮数截断", "由 LLM 压缩上下文"], "condition": { "provider_settings.agent_runner_type": "local", }, - "hint": "", + "hint": "普通会话历史仅在超过“压缩前最多保留对话轮数”后执行该策略;请求发送前也会在上下文 token 接近模型窗口时使用同一策略保护本次请求。", }, "provider_settings.llm_compress_instruction": { "description": "上下文压缩提示词", @@ -3552,7 +3552,7 @@ "description": "用于上下文压缩的模型提供商 ID", "type": "string", "_special": "select_provider", - "hint": "留空时将降级为“按对话轮数截断”的策略。", + "hint": "留空时使用当前聊天模型进行压缩;如果模型不可用或压缩失败,将回退为“按对话轮数截断”的策略。", "condition": { "provider_settings.context_limit_reached_strategy": "llm_compress", "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 2c200ec262..e8b0459fb6 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -97,7 +97,9 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_instruction: str = settings.get( "llm_compress_instruction", "" ) - self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4) + self.llm_compress_keep_recent: int = settings.get( + "llm_compress_keep_recent", 10 + ) self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" ) @@ -476,6 +478,18 @@ async def _save_to_history( continue if message.role in ["assistant", "user"] and message._no_save: continue + # Truncate long tool results before persisting (8192 chars) + if ( + message.role == "tool" + and isinstance(message.content, str) + and len(message.content) > 8192 + ): + message = Message( + role="tool", + tool_call_id=message.tool_call_id, + content=message.content[:8192] + + f"\n...[truncated {len(message.content) - 8192} chars]", + ) messages_to_save.append(message) checkpoint_id = event.get_extra("llm_checkpoint_id") diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 6363b71e31..284da40097 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -247,19 +247,19 @@ "provider_settings": { "max_context_length": { "description": "Max Turns Before Compression", - "hint": "Limits history turns before any compression strategy is applied; -1 means no turn-based limit" + "hint": "Persistent conversation history is truncated or LLM-compressed by the strategy below only after it exceeds this many turns. Request-time contexts are also constrained by this value before sending. -1 means no turn-based limit." }, "dequeue_context_length": { "description": "Turns to Discard When Limit Exceeded", - "hint": "Number of old conversation turns to discard at once when the turn limit is exceeded; also used as fallback when compression is unavailable" + "hint": "When history exceeds 'Max Turns Before Compression' and LLM compression is unavailable, discard this many oldest turns at once. Request-time truncation also reuses this value." }, "context_limit_reached_strategy": { - "description": "Handling When Context Approaches Model Limit", + "description": "Handling for History Limits or Context Window Pressure", "labels": [ "Truncate by Turns", "Compress by LLM" ], - "hint": "This strategy only triggers after turn-based limiting, when context tokens approach the model's window limit. When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Turns to Discard When Limit Exceeded' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression." + "hint": "Persistent conversation history uses this strategy only after exceeding 'Max Turns Before Compression'. Before each request, the same strategy may also protect the in-flight context when tokens approach the model window." }, "llm_compress_instruction": { "description": "Context Compression Instruction", @@ -271,7 +271,7 @@ }, "llm_compress_provider_id": { "description": "Model Provider ID for Context Compression", - "hint": "When left empty, will fall back to the 'Truncate by Turns' strategy." + "hint": "When left empty, the current chat model will be used for compression. If the model is unavailable or compression fails, AstrBot falls back to the 'Truncate by Turns' strategy." }, "fallback_max_context_tokens": { "description": "Fallback context window size", diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 028bff8675..207e2eeda7 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -246,20 +246,20 @@ "description": "Стратегия управления контекстом", "provider_settings": { "max_context_length": { - "description": "Макс. количество раундов диалога", - "hint": "При превышении удаляются старые сообщения. 1 раунд = 1 пара запрос-ответ. -1 означает без ограничений." + "description": "Макс. раундов перед сжатием", + "hint": "Постоянная история диалога обрезается или сжимается LLM по стратегии ниже только после превышения этого числа раундов. Контекст перед запросом также ограничивается этим значением. -1 означает без ограничений по раундам." }, "dequeue_context_length": { - "description": "Кол-во удаляемых раундов", - "hint": "Сколько раундов удалять за один раз при достижении лимита." + "description": "Раундов для удаления при превышении лимита", + "hint": "Когда история превышает лимит раундов и LLM-сжатие недоступно, за один раз удаляется это число самых старых раундов. Обрезка перед запросом также использует это значение." }, "context_limit_reached_strategy": { - "description": "Действие при переполнении окна контекста", + "description": "Действие при лимите истории или давлении окна контекста", "labels": [ "Обрезать по раундам", "Сжать с помощью LLM" ], - "hint": "При выборе 'Обрезать' удаляются старые сообщения. При выборе 'Сжать' используется модель для суммаризации контекста." + "hint": "Постоянная история диалога использует эту стратегию только после превышения лимита раундов. Перед каждым запросом та же стратегия может защищать текущий контекст, когда токены приближаются к окну модели." }, "llm_compress_instruction": { "description": "Инструкция для сжатия контекста", @@ -271,7 +271,7 @@ }, "llm_compress_provider_id": { "description": "Модель для сжатия контекста", - "hint": "Если не выбрано, произойдет откат к стратегии удаления сообщений." + "hint": "Если не выбрано, для сжатия используется текущая модель чата. Если модель недоступна или сжатие завершается ошибкой, AstrBot откатывается к обрезке по раундам." }, "fallback_max_context_tokens": { "description": "Запасной размер окна контекста", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 70f4fa5c79..96fdf07d73 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -249,19 +249,19 @@ "provider_settings": { "max_context_length": { "description": "压缩前最多保留对话轮数", - "hint": "无论选择截断还是 LLM 压缩,都会先按该值限制历史轮数;-1 表示不按轮数限制" + "hint": "普通会话历史超过该轮数后,才会按下方策略进行持久化截断或 LLM 压缩;请求发送前也会先按该值约束上下文。-1 表示不按轮数限制。" }, "dequeue_context_length": { "description": "轮次超限时一次丢弃轮数", - "hint": "当超过\"压缩前最多保留对话轮数\"时,一次丢弃多少轮旧对话;同时也可能作为压缩不可用时的回退截断参数" + "hint": "当超过\"压缩前最多保留对话轮数\"且无法使用 LLM 压缩时,一次丢弃多少轮旧对话;请求期截断也会复用该值。" }, "context_limit_reached_strategy": { - "description": "模型上下文接近上限后的处理方式", + "description": "历史超限或上下文接近上限时的处理方式", "labels": [ "按对话轮数截断", "由 LLM 压缩上下文" ], - "hint": "该策略只会在完成轮次限制后,且上下文 token 接近模型窗口上限时触发。当按对话轮数截断时,会根据上面\"轮次超限时一次丢弃轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。" + "hint": "普通会话历史仅在超过\"压缩前最多保留对话轮数\"后执行该策略;请求发送前也会在上下文 token 接近模型窗口时使用同一策略保护本次请求。" }, "llm_compress_instruction": { "description": "上下文压缩提示词", @@ -273,7 +273,7 @@ }, "llm_compress_provider_id": { "description": "用于上下文压缩的模型提供商 ID", - "hint": "留空时将降级为\"按对话轮数截断\"的策略。" + "hint": "留空时使用当前聊天模型进行压缩;如果模型不可用或压缩失败,将回退为\"按对话轮数截断\"的策略。" }, "fallback_max_context_tokens": { "description": "上下文窗口兜底值", diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index 1685947b0a..a14b49691a 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -703,91 +703,61 @@ async def test_llm_compression_with_mock_provider(self): # Should have been compressed assert len(result) <= len(messages) - # ==================== split_history Tests ==================== + # ==================== split_into_rounds Tests ==================== - def test_split_history_ensures_user_start(self): - """Test split_history ensures recent_messages starts with user message.""" - from astrbot.core.agent.context.compressor import split_history + def test_split_rounds_ensures_user_start(self): + """Test split_into_rounds preserves user-assistant round boundaries.""" + from astrbot.core.agent.context.round_utils import split_into_rounds - # Create alternating messages: user, assistant, user, assistant, user, assistant + # First round may begin with system messages; subsequent rounds must start with user messages = [ - self.create_message("system", "System prompt"), - self.create_message("user", "msg1"), - self.create_message("assistant", "msg2"), - self.create_message("user", "msg3"), - self.create_message("assistant", "msg4"), - self.create_message("user", "msg5"), - self.create_message("assistant", "msg6"), + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "msg1"}, + {"role": "assistant", "content": "msg2"}, + {"role": "user", "content": "msg3"}, + {"role": "assistant", "content": "msg4"}, + {"role": "tool", "content": "tool result"}, + {"role": "user", "content": "msg5"}, ] - # Keep recent 3 messages - should adjust to start with user - system, to_summarize, recent = split_history(messages, keep_recent=3) + rounds = split_into_rounds(messages) - # recent_messages should start with user message - assert len(recent) > 0 - assert recent[0].role == "user" + # Subsequent rounds (after the first) must start with user + for rnd in rounds[1:]: + assert rnd[0]["role"] == "user" - # messages_to_summarize should end with assistant (complete turn) - if len(to_summarize) > 0: - assert to_summarize[-1].role == "assistant" + assert len(rounds) >= 2 - def test_split_history_handles_assistant_at_split_point(self): - """Test split_history when assistant message is at the intended split point.""" - from astrbot.core.agent.context.compressor import split_history + def test_split_rounds_single_round(self): + """A single user-assistant pair is one round.""" + from astrbot.core.agent.context.round_utils import split_into_rounds messages = [ - self.create_message("user", "msg1"), - self.create_message("assistant", "msg2"), - self.create_message("user", "msg3"), - self.create_message("assistant", "msg4"), # <- intended split here - self.create_message("user", "msg5"), - self.create_message("assistant", "msg6"), + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, ] + rounds = split_into_rounds(messages) + assert len(rounds) == 1 + assert rounds[0][0]["role"] == "user" - # keep_recent=2 would normally split at index 4 (assistant msg4) - # Should move back to include from msg5 (user) - system, to_summarize, recent = split_history(messages, keep_recent=2) - - # recent should start with user message - assert recent[0].role == "user" - assert recent[0].content == "msg5" - - def test_split_history_all_assistant_messages(self): - """Test split_history when there are consecutive assistant messages.""" - from astrbot.core.agent.context.compressor import split_history + def test_split_rounds_multi_tool(self): + """Tool calls/results within a round are kept together.""" + from astrbot.core.agent.context.round_utils import split_into_rounds messages = [ - self.create_message("user", "msg1"), - self.create_message("assistant", "msg2"), - self.create_message("assistant", "msg3"), - self.create_message("assistant", "msg4"), + {"role": "user", "content": "search"}, + {"role": "assistant", "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "tool_call_id": "c1", "content": "result1"}, + {"role": "tool", "tool_call_id": "c1", "content": "result2"}, + {"role": "assistant", "content": "done"}, ] - - system, to_summarize, recent = split_history(messages, keep_recent=2) - - # Should find the user message and keep from there - if len(recent) > 0: - # Find first user message backwards - assert any(m.role == "user" for m in messages) - - def test_split_history_with_system_messages(self): - """Test split_history preserves system messages separately.""" - from astrbot.core.agent.context.compressor import split_history - - messages = [ - self.create_message("system", "System 1"), - self.create_message("system", "System 2"), - self.create_message("user", "msg1"), - self.create_message("assistant", "msg2"), - self.create_message("user", "msg3"), - ] - - system, to_summarize, recent = split_history(messages, keep_recent=2) - - # System messages should be separate - assert len(system) == 2 - assert all(m.role == "system" for m in system) - - # Recent should start with user - if len(recent) > 0: - assert recent[0].role == "user" + rounds = split_into_rounds(messages) + # One round with 5 segments + assert len(rounds) == 1 + assert len(rounds[0]) == 5 + + def test_split_rounds_empty(self): + """Empty list returns no rounds.""" + from astrbot.core.agent.context.round_utils import split_into_rounds + rounds = split_into_rounds([]) + assert len(rounds) == 0 diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 74d0691085..7404d56680 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -260,6 +260,33 @@ async def text_chat(self, **kwargs) -> LLMResponse: ) +class CapturingToolLoopProvider(MockProvider): + def __init__(self, tool_name: str): + super().__init__() + self.tool_name = tool_name + self.received_contexts = [] + + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + self.received_contexts.append(list(kwargs.get("contexts") or [])) + func_tool = kwargs.get("func_tool") + if func_tool is None or self.call_count > 1: + return LLMResponse( + role="assistant", + completion_text="最终回复", + usage=TokenUsage(input_other=10, output=5), + ) + + return LLMResponse( + role="assistant", + completion_text="", + tools_call_name=[self.tool_name], + tools_call_args=[{"query": "test"}], + tools_call_ids=["call_context_refresh"], + usage=TokenUsage(input_other=10, output=5), + ) + + class SequentialToolProvider(MockProvider): def __init__(self, tool_sequence: list[str]): super().__init__() @@ -450,6 +477,68 @@ async def test_max_step_limit_functionality( assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答" +@pytest.mark.asyncio +async def test_max_step_final_request_includes_limit_prompt( + runner, provider_request, mock_tool_executor, mock_hooks +): + """The forced final step must use contexts recomputed after max-step prompt.""" + provider = CapturingToolLoopProvider("test_tool") + + await runner.reset( + provider=provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + async def snapshot_context_manager(messages, trusted_token_usage=0): + return list(messages) + + runner.request_context_manager.process = snapshot_context_manager + + async for _ in runner.step_until_done(1): + pass + + assert provider.call_count == 2 + final_contexts = provider.received_contexts[-1] + assert final_contexts[-1].role == "user" + assert final_contexts[-1].content == runner.MAX_STEPS_REACHED_PROMPT + + +@pytest.mark.asyncio +async def test_tool_loop_next_request_includes_tool_result( + runner, provider_request, mock_tool_executor, mock_hooks +): + """Tool-loop provider contexts must be recomputed after tool results append.""" + provider = CapturingToolLoopProvider("test_tool") + + await runner.reset( + provider=provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + async def snapshot_context_manager(messages, trusted_token_usage=0): + return list(messages) + + runner.request_context_manager.process = snapshot_context_manager + + async for _ in runner.step_until_done(3): + pass + + assert provider.call_count == 2 + second_contexts = provider.received_contexts[1] + tool_messages = [msg for msg in second_contexts if msg.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].tool_call_id == "call_context_refresh" + assert "工具执行结果" in tool_messages[0].content + + @pytest.mark.asyncio async def test_normal_completion_without_max_step( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks diff --git a/tests/unit/test_group_chat_context_wiring.py b/tests/unit/test_group_chat_context_wiring.py new file mode 100644 index 0000000000..a452434ca3 --- /dev/null +++ b/tests/unit/test_group_chat_context_wiring.py @@ -0,0 +1,120 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.api.message_components import Plain +from astrbot.builtin_stars.astrbot.main import Main + + +def make_main_with_conversation_manager(conv_mgr): + main = Main.__new__(Main) + main.context = MagicMock() + main.context.conversation_manager = conv_mgr + return main + + +def make_event(umo: str = "aiocqhttp:GroupMessage:user_123_group_456"): + event = MagicMock() + event.unified_msg_origin = umo + event.get_platform_id.return_value = "aiocqhttp" + event.message_obj = SimpleNamespace(message=[Plain("hello")]) + event.message_str = "hello" + event.session_id = "session-1" + return event + + +@pytest.mark.asyncio +async def test_active_reply_does_not_create_conversation_when_current_missing(): + conv_mgr = SimpleNamespace( + get_curr_conversation_id=AsyncMock(return_value=None), + new_conversation=AsyncMock(), + get_conversation=AsyncMock(), + ) + main = make_main_with_conversation_manager(conv_mgr) + main.context.get_config.return_value = { + "provider_ltm_settings": { + "group_icl_enable": False, + "active_reply": {"enable": True}, + }, + } + main.context.get_using_provider.return_value = object() + main.group_chat_context = SimpleNamespace( + need_active_reply=AsyncMock(return_value=True), + handle_message=AsyncMock(), + ) + event = make_event() + + results = [item async for item in main.on_message(event)] + + assert results == [] + conv_mgr.get_curr_conversation_id.assert_awaited_once_with(event.unified_msg_origin) + conv_mgr.new_conversation.assert_not_called() + conv_mgr.get_conversation.assert_not_called() + event.request_llm.assert_not_called() + + +@pytest.mark.asyncio +async def test_active_reply_reuses_current_umo_conversation(): + conv = SimpleNamespace(cid="cid-1") + conv_mgr = SimpleNamespace( + get_curr_conversation_id=AsyncMock(return_value="cid-1"), + new_conversation=AsyncMock(), + get_conversation=AsyncMock(return_value=conv), + ) + main = make_main_with_conversation_manager(conv_mgr) + main.context.get_config.return_value = { + "provider_ltm_settings": { + "group_icl_enable": False, + "active_reply": {"enable": True}, + }, + } + main.context.get_using_provider.return_value = object() + main.group_chat_context = SimpleNamespace( + need_active_reply=AsyncMock(return_value=True), + handle_message=AsyncMock(), + ) + event = make_event("aiocqhttp:GroupMessage:user_999_group_456") + llm_request = object() + event.request_llm.return_value = llm_request + + results = [item async for item in main.on_message(event)] + + assert results == [llm_request] + conv_mgr.get_curr_conversation_id.assert_awaited_once_with(event.unified_msg_origin) + conv_mgr.new_conversation.assert_not_called() + conv_mgr.get_conversation.assert_awaited_once_with( + event.unified_msg_origin, + "cid-1", + ) + event.request_llm.assert_called_once_with( + prompt="hello", + session_id="session-1", + image_urls=[], + conversation=conv, + ) + + +@pytest.mark.asyncio +async def test_on_message_does_not_clear_group_context_on_first_enabled_message(): + main = Main.__new__(Main) + main.context = MagicMock() + main.context.get_config.return_value = { + "provider_ltm_settings": { + "group_icl_enable": True, + "active_reply": {"enable": False}, + }, + } + main.group_chat_context = SimpleNamespace( + need_active_reply=AsyncMock(return_value=False), + handle_message=AsyncMock(), + remove_session=AsyncMock(), + ) + event = make_event() + + async for _ in main.on_message(event): + pass + + main.group_chat_context.need_active_reply.assert_awaited_once_with(event) + main.group_chat_context.handle_message.assert_awaited_once_with(event) + main.group_chat_context.remove_session.assert_not_called()