diff --git a/agentrun/memory_collection/memory_conversation.py b/agentrun/memory_collection/memory_conversation.py index 64ab94f..7ee16ee 100644 --- a/agentrun/memory_collection/memory_conversation.py +++ b/agentrun/memory_collection/memory_conversation.py @@ -4,6 +4,7 @@ """ +import asyncio import json import os from typing import ( @@ -76,6 +77,7 @@ def __init__( # 延迟初始化 self._memory_store = None self._ots_client = None + self._init_lock = asyncio.Lock() @staticmethod def _default_user_id_extractor(req: Any) -> str: @@ -153,10 +155,18 @@ def _default_agent_id_extractor(req: Any) -> str: return "default_agent" async def _get_memory_store(self): - """获取或创建 AsyncMemoryStore 实例""" + """获取或创建 AsyncMemoryStore 实例(双检锁,并发安全)""" if self._memory_store is not None: return self._memory_store + async with self._init_lock: + # 拿到锁后再检查一次,防止并发请求重复初始化 + if self._memory_store is not None: + return self._memory_store + return await self._init_memory_store() + + async def _init_memory_store(self): + """内部初始化方法,由 _get_memory_store 在持锁状态下调用""" try: # 导入依赖 from tablestore_for_agent_memory.base.base_memory_store import ( @@ -228,7 +238,7 @@ async def _get_memory_store(self): ) await self._memory_store.init_table() await self._memory_store.init_search_index() - logger.info(f"Tables and indexes initialized successfully") + logger.info("Tables and indexes initialized successfully") except Exception as e: # 如果表已存在,会抛出异常,这是正常的 logger.info( @@ -384,10 +394,13 @@ async def wrap_invoke_agent( metadata={"agent_id": agent_id}, ) - try: - await memory_store.put_session(session) - except Exception as e: - logger.error(f"Failed to save session: {e}", exc_info=True) + async def _put_session_bg(): + try: + await memory_store.put_session(session) + except Exception as e: + logger.error(f"Failed to save session: {e}", exc_info=True) + + asyncio.create_task(_put_session_bg()) # 构建输入消息列表(包含所有历史消息) input_messages = [] @@ -465,57 +478,62 @@ async def wrap_invoke_agent( yield event # 保存完整的对话轮次(输入 + 输出) - # 只有当有文本内容或工具调用时才保存 + # 使用 fire-and-forget 避免阻塞流式响应关闭 if agent_response_content or tool_calls or tool_results: - try: - # 构建助手响应消息 - assistant_message: Dict[str, Any] = { - "role": "assistant", - } - - # 添加文本内容(如果有) - if agent_response_content: - assistant_message["content"] = agent_response_content - else: - # OpenAI 格式要求:如果有 tool_calls,content 可以为 null - assistant_message["content"] = None - - # 添加工具调用(如果有) - if tool_calls: - assistant_message["tool_calls"] = list( - tool_calls.values() + # 构建助手响应消息 + assistant_message: Dict[str, Any] = { + "role": "assistant", + } + + if agent_response_content: + assistant_message["content"] = agent_response_content + else: + assistant_message["content"] = None + + if tool_calls: + assistant_message["tool_calls"] = list(tool_calls.values()) + + output_messages = input_messages + [assistant_message] + + if tool_results: + output_messages.extend(tool_results) + + conversation_message = Message( + session_id=session_id, + message_id=f"msg_{uuid.uuid4().hex[:16]}", + content=json.dumps(output_messages, ensure_ascii=False), + ) + + async def _save_conversation_bg( + ms=memory_store, + msg=conversation_message, + sess=session, + n_msgs=len(output_messages), + text_len=len(agent_response_content), + n_tc=len(tool_calls), + n_tr=len(tool_results), + ): + try: + await ms.put_message(msg) + sess.update_time = microseconds_timestamp() + await ms.update_session(sess) + logger.debug( + "Saved conversation: %d messages," + " text length: %d chars," + " tool_calls: %d, tool_results: %d", + n_msgs, + text_len, + n_tc, + n_tr, + ) + except Exception as e: + logger.error( + "Failed to save conversation: %s", + e, + exc_info=True, ) - # 构建完整的消息列表 - output_messages = input_messages + [assistant_message] - - # 添加工具执行结果(如果有) - if tool_results: - output_messages.extend(tool_results) - - # 将完整的对话历史存储为一条消息 - # content 字段存储 JSON 格式的消息列表 - conversation_message = Message( - session_id=session_id, - message_id=f"msg_{uuid.uuid4().hex[:16]}", - content=json.dumps(output_messages, ensure_ascii=False), - ) - await memory_store.put_message(conversation_message) - - # 更新 Session 时间 - session.update_time = microseconds_timestamp() - await memory_store.update_session(session) - - logger.debug( - f"Saved conversation: {len(output_messages)} messages," - f" text length: {len(agent_response_content)} chars," - f" tool_calls: {len(tool_calls)}, tool_results:" - f" {len(tool_results)}" - ) - except Exception as e: - logger.error( - f"Failed to save conversation: {e}", exc_info=True - ) + asyncio.create_task(_save_conversation_bg()) except Exception as e: logger.error(f"Error in agent handler: {e}", exc_info=True) diff --git a/tests/unittests/memory_collection/test_memory_conversation.py b/tests/unittests/memory_collection/test_memory_conversation.py index 20b2592..5763bd6 100644 --- a/tests/unittests/memory_collection/test_memory_conversation.py +++ b/tests/unittests/memory_collection/test_memory_conversation.py @@ -1,5 +1,6 @@ """Tests for AgentRun Memory Conversation / AgentRun 记忆对话测试""" +import asyncio from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -8,6 +9,11 @@ from agentrun.server.model import AgentRequest, Message, MessageRole +async def _flush_bg_tasks(): + """Let fire-and-forget background tasks complete before assertions.""" + await asyncio.sleep(0.05) + + @pytest.fixture def mock_memory_collection(): """Mock MemoryCollection""" @@ -185,6 +191,9 @@ async def mock_agent(request: AgentRequest): # Verify results assert results == ["Hello", ", ", "world!"] + # Wait for fire-and-forget background tasks to complete + await _flush_bg_tasks() + # Verify memory store calls assert mock_memory_store.put_session.called assert mock_memory_store.put_message.called @@ -252,6 +261,9 @@ async def mock_agent(request: AgentRequest): async for event in memory.wrap_invoke_agent(request, mock_agent): results.append(event) + # Wait for fire-and-forget background tasks to complete + await _flush_bg_tasks() + # Verify agent still responds assert results == ["Still works!"] @@ -339,6 +351,9 @@ async def mock_agent(request: AgentRequest): assert results[0] == "Let me search for that..." assert results[3] == "Based on the search, it's sunny today." + # Wait for fire-and-forget background tasks to complete + await _flush_bg_tasks() + # Verify message was saved with tool calls assert mock_memory_store.put_message.called saved_message = mock_memory_store.put_message.call_args[0][0] @@ -437,6 +452,9 @@ async def mock_agent(request: AgentRequest): # Verify all events were passed through assert len(results) == 4 + # Wait for fire-and-forget background tasks to complete + await _flush_bg_tasks() + # Verify message was saved with accumulated tool call assert mock_memory_store.put_message.called saved_message = mock_memory_store.put_message.call_args[0][0]