From b4eaa7a9e089e98a57365952af06ae148bfbe81e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AF=92=E5=85=89?= <2510399607@qq.com> Date: Thu, 9 Apr 2026 23:01:05 +0800 Subject: [PATCH 1/3] feat(state): Implement LangGraph checkpoint management with OTS integration - Added checkpoint, checkpoint_writes, and checkpoint_blobs tables to support LangGraph functionality. - Introduced asynchronous methods for initializing and managing checkpoint tables in OTSBackend and SessionStore. - Enhanced SessionStore with methods for checkpoint CRUD operations. - Updated README and conversation_design.md to document new checkpoint features and usage examples. - Refactored utils to build OTS clients independently of code generation templates. This update enables persistent storage of LangGraph checkpoints, enhancing the overall functionality of the conversation service. --- .gitignore | 9 +- agentrun/conversation_service/README.md | 80 +- .../__ots_backend_async_template.py | 541 ++++++++ .../__session_store_async_template.py | 159 ++- .../conversation_service/adapters/__init__.py | 9 + .../adapters/langgraph_adapter.py | 718 ++++++++++ .../conversation_design.md | 134 +- agentrun/conversation_service/model.py | 5 + agentrun/conversation_service/ots_backend.py | 1194 ++++++++++++++++- .../conversation_service/session_store.py | 328 ++++- agentrun/conversation_service/utils.py | 37 + agentrun/integration/utils/tool.py | 4 +- agentrun/utils/log.py | 65 +- examples/conversation_service.md | 722 ++++++++++ .../conversation_service_langchain_server.py | 181 +++ .../conversation_service_langgraph_server.py | 203 +++ .../test_langchain_agui_integration.py | 8 +- tests/unittests/toolset/api/test_openapi.py | 24 +- 18 files changed, 4266 insertions(+), 155 deletions(-) create mode 100644 agentrun/conversation_service/adapters/langgraph_adapter.py create mode 100644 examples/conversation_service.md create mode 100644 examples/conversation_service_langchain_server.py create mode 100644 examples/conversation_service_langgraph_server.py diff --git a/.gitignore b/.gitignore index 7a2da69..096c88e 100644 --- a/.gitignore +++ b/.gitignore @@ -106,11 +106,4 @@ uv.lock coverage.json coverage.json -# examples -examples/conversation_service_adk_example.py -examples/conversation_service_adk_data.py -examples/conversation_service_langchain_example.py -examples/conversation_service_langchain_data.py -examples/conversation_service_verify.py -examples/Langchain_His_example.py -examples/agent-quickstart-langchain/ \ No newline at end of file +local diff --git a/agentrun/conversation_service/README.md b/agentrun/conversation_service/README.md index da51c0f..1ec6b16 100644 --- a/agentrun/conversation_service/README.md +++ b/agentrun/conversation_service/README.md @@ -11,7 +11,7 @@ ADK Agent ──→ OTSSessionService ──┐ │ ┌─────────────┐ ┌─────────┐ LangChain ──→ OTSChatMessageHistory ──→│ SessionStore │───→│ OTS │ │ │ (业务逻辑层) │───→│ Tables │ -LangGraph ──→ (LG Adapter) ─────┘ └─────────────┘ └─────────┘ +LangGraph ──→ OTSCheckpointSaver ─┘ └─────────────┘ └─────────┘ │ OTSBackend (存储操作层) @@ -78,9 +78,11 @@ store.init_tables() | `init_core_tables()` | Conversation + Event + 二级索引 | 所有框架 | | `init_state_tables()` | State + App_state + User_state | ADK 三级 State | | `init_search_index()` | 多元索引(conversation_search_index) | 需要搜索/过滤 | -| `init_tables()` | 以上全部 | 快速开发 | +| `init_checkpoint_tables()` | checkpoint + checkpoint_writes + checkpoint_blobs | LangGraph | +| `init_tables()` | 核心表 + State 表 + 多元索引(不含 checkpoint 表) | 快速开发 | > 多元索引创建耗时较长(数秒级),建议与核心表创建分离,不阻塞核心流程。 +> checkpoint 表仅在使用 LangGraph 时需要,需显式调用 `init_checkpoint_tables()`。 ## 使用示例 @@ -143,6 +145,57 @@ for msg in history.messages: print(f"{msg.type}: {msg.content}") ``` +### LangGraph 集成 + +```python +import asyncio +from langgraph.graph import StateGraph, START, END +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSCheckpointSaver + +# 初始化 +store = SessionStore.from_memory_collection("my-collection") +store.init_core_tables() # conversation 表(会话同步需要) +store.init_checkpoint_tables() # checkpoint 相关表 + +# 创建 checkpointer(指定 agent_id 后自动同步 conversation 记录) +checkpointer = OTSCheckpointSaver( + store, agent_id="my_agent", user_id="default_user" +) + +# 构建 Graph +graph = StateGraph(MyState) +graph.add_node("step", my_node) +graph.add_edge(START, "step") +graph.add_edge("step", END) +app = graph.compile(checkpointer=checkpointer) + +# 对话(自动持久化 checkpoint 到 OTS + 同步 conversation 记录) +async def chat(): + config = { + "configurable": {"thread_id": "thread-1"}, + "metadata": {"user_id": "user_1"}, # 可选,覆盖默认 user_id + } + result = await app.ainvoke({"messages": [...]}, config=config) + # 再次调用同一 thread_id 会自动恢复状态 + result2 = await app.ainvoke({"messages": [...]}, config=config) + +asyncio.run(chat()) +``` + +> **会话同步**:指定 `agent_id` 后,每次 `put()` 会自动在 conversation 表中创建/更新会话记录(`session_id = thread_id`,`framework = "langgraph"`)。这使得外部服务可以通过 `agent_id / user_id` 查询到 LangGraph 的所有会话。 + +### 跨语言查询 LangGraph 状态 + +外部服务(如 Go 后端)可直接通过 OTS SDK 查询 LangGraph 会话状态: + +1. **列出会话**:查询 conversation 表(按 `agent_id/user_id`,过滤 `framework = "langgraph"`) +2. **读取最新 checkpoint**:用 `session_id`(即 `thread_id`)查询 checkpoint 表(GetRange BACKWARD limit=1) +3. **解析数据**:`checkpoint_data` 和 `blob_data` 为 `base64(msgpack)` 格式,Go 使用 msgpack 库(如 `github.com/vmihailenco/msgpack/v5`)解码 +4. **注意**:对于包含 LangChain 对象(HumanMessage 等)的 blob,msgpack 中包含 ext type,需要自定义 decoder 提取 kwargs + +详细序列化格式说明和 Go 伪代码见 [conversation_design.md](./conversation_design.md#跨语言查询-checkpoint-状态)。 + ### 直接使用 SessionStore ```python @@ -195,10 +248,11 @@ store.delete_session("agent_1", "user_1", "sess_1") | 方法 | 说明 | |------|------| -| `init_tables()` | 创建所有表和索引 | +| `init_tables()` | 创建所有表和索引(不含 checkpoint) | | `init_core_tables()` | 创建核心表 + 二级索引 | | `init_state_tables()` | 创建三张 State 表 | | `init_search_index()` | 创建多元索引 | +| `init_checkpoint_tables()` | 创建 LangGraph checkpoint 表 | **Session 管理** @@ -230,12 +284,26 @@ store.delete_session("agent_1", "user_1", "sess_1") | `get_user_state / update_user_state` | 用户级状态读写 | | `get_merged_state(agent_id, user_id, session_id)` | 三级状态浅合并 | +**Checkpoint 管理(LangGraph)** + +| 方法 | 说明 | +|------|------| +| `put_checkpoint(thread_id, checkpoint_ns, checkpoint_id, ...)` | 写入 checkpoint | +| `get_checkpoint(thread_id, checkpoint_ns, checkpoint_id)` | 读取 checkpoint | +| `list_checkpoints(thread_id, checkpoint_ns, *, limit, before)` | 列出 checkpoint | +| `put_checkpoint_writes(thread_id, checkpoint_ns, checkpoint_id, writes)` | 批量写入 writes | +| `get_checkpoint_writes(thread_id, checkpoint_ns, checkpoint_id)` | 读取 writes | +| `put_checkpoint_blob(thread_id, checkpoint_ns, channel, version, ...)` | 写入 blob | +| `get_checkpoint_blobs(thread_id, checkpoint_ns, channel_versions)` | 批量读取 blobs | +| `delete_thread_checkpoints(thread_id)` | 删除 thread 全部 checkpoint 数据 | + ### 框架适配器 | 适配器 | 框架 | 基类 | |--------|------|------| | `OTSSessionService` | Google ADK | `BaseSessionService` | | `OTSChatMessageHistory` | LangChain | `BaseChatMessageHistory` | +| `OTSCheckpointSaver` | LangGraph | `BaseCheckpointSaver` | ### 领域模型 @@ -248,7 +316,7 @@ store.delete_session("agent_1", "user_1", "sess_1") ## OTS 表结构 -共五张表 + 一个二级索引 + 一个多元索引: +共八张表 + 一个二级索引 + 两个多元索引: | 表名 | 主键 | 用途 | |------|------|------| @@ -257,6 +325,9 @@ store.delete_session("agent_1", "user_1", "sess_1") | `state` | agent_id, user_id, session_id | 会话级状态 | | `app_state` | agent_id | 应用级状态 | | `user_state` | agent_id, user_id | 用户级状态 | +| `checkpoint` | thread_id, checkpoint_ns, checkpoint_id | LangGraph checkpoint | +| `checkpoint_writes` | thread_id, checkpoint_ns, checkpoint_id, task_idx | LangGraph 中间写入 | +| `checkpoint_blobs` | thread_id, checkpoint_ns, channel, version | LangGraph 通道值快照 | | `conversation_secondary_index` | agent_id, user_id, updated_at, session_id | 二级索引(list 热路径) | | `conversation_search_index` | 多元索引 | 全文搜索 / 标签过滤 / 组合查询 | @@ -271,6 +342,7 @@ store.delete_session("agent_1", "user_1", "sess_1") | [`conversation_service_adk_data.py`](../../examples/conversation_service_adk_data.py) | ADK 模拟数据填充 + 多元索引搜索验证 | | [`conversation_service_langchain_example.py`](../../examples/conversation_service_langchain_example.py) | LangChain 消息历史读写验证 | | [`conversation_service_langchain_data.py`](../../examples/conversation_service_langchain_data.py) | LangChain 模拟数据填充 | +| [`conversation_service_langgraph_example.py`](../../examples/conversation_service_langgraph_example.py) | LangGraph checkpoint 持久化示例 | | [`conversation_service_verify.py`](../../examples/conversation_service_verify.py) | 端到端 CRUD 验证脚本 | ## 环境变量 diff --git a/agentrun/conversation_service/__ots_backend_async_template.py b/agentrun/conversation_service/__ots_backend_async_template.py index a2bca0c..5bd9a13 100644 --- a/agentrun/conversation_service/__ots_backend_async_template.py +++ b/agentrun/conversation_service/__ots_backend_async_template.py @@ -38,6 +38,9 @@ ConversationEvent, ConversationSession, DEFAULT_APP_STATE_TABLE, + DEFAULT_CHECKPOINT_BLOBS_TABLE, + DEFAULT_CHECKPOINT_TABLE, + DEFAULT_CHECKPOINT_WRITES_TABLE, DEFAULT_CONVERSATION_SEARCH_INDEX, DEFAULT_CONVERSATION_SECONDARY_INDEX, DEFAULT_CONVERSATION_TABLE, @@ -100,6 +103,15 @@ def __init__( ) self._state_search_index = f"{table_prefix}{DEFAULT_STATE_SEARCH_INDEX}" + # LangGraph checkpoint 表 + self._checkpoint_table = f"{table_prefix}{DEFAULT_CHECKPOINT_TABLE}" + self._checkpoint_writes_table = ( + f"{table_prefix}{DEFAULT_CHECKPOINT_WRITES_TABLE}" + ) + self._checkpoint_blobs_table = ( + f"{table_prefix}{DEFAULT_CHECKPOINT_BLOBS_TABLE}" + ) + # ----------------------------------------------------------------------- # 建表(异步)/ Table creation (async) # ----------------------------------------------------------------------- @@ -162,6 +174,116 @@ async def init_search_index_async(self) -> None: await self._create_conversation_search_index_async() await self._create_state_search_index_async() + async def init_checkpoint_tables_async(self) -> None: + """创建 LangGraph checkpoint 相关的 3 张表(异步)。 + + 包含 checkpoint、checkpoint_writes、checkpoint_blobs 表。 + 表已存在时跳过,可重复调用。 + """ + await self._create_checkpoint_table_async() + await self._create_checkpoint_writes_table_async() + await self._create_checkpoint_blobs_table_async() + + async def _create_checkpoint_table_async(self) -> None: + """创建 checkpoint 表(异步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), checkpoint_id (STRING) + """ + table_meta = TableMeta( + self._checkpoint_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("checkpoint_id", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_table, + ) + else: + raise + + async def _create_checkpoint_writes_table_async(self) -> None: + """创建 checkpoint_writes 表(异步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), + checkpoint_id (STRING), task_idx (STRING) + """ + table_meta = TableMeta( + self._checkpoint_writes_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("checkpoint_id", "STRING"), + ("task_idx", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_writes_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_writes_table, + ) + else: + raise + + async def _create_checkpoint_blobs_table_async(self) -> None: + """创建 checkpoint_blobs 表(异步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), + channel (STRING), version (STRING) + """ + table_meta = TableMeta( + self._checkpoint_blobs_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("channel", "STRING"), + ("version", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_blobs_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_blobs_table, + ) + else: + raise + async def _create_conversation_table_async(self) -> None: """创建 Conversation 表 + 二级索引(异步)。""" table_meta = TableMeta( @@ -1205,6 +1327,425 @@ async def delete_state_row_async( condition = Condition(RowExistenceExpectation.IGNORE) await self._async_client.delete_row(table_name, row, condition) + # ----------------------------------------------------------------------- + # Checkpoint CRUD(LangGraph)(异步) + # ----------------------------------------------------------------------- + + async def put_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + *, + checkpoint_type: str, + checkpoint_data: str, + metadata_json: str, + parent_checkpoint_id: str = "", + ) -> None: + """写入/覆盖 checkpoint 行(异步)。""" + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ] + attribute_columns = [ + ("checkpoint_type", checkpoint_type), + ("checkpoint_data", checkpoint_data), + ("metadata", metadata_json), + ("parent_checkpoint_id", parent_checkpoint_id), + ] + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.put_row(self._checkpoint_table, row, condition) + + async def get_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """读取单条 checkpoint(异步)。 + + 若 checkpoint_id 为 None,使用 GetRange 获取最新的(按 checkpoint_id 倒序)。 + + Returns: + 包含 checkpoint 字段的字典,或 None。 + """ + if checkpoint_id is not None: + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ] + _, row, _ = await self._async_client.get_row( + self._checkpoint_table, primary_key, max_version=1 + ) + if row is None or row.primary_key is None: + return None + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + return {**pk, **attrs} + + # checkpoint_id 为 None -> 取最新 + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MAX), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MIN), + ] + _, _, rows, _ = await self._async_client.get_range( + self._checkpoint_table, + Direction.BACKWARD, + inclusive_start, + exclusive_end, + max_version=1, + limit=1, + ) + if not rows: + return None + row = rows[0] + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + return {**pk, **attrs} + + async def list_checkpoints_async( + self, + thread_id: str, + checkpoint_ns: str, + *, + limit: int = 10, + before: Optional[str] = None, + ) -> list[dict[str, Any]]: + """按 checkpoint_id 倒序列出 checkpoint(异步)。 + + Args: + thread_id: 线程 ID。 + checkpoint_ns: checkpoint 命名空间。 + limit: 最多返回条数。 + before: 仅返回 checkpoint_id < before 的记录。 + """ + if before is not None: + start_id: Any = before + else: + start_id = INF_MAX + + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", start_id), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MIN), + ] + + results: list[dict[str, Any]] = [] + next_start = inclusive_start + + while len(results) < limit: + _, next_token, rows, _ = await self._async_client.get_range( + self._checkpoint_table, + Direction.BACKWARD, + next_start, + exclusive_end, + max_version=1, + limit=limit - len(results), + ) + + for row in rows: + pk = self._pk_to_dict(row.primary_key) + # 如果 before 指定了精确值,跳过它本身 + if before is not None and pk.get("checkpoint_id") == before: + continue + attrs = self._attrs_to_dict(row.attribute_columns) + results.append({**pk, **attrs}) + if len(results) >= limit: + break + + if next_token is None: + break + next_start = next_token + + return results + + async def put_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + ) -> None: + """批量写入 checkpoint writes(异步)。 + + Args: + writes: 每个元素是 dict,包含 task_idx, task_id, task_path, + channel, value_type, value_data 字段。 + """ + if not writes: + return + + for i in range(0, len(writes), _BATCH_WRITE_LIMIT): + batch = writes[i : i + _BATCH_WRITE_LIMIT] + from tablestore import PutRowItem # type: ignore[import-untyped] + + put_items = [] + for w in batch: + pk = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", w["task_idx"]), + ] + attrs = [ + ("task_id", w["task_id"]), + ("task_path", w.get("task_path", "")), + ("channel", w["channel"]), + ("value_type", w["value_type"]), + ("value_data", w["value_data"]), + ] + row = Row(pk, attrs) + condition = Condition(RowExistenceExpectation.IGNORE) + put_items.append(PutRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add( + TableInBatchWriteRowItem( + self._checkpoint_writes_table, put_items + ) + ) + await self._async_client.batch_write_row(request) + + async def get_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> list[dict[str, Any]]: + """读取指定 checkpoint 的所有 writes(异步)。""" + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", INF_MIN), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", INF_MAX), + ] + + results: list[dict[str, Any]] = [] + next_start = inclusive_start + + while True: + _, next_token, rows, _ = await self._async_client.get_range( + self._checkpoint_writes_table, + Direction.FORWARD, + next_start, + exclusive_end, + max_version=1, + ) + for row in rows: + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + results.append({**pk, **attrs}) + + if next_token is None: + break + next_start = next_token + + return results + + async def put_checkpoint_blob_async( + self, + thread_id: str, + checkpoint_ns: str, + channel: str, + version: str, + *, + blob_type: str, + blob_data: str, + ) -> None: + """写入/覆盖 checkpoint blob 行(异步)。""" + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("channel", channel), + ("version", version), + ] + attribute_columns = [ + ("blob_type", blob_type), + ("blob_data", blob_data), + ] + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.put_row( + self._checkpoint_blobs_table, row, condition + ) + + async def get_checkpoint_blobs_async( + self, + thread_id: str, + checkpoint_ns: str, + channel_versions: dict[str, str], + ) -> dict[str, dict[str, str]]: + """批量读取 checkpoint blobs(异步)。 + + Args: + channel_versions: {channel: version} 映射。 + + Returns: + {channel: {"blob_type": ..., "blob_data": ...}} 映射。 + """ + if not channel_versions: + return {} + + from tablestore import ( # type: ignore[import-untyped] + BatchGetRowRequest, + TableInBatchGetRowItem, + ) + + results: dict[str, dict[str, str]] = {} + items = list(channel_versions.items()) + + # OTS BatchGetRow 每次最多 100 行 + batch_limit = 100 + for i in range(0, len(items), batch_limit): + batch = items[i : i + batch_limit] + rows_to_get = [] + for ch, ver in batch: + pk = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("channel", ch), + ("version", str(ver)), + ] + rows_to_get.append(pk) + + table_item = TableInBatchGetRowItem( + self._checkpoint_blobs_table, + rows_to_get, + max_version=1, + ) + request = BatchGetRowRequest() + request.add(table_item) + response = await self._async_client.batch_get_row(request) + + table_results = response.get_result_by_table( + self._checkpoint_blobs_table + ) + for item in table_results: + if not item.is_ok or item.row is None: + continue + pk = self._pk_to_dict(item.row.primary_key) + attrs = self._attrs_to_dict(item.row.attribute_columns) + channel_name = pk.get("channel", "") + results[channel_name] = { + "blob_type": attrs.get("blob_type", ""), + "blob_data": attrs.get("blob_data", ""), + } + + return results + + async def delete_thread_checkpoints_async( + self, + thread_id: str, + ) -> None: + """删除指定 thread_id 的所有 checkpoint 相关数据(异步)。 + + 扫描并删除 checkpoint、checkpoint_writes、checkpoint_blobs 三张表中 + 所有以 thread_id 为分区键的行。 + """ + await self._scan_and_delete_async( + self._checkpoint_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("checkpoint_id", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("checkpoint_id", INF_MAX), + ], + ) + await self._scan_and_delete_async( + self._checkpoint_writes_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("checkpoint_id", INF_MIN), + ("task_idx", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("checkpoint_id", INF_MAX), + ("task_idx", INF_MAX), + ], + ) + await self._scan_and_delete_async( + self._checkpoint_blobs_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("channel", INF_MIN), + ("version", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("channel", INF_MAX), + ("version", INF_MAX), + ], + ) + + async def _scan_and_delete_async( + self, + table_name: str, + inclusive_start: list[Any], + exclusive_end: list[Any], + ) -> None: + """通用扫描删除:GetRange 扫描 PK 后 BatchWriteRow 删除(异步)。""" + all_pks: list[Any] = [] + next_start = inclusive_start + + while True: + _, next_token, rows, _ = await self._async_client.get_range( + table_name, + Direction.FORWARD, + next_start, + exclusive_end, + columns_to_get=[], + max_version=1, + ) + for row in rows: + all_pks.append(row.primary_key) + if next_token is None: + break + next_start = next_token + + if not all_pks: + return + + for i in range(0, len(all_pks), _BATCH_WRITE_LIMIT): + batch = all_pks[i : i + _BATCH_WRITE_LIMIT] + delete_items = [] + for pk in batch: + row = Row(pk) + condition = Condition(RowExistenceExpectation.IGNORE) + delete_items.append(DeleteRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add(TableInBatchWriteRowItem(table_name, delete_items)) + await self._async_client.batch_write_row(request) + # ----------------------------------------------------------------------- # 内部辅助方法(I/O 相关,异步) # ----------------------------------------------------------------------- diff --git a/agentrun/conversation_service/__session_store_async_template.py b/agentrun/conversation_service/__session_store_async_template.py index 81ae8ca..5f17f86 100644 --- a/agentrun/conversation_service/__session_store_async_template.py +++ b/agentrun/conversation_service/__session_store_async_template.py @@ -56,6 +56,148 @@ async def init_search_index_async(self) -> None: """ await self._backend.init_search_index_async() + async def init_checkpoint_tables_async(self) -> None: + """创建 LangGraph checkpoint 相关的 3 张表(异步)。 + + 包含 checkpoint、checkpoint_writes、checkpoint_blobs 表。 + 表已存在时跳过,可重复调用。 + """ + await self._backend.init_checkpoint_tables_async() + + async def init_langchain_tables_async(self) -> None: + """创建 LangChain 所需的全部表和索引(异步)。 + + 包含核心表(Conversation + Event + 二级索引)和多元索引。 + 表或索引已存在时跳过,可重复调用。 + """ + await self._backend.init_core_tables_async() + await self._backend.init_search_index_async() + + async def init_langgraph_tables_async(self) -> None: + """创建 LangGraph 所需的全部表和索引(异步)。 + + 包含核心表(Conversation + Event + 二级索引)、多元索引 + 以及 checkpoint 相关的 3 张表(checkpoint / checkpoint_writes / checkpoint_blobs)。 + 表或索引已存在时跳过,可重复调用。 + """ + await self._backend.init_core_tables_async() + await self._backend.init_search_index_async() + await self._backend.init_checkpoint_tables_async() + + # ------------------------------------------------------------------- + # Checkpoint 管理(LangGraph)(异步) + # ------------------------------------------------------------------- + + async def put_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + *, + checkpoint_type: str, + checkpoint_data: str, + metadata_json: str, + parent_checkpoint_id: str = "", + ) -> None: + """写入/覆盖 checkpoint 行(异步)。""" + await self._backend.put_checkpoint_async( + thread_id, + checkpoint_ns, + checkpoint_id, + checkpoint_type=checkpoint_type, + checkpoint_data=checkpoint_data, + metadata_json=metadata_json, + parent_checkpoint_id=parent_checkpoint_id, + ) + + async def get_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """读取单条 checkpoint(异步)。 + + 若 checkpoint_id 为 None,返回最新的 checkpoint。 + """ + return await self._backend.get_checkpoint_async( + thread_id, checkpoint_ns, checkpoint_id + ) + + async def list_checkpoints_async( + self, + thread_id: str, + checkpoint_ns: str, + *, + limit: int = 10, + before: Optional[str] = None, + ) -> list[dict[str, Any]]: + """列出 checkpoint(按 checkpoint_id 倒序)(异步)。""" + return await self._backend.list_checkpoints_async( + thread_id, checkpoint_ns, limit=limit, before=before + ) + + async def put_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + ) -> None: + """批量写入 checkpoint writes(异步)。""" + await self._backend.put_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id, writes + ) + + async def get_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> list[dict[str, Any]]: + """读取指定 checkpoint 的所有 writes(异步)。""" + return await self._backend.get_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id + ) + + async def put_checkpoint_blob_async( + self, + thread_id: str, + checkpoint_ns: str, + channel: str, + version: str, + *, + blob_type: str, + blob_data: str, + ) -> None: + """写入/覆盖 checkpoint blob 行(异步)。""" + await self._backend.put_checkpoint_blob_async( + thread_id, + checkpoint_ns, + channel, + version, + blob_type=blob_type, + blob_data=blob_data, + ) + + async def get_checkpoint_blobs_async( + self, + thread_id: str, + checkpoint_ns: str, + channel_versions: dict[str, str], + ) -> dict[str, dict[str, str]]: + """批量读取 checkpoint blobs(异步)。""" + return await self._backend.get_checkpoint_blobs_async( + thread_id, checkpoint_ns, channel_versions + ) + + async def delete_thread_checkpoints_async( + self, + thread_id: str, + ) -> None: + """删除指定 thread 的所有 checkpoint 相关数据(异步)。""" + await self._backend.delete_thread_checkpoints_async(thread_id) + # ------------------------------------------------------------------- # Session 管理(异步)/ Session management (async) # ------------------------------------------------------------------- @@ -695,11 +837,8 @@ async def from_memory_collection_async( "agentrun 主包未安装。请先安装: pip install agentrun" ) from e - from tablestore import AsyncOTSClient # type: ignore[import-untyped] - from tablestore import OTSClient # type: ignore[import-untyped] - from tablestore import WriteRetryPolicy - from agentrun.conversation_service.utils import ( + build_ots_clients, convert_vpc_endpoint_to_public, ) @@ -745,21 +884,13 @@ async def from_memory_collection_async( sts_token = security_token if security_token else None # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend - ots_client = OTSClient( - endpoint, - access_key_id, - access_key_secret, - instance_name, - sts_token=sts_token, - retry_policy=WriteRetryPolicy(), - ) - async_ots_client = AsyncOTSClient( + # 使用 utils.build_ots_clients 避免 codegen 替换 AsyncOTSClient + ots_client, async_ots_client = build_ots_clients( endpoint, access_key_id, access_key_secret, instance_name, sts_token=sts_token, - retry_policy=WriteRetryPolicy(), ) backend = OTSBackend( diff --git a/agentrun/conversation_service/adapters/__init__.py b/agentrun/conversation_service/adapters/__init__.py index a67d248..eff597b 100644 --- a/agentrun/conversation_service/adapters/__init__.py +++ b/agentrun/conversation_service/adapters/__init__.py @@ -15,7 +15,16 @@ except ImportError: pass +# LangGraph adapter 依赖 langgraph,仅在安装了 langgraph 时可用 +try: + from agentrun.conversation_service.adapters.langgraph_adapter import ( + OTSCheckpointSaver, + ) +except ImportError: + pass + __all__ = [ "OTSChatMessageHistory", "OTSSessionService", + "OTSCheckpointSaver", ] diff --git a/agentrun/conversation_service/adapters/langgraph_adapter.py b/agentrun/conversation_service/adapters/langgraph_adapter.py new file mode 100644 index 0000000..fc02f10 --- /dev/null +++ b/agentrun/conversation_service/adapters/langgraph_adapter.py @@ -0,0 +1,718 @@ +"""LangGraph BaseCheckpointSaver 适配器。 + +基于 OTS 的 LangGraph checkpoint 持久化实现, +通过 SessionStore 层访问 checkpoint/checkpoint_writes/checkpoint_blobs 三张表。 +""" + +from __future__ import annotations + +import base64 +from collections.abc import AsyncIterator, Iterator, Sequence +import json +import logging +import random +from typing import Any, Optional + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + get_checkpoint_id, + get_checkpoint_metadata, + WRITES_IDX_MAP, +) + +from agentrun.conversation_service.session_store import SessionStore + +logger = logging.getLogger(__name__) + + +def _b64_encode(data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +def _b64_decode(data: str) -> bytes: + return base64.b64decode(data) + + +class OTSCheckpointSaver(BaseCheckpointSaver[str]): + """基于 OTS 的 LangGraph checkpoint saver。 + + 将 LangGraph 的 checkpoint 数据持久化到阿里云 TableStore, + 遵循三层架构通过 SessionStore 访问底层存储。 + + 当指定 ``agent_id`` 时,每次 ``put()`` 会自动在 conversation 表中 + 创建/更新会话记录(``session_id = thread_id``,``framework = "langgraph"``), + 使得外部服务可以通过 ``agent_id / user_id`` 查询到 LangGraph 会话列表。 + + Args: + session_store: SessionStore 实例(需已完成 init_checkpoint_tables)。 + agent_id: 智能体 ID。设置后 put 时自动同步 conversation 记录。 + user_id: 默认用户 ID。可通过 ``config["metadata"]["user_id"]`` + 在每次调用时覆盖(优先级更高)。 + + Example:: + + store = await SessionStore.from_memory_collection_async("my-mc") + await store.init_checkpoint_tables_async() + checkpointer = OTSCheckpointSaver( + store, agent_id="my_agent", user_id="default_user" + ) + # 传入 LangGraph 的 StateGraph.compile(checkpointer=checkpointer) + """ + + store: SessionStore + agent_id: str + user_id: str + + def __init__( + self, + session_store: SessionStore, + *, + agent_id: str = "", + user_id: str = "", + ) -> None: + super().__init__() + self.store = session_store + self.agent_id = agent_id + self.user_id = user_id + + # ------------------------------------------------------------------ + # Version + # ------------------------------------------------------------------ + + def get_next_version(self, current: str | None, channel: None) -> str: + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = random.random() + return f"{next_v:032}.{next_h:016}" + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _dump_typed_b64(self, value: Any) -> tuple[str, str]: + """序列化值并 base64 编码 bytes 部分。""" + type_str, data = self.serde.dumps_typed(value) + return type_str, _b64_encode(data) + + def _load_typed_b64(self, type_str: str, data_b64: str) -> Any: + """从 base64 编码的字符串反序列化值。""" + data = _b64_decode(data_b64) + return self.serde.loads_typed((type_str, data)) + + def _load_blobs( + self, + blob_map: dict[str, dict[str, str]], + ) -> dict[str, Any]: + """从 blob 数据重建 channel_values。""" + channel_values: dict[str, Any] = {} + for channel, blob_info in blob_map.items(): + blob_type = blob_info.get("blob_type", "") + blob_data = blob_info.get("blob_data", "") + if blob_type and blob_type != "empty": + channel_values[channel] = self._load_typed_b64( + blob_type, blob_data + ) + return channel_values + + def _build_checkpoint_tuple( + self, + thread_id: str, + checkpoint_ns: str, + row: dict[str, Any], + blob_map: dict[str, dict[str, str]], + writes_rows: list[dict[str, Any]], + config: Optional[RunnableConfig] = None, + ) -> CheckpointTuple: + """从存储行数据构建 CheckpointTuple。""" + checkpoint_id = row["checkpoint_id"] + checkpoint_type = row.get("checkpoint_type", "") + checkpoint_data = row.get("checkpoint_data", "") + metadata_json = row.get("metadata", "{}") + parent_checkpoint_id = row.get("parent_checkpoint_id", "") + + checkpoint: Checkpoint = self._load_typed_b64( + checkpoint_type, checkpoint_data + ) + checkpoint["channel_values"] = self._load_blobs(blob_map) + + metadata: CheckpointMetadata = json.loads(metadata_json) + + pending_writes: list[tuple[str, str, Any]] = [] + for w in writes_rows: + task_id = w.get("task_id", "") + channel = w.get("channel", "") + value_type = w.get("value_type", "") + value_data = w.get("value_data", "") + if value_type: + value = self._load_typed_b64(value_type, value_data) + else: + value = None + pending_writes.append((task_id, channel, value)) + + result_config: RunnableConfig = config or { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + parent_config: Optional[RunnableConfig] = None + if parent_checkpoint_id: + parent_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": parent_checkpoint_id, + } + } + + return CheckpointTuple( + config=result_config, + checkpoint=checkpoint, + metadata=metadata, + parent_config=parent_config, + pending_writes=pending_writes if pending_writes else None, + ) + + # ------------------------------------------------------------------ + # Session sync helpers + # ------------------------------------------------------------------ + + def _resolve_user_id(self, config: RunnableConfig) -> str: + """从 config.metadata 或构造器参数中提取 user_id。 + + 优先级:config["metadata"]["user_id"] > self.user_id > "default" + """ + md = config.get("metadata") or {} + if isinstance(md, dict): + uid = md.get("user_id") + if uid: + return str(uid) + return self.user_id or "default" + + def _sync_session(self, thread_id: str, user_id: str) -> None: + """同步创建/更新 conversation 表中的会话记录(同步)。""" + if not self.agent_id: + return + try: + existing = self.store.get_session(self.agent_id, user_id, thread_id) + if existing is None: + self.store.create_session( + agent_id=self.agent_id, + user_id=user_id, + session_id=thread_id, + framework="langgraph", + ) + else: + self.store.update_session( + self.agent_id, + user_id, + thread_id, + version=existing.version, + ) + except Exception: + logger.warning( + "Failed to sync conversation record for " + "agent_id=%s, user_id=%s, thread_id=%s", + self.agent_id, + user_id, + thread_id, + exc_info=True, + ) + + async def _sync_session_async(self, thread_id: str, user_id: str) -> None: + """同步创建/更新 conversation 表中的会话记录(异步)。""" + if not self.agent_id: + return + try: + existing = await self.store.get_session_async( + self.agent_id, user_id, thread_id + ) + if existing is None: + await self.store.create_session_async( + agent_id=self.agent_id, + user_id=user_id, + session_id=thread_id, + framework="langgraph", + ) + else: + await self.store.update_session_async( + self.agent_id, + user_id, + thread_id, + version=existing.version, + ) + except Exception: + logger.warning( + "Failed to sync conversation record for " + "agent_id=%s, user_id=%s, thread_id=%s", + self.agent_id, + user_id, + thread_id, + exc_info=True, + ) + + # ------------------------------------------------------------------ + # Core: get_tuple + # ------------------------------------------------------------------ + + def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: + thread_id: str = config["configurable"]["thread_id"] + checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = get_checkpoint_id(config) + + row = self.store.get_checkpoint(thread_id, checkpoint_ns, checkpoint_id) + if row is None: + return None + + actual_id = row["checkpoint_id"] + checkpoint_type = row.get("checkpoint_type", "") + checkpoint_data = row.get("checkpoint_data", "") + cp: Checkpoint = self._load_typed_b64(checkpoint_type, checkpoint_data) + + blob_map = self.store.get_checkpoint_blobs( + thread_id, checkpoint_ns, cp.get("channel_versions", {}) + ) + + writes_rows = self.store.get_checkpoint_writes( + thread_id, checkpoint_ns, actual_id + ) + + result_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": actual_id, + } + } + + return self._build_checkpoint_tuple( + thread_id, + checkpoint_ns, + row, + blob_map, + writes_rows, + config=result_config, + ) + + # ------------------------------------------------------------------ + # Core: list + # ------------------------------------------------------------------ + + def list( + self, + config: RunnableConfig | None, + *, + filter: dict[str, Any] | None = None, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> Iterator[CheckpointTuple]: + if config is None: + return + + thread_id: str = config["configurable"]["thread_id"] + checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "") + + before_id: Optional[str] = None + if before: + before_id = get_checkpoint_id(before) + + fetch_limit = limit if limit is not None else 100 + rows = self.store.list_checkpoints( + thread_id, + checkpoint_ns, + limit=fetch_limit, + before=before_id, + ) + + yielded = 0 + for row in rows: + if limit is not None and yielded >= limit: + break + + checkpoint_id = row["checkpoint_id"] + checkpoint_type = row.get("checkpoint_type", "") + checkpoint_data = row.get("checkpoint_data", "") + metadata_json = row.get("metadata", "{}") + + metadata: CheckpointMetadata = json.loads(metadata_json) + if filter and not all( + query_value == metadata.get(query_key) + for query_key, query_value in filter.items() + ): + continue + + cp: Checkpoint = self._load_typed_b64( + checkpoint_type, checkpoint_data + ) + + blob_map = self.store.get_checkpoint_blobs( + thread_id, checkpoint_ns, cp.get("channel_versions", {}) + ) + + writes_rows = self.store.get_checkpoint_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + + yield self._build_checkpoint_tuple( + thread_id, checkpoint_ns, row, blob_map, writes_rows + ) + yielded += 1 + + # ------------------------------------------------------------------ + # Core: put + # ------------------------------------------------------------------ + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + parent_checkpoint_id = config["configurable"].get("checkpoint_id", "") + + c = checkpoint.copy() + channel_values: dict[str, Any] = c.pop("channel_values") # type: ignore[misc] + + for channel, version in new_versions.items(): + if channel in channel_values: + blob_type, blob_data = self._dump_typed_b64( + channel_values[channel] + ) + else: + blob_type, blob_data = "empty", "" + + self.store.put_checkpoint_blob( + thread_id, + checkpoint_ns, + channel, + str(version), + blob_type=blob_type, + blob_data=blob_data, + ) + + cp_type, cp_data = self._dump_typed_b64(c) + final_metadata = get_checkpoint_metadata(config, metadata) + metadata_json = json.dumps(final_metadata, ensure_ascii=False) + + self.store.put_checkpoint( + thread_id, + checkpoint_ns, + checkpoint["id"], + checkpoint_type=cp_type, + checkpoint_data=cp_data, + metadata_json=metadata_json, + parent_checkpoint_id=parent_checkpoint_id or "", + ) + + self._sync_session(thread_id, self._resolve_user_id(config)) + + return { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + # ------------------------------------------------------------------ + # Core: put_writes + # ------------------------------------------------------------------ + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = config["configurable"]["checkpoint_id"] + + existing_writes = self.store.get_checkpoint_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + existing_keys: set[str] = set() + for w in existing_writes: + existing_keys.add(w.get("task_idx", "")) + + write_rows: list[dict[str, Any]] = [] + for idx, (channel, value) in enumerate(writes): + mapped_idx = WRITES_IDX_MAP.get(channel, idx) + task_idx = f"{task_id}:{mapped_idx}" + + if mapped_idx >= 0 and task_idx in existing_keys: + continue + + value_type, value_data = self._dump_typed_b64(value) + + write_rows.append({ + "task_idx": task_idx, + "task_id": task_id, + "task_path": task_path, + "channel": channel, + "value_type": value_type, + "value_data": value_data, + }) + + if write_rows: + self.store.put_checkpoint_writes( + thread_id, checkpoint_ns, checkpoint_id, write_rows + ) + + # ------------------------------------------------------------------ + # Core: delete_thread + # ------------------------------------------------------------------ + + def delete_thread(self, thread_id: str) -> None: + self.store.delete_thread_checkpoints(thread_id) + if self.agent_id: + user_id = self.user_id or "default" + try: + self.store.delete_session(self.agent_id, user_id, thread_id) + except Exception: + logger.warning( + "Failed to delete conversation record for " + "agent_id=%s, user_id=%s, thread_id=%s", + self.agent_id, + user_id, + thread_id, + exc_info=True, + ) + + # ------------------------------------------------------------------ + # Async versions + # ------------------------------------------------------------------ + + async def aget_tuple( + self, config: RunnableConfig + ) -> CheckpointTuple | None: + thread_id: str = config["configurable"]["thread_id"] + checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = get_checkpoint_id(config) + + row = await self.store.get_checkpoint_async( + thread_id, checkpoint_ns, checkpoint_id + ) + if row is None: + return None + + actual_id = row["checkpoint_id"] + checkpoint_type = row.get("checkpoint_type", "") + checkpoint_data = row.get("checkpoint_data", "") + cp: Checkpoint = self._load_typed_b64(checkpoint_type, checkpoint_data) + + blob_map = await self.store.get_checkpoint_blobs_async( + thread_id, checkpoint_ns, cp.get("channel_versions", {}) + ) + + writes_rows = await self.store.get_checkpoint_writes_async( + thread_id, checkpoint_ns, actual_id + ) + + result_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": actual_id, + } + } + + return self._build_checkpoint_tuple( + thread_id, + checkpoint_ns, + row, + blob_map, + writes_rows, + config=result_config, + ) + + async def alist( + self, + config: RunnableConfig | None, + *, + filter: dict[str, Any] | None = None, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> AsyncIterator[CheckpointTuple]: + if config is None: + return + + thread_id: str = config["configurable"]["thread_id"] + checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "") + + before_id: Optional[str] = None + if before: + before_id = get_checkpoint_id(before) + + fetch_limit = limit if limit is not None else 100 + rows = await self.store.list_checkpoints_async( + thread_id, + checkpoint_ns, + limit=fetch_limit, + before=before_id, + ) + + yielded = 0 + for row in rows: + if limit is not None and yielded >= limit: + break + + checkpoint_id = row["checkpoint_id"] + checkpoint_type = row.get("checkpoint_type", "") + checkpoint_data = row.get("checkpoint_data", "") + metadata_json = row.get("metadata", "{}") + + metadata: CheckpointMetadata = json.loads(metadata_json) + if filter and not all( + query_value == metadata.get(query_key) + for query_key, query_value in filter.items() + ): + continue + + cp: Checkpoint = self._load_typed_b64( + checkpoint_type, checkpoint_data + ) + + blob_map = await self.store.get_checkpoint_blobs_async( + thread_id, checkpoint_ns, cp.get("channel_versions", {}) + ) + + writes_rows = await self.store.get_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id + ) + + yield self._build_checkpoint_tuple( + thread_id, checkpoint_ns, row, blob_map, writes_rows + ) + yielded += 1 + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + parent_checkpoint_id = config["configurable"].get("checkpoint_id", "") + + c = checkpoint.copy() + channel_values: dict[str, Any] = c.pop("channel_values") # type: ignore[misc] + + for channel, version in new_versions.items(): + if channel in channel_values: + blob_type, blob_data = self._dump_typed_b64( + channel_values[channel] + ) + else: + blob_type, blob_data = "empty", "" + + await self.store.put_checkpoint_blob_async( + thread_id, + checkpoint_ns, + channel, + str(version), + blob_type=blob_type, + blob_data=blob_data, + ) + + cp_type, cp_data = self._dump_typed_b64(c) + final_metadata = get_checkpoint_metadata(config, metadata) + metadata_json = json.dumps(final_metadata, ensure_ascii=False) + + await self.store.put_checkpoint_async( + thread_id, + checkpoint_ns, + checkpoint["id"], + checkpoint_type=cp_type, + checkpoint_data=cp_data, + metadata_json=metadata_json, + parent_checkpoint_id=parent_checkpoint_id or "", + ) + + await self._sync_session_async(thread_id, self._resolve_user_id(config)) + + return { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = config["configurable"]["checkpoint_id"] + + existing_writes = await self.store.get_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id + ) + existing_keys: set[str] = set() + for w in existing_writes: + existing_keys.add(w.get("task_idx", "")) + + write_rows: list[dict[str, Any]] = [] + for idx, (channel, value) in enumerate(writes): + mapped_idx = WRITES_IDX_MAP.get(channel, idx) + task_idx = f"{task_id}:{mapped_idx}" + + if mapped_idx >= 0 and task_idx in existing_keys: + continue + + value_type, value_data = self._dump_typed_b64(value) + + write_rows.append({ + "task_idx": task_idx, + "task_id": task_id, + "task_path": task_path, + "channel": channel, + "value_type": value_type, + "value_data": value_data, + }) + + if write_rows: + await self.store.put_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id, write_rows + ) + + async def adelete_thread(self, thread_id: str) -> None: + await self.store.delete_thread_checkpoints_async(thread_id) + if self.agent_id: + user_id = self.user_id or "default" + try: + await self.store.delete_session_async( + self.agent_id, user_id, thread_id + ) + except Exception: + logger.warning( + "Failed to delete conversation record for " + "agent_id=%s, user_id=%s, thread_id=%s", + self.agent_id, + user_id, + thread_id, + exc_info=True, + ) diff --git a/agentrun/conversation_service/conversation_design.md b/agentrun/conversation_service/conversation_design.md index fecfbec..52b7a1e 100644 --- a/agentrun/conversation_service/conversation_design.md +++ b/agentrun/conversation_service/conversation_design.md @@ -175,6 +175,51 @@ Attributes: 说明:三级 State 是 ADK 的概念,其他框架按需使用 +### Checkpoint 表(LangGraph) + +Checkpoint 表 +PK: + thread_id (String, 分区键) + checkpoint_ns (String) + checkpoint_id (String) + +Attributes: + checkpoint_type : String -- serde dumps_typed 返回的类型标识 + checkpoint_data : String -- serde 序列化后的 checkpoint(不含 channel_values),base64 编码 + metadata : String -- JSON 序列化的 CheckpointMetadata + parent_checkpoint_id : String -- 父 checkpoint ID + +### Checkpoint_writes 表(LangGraph) + +Checkpoint_writes 表 +PK: + thread_id (String, 分区键) + checkpoint_ns (String) + checkpoint_id (String) + task_idx (String) -- 格式: "{task_id}:{idx}" + +Attributes: + task_id : String -- 任务标识 + task_path : String -- 任务路径 + channel : String -- 写入通道名 + value_type : String -- serde 类型标识 + value_data : String -- base64 编码的序列化数据 + +### Checkpoint_blobs 表(LangGraph) + +Checkpoint_blobs 表 +PK: + thread_id (String, 分区键) + checkpoint_ns (String) + channel (String) + version (String) + +Attributes: + blob_type : String -- serde 类型标识("empty" 表示空通道) + blob_data : String -- base64 编码的序列化数据 + +说明:使用 base64 编码而非 OTS Binary 类型,与现有 conversation_service 的 String 存储方式保持一致。 + ## 初始化策略 表和索引按用途分组创建,避免为未使用的框架创建不必要的表: @@ -184,8 +229,10 @@ Attributes: init_core_tables() Conversation + Event + 二级索引 所有框架 init_state_tables() State + App_state + User_state ADK 三级 State init_search_index() conversation_search_index (多元索引) 需要搜索/过滤 -init_tables() 以上全部(向后兼容) 快速开发 +init_checkpoint_tables() checkpoint + checkpoint_writes + checkpoint_blobs LangGraph +init_tables() 以上全部(不含 checkpoint 表,向后兼容) 快速开发 多元索引创建耗时较长(数秒级),建议与核心表创建分离,不阻塞核心流程。 +checkpoint 表仅在使用 LangGraph 时需要,需显式调用 init_checkpoint_tables()。 ## 分层架构 @@ -220,6 +267,7 @@ init_tables() 以上全部(向后兼容) │ │ init_core_tables() # 核心表 + 二级索引 │ │ │ │ init_state_tables() # 三级 State 表 │ │ │ │ init_search_index() # 多元索引(按需) │ │ +│ │ init_checkpoint_tables() # LangGraph checkpoint│ │ │ │ │ │ │ │ # Session CRUD │ │ │ │ create_session(...) → ConversationSession │ │ @@ -241,6 +289,13 @@ init_tables() 以上全部(向后兼容) │ │ get_app_state / update_app_state │ │ │ │ get_user_state / update_user_state │ │ │ │ get_merged_state(...)→ dict # 三级浅合并 │ │ +│ │ │ │ +│ │ # Checkpoint 管理(LangGraph) │ │ +│ │ put_checkpoint / get_checkpoint │ │ +│ │ list_checkpoints │ │ +│ │ put_checkpoint_writes / get_checkpoint_writes │ │ +│ │ put_checkpoint_blob / get_checkpoint_blobs │ │ +│ │ delete_thread_checkpoints │ │ │ └──────────────┬──────────────────────────────────┘ │ │ │ │ ├─────────────────┼───────────────────────────────────────┤ @@ -259,4 +314,79 @@ init_tables() 以上全部(向后兼容) │ → OTSClient → OTSBackend │ │ │ │ 也可手动传入 OTSClient 构建 OTSBackend(向后兼容) │ -└─────────────────────────────────────────────────────────┘ \ No newline at end of file +└─────────────────────────────────────────────────────────┘ + +## LangGraph 会话同步 + +OTSCheckpointSaver 在指定 agent_id 后,每次 put() 会自动在 conversation 表 +中创建/更新会话记录: + + session_id = thread_id + framework = "langgraph" + +这使得外部服务(包括非 Python 服务)可以通过标准 OTS 查询: + + 1. conversation 表: GetRange(agent_id, user_id) → 列出所有 LangGraph 会话 + 2. 二级索引: 按 updated_at 排序 + 3. 多元索引: 按 framework="langgraph" 过滤 + +### 跨语言查询 checkpoint 状态 + +#### 序列化格式 + +LangGraph 的 JsonPlusSerializer 使用 **msgpack**(非 JSON)序列化数据。 +存储到 OTS 时经过 base64 编码,因此 OTS 列中的数据格式为 base64(msgpack)。 + +checkpoint_type / blob_type 列的值通常为 "msgpack"。 + +#### 数据分类 + + 简单类型(dict/list/str/int/float): + msgpack 标准编码,任何语言的 msgpack 库可直接解码为原生结构。 + Go: base64.Decode → msgpack.Unmarshal → map[string]interface{} + + LangChain 对象(HumanMessage/AIMessage 等): + 编码为 msgpack Extension Type,内部嵌套 msgpack 数据: + ext(type=N, data=msgpack([module, class_name, kwargs_dict])) + 其中 kwargs_dict 包含实际字段(content, type, name 等),是普通 dict。 + + Go 处理 ext type 的方式(以 vmihailenco/msgpack/v5 为例): + 注册 ext type decoder,将嵌套 msgpack 解码为 [module, class, kwargs] 数组, + 取 kwargs(第 3 个元素)即可获取对象的实际属性值。 + +#### 查询步骤 + + Step 1: checkpoint 表 GetRange + PK: (thread_id, checkpoint_ns="", checkpoint_id=INF_MAX→INF_MIN) + Direction: BACKWARD, Limit: 1 + → 拿到最新行的 checkpoint_type, checkpoint_data, metadata + + Step 2: base64 解码 + msgpack 解析 + checkpoint_data: base64 decode → msgpack unmarshal + 结果为 map: { v, id, ts, channel_versions, versions_seen, ... } + 注意: channel_values 不在此表中,存储在 checkpoint_blobs 表 + metadata 列是 JSON 字符串,可直接 json.Unmarshal + + Step 3: checkpoint_blobs 表 BatchGetRow + 从 Step 2 的 channel_versions 中提取 {channel: version}: + PK: (thread_id, checkpoint_ns="", channel, version) + → 拿到 blob_type, blob_data + + Step 4: 解析 blob 数据 + blob_data: base64 decode → msgpack unmarshal + - 简单 state 字段(str/int/list 等):直接得到原生值 + - LangChain Message 字段:得到 ext type,需自定义 decoder 提取 kwargs + +#### Go 示例伪代码 + + // 简单 state(无 LangChain 对象) + rawBytes, _ := base64.StdEncoding.DecodeString(blobData) + var value interface{} + _ = msgpack.Unmarshal(rawBytes, &value) // 直接可用 + + // 包含 LangChain 对象的 state + // 需要注册 ext type handler: + dec := msgpack.NewDecoder(bytes.NewReader(rawBytes)) + dec.SetCustomStructTag("json") + // 对于 ext type 5 (Pydantic V2): 解码内部 [module, class, kwargs, method] + // 取 kwargs 即可拿到 {content: "...", type: "human", ...} \ No newline at end of file diff --git a/agentrun/conversation_service/model.py b/agentrun/conversation_service/model.py index 4866611..c7b344e 100644 --- a/agentrun/conversation_service/model.py +++ b/agentrun/conversation_service/model.py @@ -23,6 +23,11 @@ DEFAULT_CONVERSATION_SEARCH_INDEX = "conversation_search_index" DEFAULT_STATE_SEARCH_INDEX = "state_search_index" +# LangGraph checkpoint 表 +DEFAULT_CHECKPOINT_TABLE = "checkpoint" +DEFAULT_CHECKPOINT_WRITES_TABLE = "checkpoint_writes" +DEFAULT_CHECKPOINT_BLOBS_TABLE = "checkpoint_blobs" + # --------------------------------------------------------------------------- # 枚举 diff --git a/agentrun/conversation_service/ots_backend.py b/agentrun/conversation_service/ots_backend.py index 3ccb0b9..0135718 100644 --- a/agentrun/conversation_service/ots_backend.py +++ b/agentrun/conversation_service/ots_backend.py @@ -48,6 +48,9 @@ ConversationEvent, ConversationSession, DEFAULT_APP_STATE_TABLE, + DEFAULT_CHECKPOINT_BLOBS_TABLE, + DEFAULT_CHECKPOINT_TABLE, + DEFAULT_CHECKPOINT_WRITES_TABLE, DEFAULT_CONVERSATION_SEARCH_INDEX, DEFAULT_CONVERSATION_SECONDARY_INDEX, DEFAULT_CONVERSATION_TABLE, @@ -110,6 +113,15 @@ def __init__( ) self._state_search_index = f"{table_prefix}{DEFAULT_STATE_SEARCH_INDEX}" + # LangGraph checkpoint 表 + self._checkpoint_table = f"{table_prefix}{DEFAULT_CHECKPOINT_TABLE}" + self._checkpoint_writes_table = ( + f"{table_prefix}{DEFAULT_CHECKPOINT_WRITES_TABLE}" + ) + self._checkpoint_blobs_table = ( + f"{table_prefix}{DEFAULT_CHECKPOINT_BLOBS_TABLE}" + ) + # ----------------------------------------------------------------------- # 建表(异步)/ Table creation (async) # ----------------------------------------------------------------------- @@ -230,6 +242,226 @@ def init_search_index(self) -> None: self._create_conversation_search_index() self._create_state_search_index() + async def init_checkpoint_tables_async(self) -> None: + """创建 LangGraph checkpoint 相关的 3 张表(异步)。 + + 包含 checkpoint、checkpoint_writes、checkpoint_blobs 表。 + 表已存在时跳过,可重复调用。 + """ + await self._create_checkpoint_table_async() + await self._create_checkpoint_writes_table_async() + await self._create_checkpoint_blobs_table_async() + + def init_checkpoint_tables(self) -> None: + """创建 LangGraph checkpoint 相关的 3 张表(同步)。 + + 包含 checkpoint、checkpoint_writes、checkpoint_blobs 表。 + 表已存在时跳过,可重复调用。 + """ + self._create_checkpoint_table() + self._create_checkpoint_writes_table() + self._create_checkpoint_blobs_table() + + async def _create_checkpoint_table_async(self) -> None: + """创建 checkpoint 表(异步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), checkpoint_id (STRING) + """ + table_meta = TableMeta( + self._checkpoint_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("checkpoint_id", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_table, + ) + else: + raise + + def _create_checkpoint_table(self) -> None: + """创建 checkpoint 表(同步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), checkpoint_id (STRING) + """ + table_meta = TableMeta( + self._checkpoint_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("checkpoint_id", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + self._client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_table, + ) + else: + raise + + async def _create_checkpoint_writes_table_async(self) -> None: + """创建 checkpoint_writes 表(异步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), + checkpoint_id (STRING), task_idx (STRING) + """ + table_meta = TableMeta( + self._checkpoint_writes_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("checkpoint_id", "STRING"), + ("task_idx", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_writes_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_writes_table, + ) + else: + raise + + def _create_checkpoint_writes_table(self) -> None: + """创建 checkpoint_writes 表(同步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), + checkpoint_id (STRING), task_idx (STRING) + """ + table_meta = TableMeta( + self._checkpoint_writes_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("checkpoint_id", "STRING"), + ("task_idx", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + self._client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_writes_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_writes_table, + ) + else: + raise + + async def _create_checkpoint_blobs_table_async(self) -> None: + """创建 checkpoint_blobs 表(异步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), + channel (STRING), version (STRING) + """ + table_meta = TableMeta( + self._checkpoint_blobs_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("channel", "STRING"), + ("version", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + await self._async_client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_blobs_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_blobs_table, + ) + else: + raise + + def _create_checkpoint_blobs_table(self) -> None: + """创建 checkpoint_blobs 表(同步)。 + + PK: thread_id (STRING), checkpoint_ns (STRING), + channel (STRING), version (STRING) + """ + table_meta = TableMeta( + self._checkpoint_blobs_table, + [ + ("thread_id", "STRING"), + ("checkpoint_ns", "STRING"), + ("channel", "STRING"), + ("version", "STRING"), + ], + ) + table_options = TableOptions() + reserved_throughput = ReservedThroughput(CapacityUnit(0, 0)) + + try: + self._client.create_table( + table_meta, table_options, reserved_throughput + ) + logger.info("Created table: %s", self._checkpoint_blobs_table) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Table %s already exists, skipping.", + self._checkpoint_blobs_table, + ) + else: + raise + async def _create_conversation_table_async(self) -> None: """创建 Conversation 表 + 二级索引(异步)。""" table_meta = TableMeta( @@ -581,12 +813,14 @@ async def _create_conversation_search_index_async(self) -> None: else: raise - async def _create_state_search_index_async(self) -> None: - """创建 State 表的多元索引(异步)。 + def _create_conversation_search_index(self) -> None: + """创建 Conversation 表的多元索引(同步)。 - 支持按 session_id 独立精确匹配查询,不受主键前缀限制。 + 多元索引支持全文检索 summary、精确匹配过滤 labels/framework/is_pinned、 + 范围查询 updated_at/created_at、跨 user 查询等场景。 索引已存在时跳过。 """ + from tablestore import AnalyzerType # type: ignore[import-untyped] from tablestore import FieldType # type: ignore[import-untyped] from tablestore import IndexSetting # type: ignore[import-untyped] from tablestore import SortOrder # type: ignore[import-untyped] @@ -617,17 +851,41 @@ async def _create_state_search_index_async(self) -> None: enable_sort_and_agg=True, ), FieldSchema( - "created_at", + "updated_at", FieldType.LONG, index=True, enable_sort_and_agg=True, ), FieldSchema( - "updated_at", + "created_at", FieldType.LONG, index=True, enable_sort_and_agg=True, ), + FieldSchema( + "is_pinned", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "framework", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "summary", + FieldType.TEXT, + index=True, + analyzer=AnalyzerType.SINGLEWORD, + ), + FieldSchema( + "labels", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), ] index_setting = IndexSetting(routing_fields=["agent_id"]) @@ -641,15 +899,15 @@ async def _create_state_search_index_async(self) -> None: ) try: - await self._async_client.create_search_index( - self._state_table, - self._state_search_index, + self._client.create_search_index( + self._conversation_table, + self._conversation_search_index, index_meta, ) logger.info( "Created search index: %s on table: %s", - self._state_search_index, - self._state_table, + self._conversation_search_index, + self._conversation_table, ) except OTSServiceError as e: if "already exist" in str(e).lower() or ( @@ -657,23 +915,17 @@ async def _create_state_search_index_async(self) -> None: ): logger.warning( "Search index %s already exists, skipping.", - self._state_search_index, + self._conversation_search_index, ) else: raise - # ----------------------------------------------------------------------- - # Session CRUD(异步)/ Session CRUD (async) - # ----------------------------------------------------------------------- - - def _create_conversation_search_index(self) -> None: - """创建 Conversation 表的多元索引(同步)。 + async def _create_state_search_index_async(self) -> None: + """创建 State 表的多元索引(异步)。 - 多元索引支持全文检索 summary、精确匹配过滤 labels/framework/is_pinned、 - 范围查询 updated_at/created_at、跨 user 查询等场景。 + 支持按 session_id 独立精确匹配查询,不受主键前缀限制。 索引已存在时跳过。 """ - from tablestore import AnalyzerType # type: ignore[import-untyped] from tablestore import FieldType # type: ignore[import-untyped] from tablestore import IndexSetting # type: ignore[import-untyped] from tablestore import SortOrder # type: ignore[import-untyped] @@ -703,12 +955,6 @@ def _create_conversation_search_index(self) -> None: index=True, enable_sort_and_agg=True, ), - FieldSchema( - "updated_at", - FieldType.LONG, - index=True, - enable_sort_and_agg=True, - ), FieldSchema( "created_at", FieldType.LONG, @@ -716,26 +962,8 @@ def _create_conversation_search_index(self) -> None: enable_sort_and_agg=True, ), FieldSchema( - "is_pinned", - FieldType.KEYWORD, - index=True, - enable_sort_and_agg=True, - ), - FieldSchema( - "framework", - FieldType.KEYWORD, - index=True, - enable_sort_and_agg=True, - ), - FieldSchema( - "summary", - FieldType.TEXT, - index=True, - analyzer=AnalyzerType.SINGLEWORD, - ), - FieldSchema( - "labels", - FieldType.KEYWORD, + "updated_at", + FieldType.LONG, index=True, enable_sort_and_agg=True, ), @@ -752,27 +980,39 @@ def _create_conversation_search_index(self) -> None: ) try: - self._client.create_search_index( - self._conversation_table, - self._conversation_search_index, + await self._async_client.create_search_index( + self._state_table, + self._state_search_index, index_meta, ) logger.info( "Created search index: %s on table: %s", - self._conversation_search_index, - self._conversation_table, + self._state_search_index, + self._state_table, ) except OTSServiceError as e: - if "already exist" in str(e).lower() or ( + err_str = str(e).lower() + if "already exist" in err_str or ( hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" ): logger.warning( "Search index %s already exists, skipping.", - self._conversation_search_index, + self._state_search_index, + ) + elif "does not exist" in err_str and "table" in err_str: + logger.warning( + "Table %s does not exist, skipping search index creation" + " for %s.", + self._state_table, + self._state_search_index, ) else: raise + # ----------------------------------------------------------------------- + # Session CRUD(异步)/ Session CRUD (async) + # ----------------------------------------------------------------------- + def _create_state_search_index(self) -> None: """创建 State 表的多元索引(同步)。 @@ -844,13 +1084,21 @@ def _create_state_search_index(self) -> None: self._state_table, ) except OTSServiceError as e: - if "already exist" in str(e).lower() or ( + err_str = str(e).lower() + if "already exist" in err_str or ( hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" ): logger.warning( "Search index %s already exists, skipping.", self._state_search_index, ) + elif "does not exist" in err_str and "table" in err_str: + logger.warning( + "Table %s does not exist, skipping search index creation" + " for %s.", + self._state_table, + self._state_search_index, + ) else: raise @@ -2294,7 +2542,7 @@ async def delete_state_row_async( await self._async_client.delete_row(table_name, row, condition) # ----------------------------------------------------------------------- - # 内部辅助方法(I/O 相关,异步) + # Checkpoint CRUD(LangGraph)(异步) # ----------------------------------------------------------------------- def delete_state_row( @@ -2312,6 +2560,842 @@ def delete_state_row( condition = Condition(RowExistenceExpectation.IGNORE) self._client.delete_row(table_name, row, condition) + # ----------------------------------------------------------------------- + # Checkpoint CRUD(LangGraph)(同步) + # ----------------------------------------------------------------------- + + async def put_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + *, + checkpoint_type: str, + checkpoint_data: str, + metadata_json: str, + parent_checkpoint_id: str = "", + ) -> None: + """写入/覆盖 checkpoint 行(异步)。""" + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ] + attribute_columns = [ + ("checkpoint_type", checkpoint_type), + ("checkpoint_data", checkpoint_data), + ("metadata", metadata_json), + ("parent_checkpoint_id", parent_checkpoint_id), + ] + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.put_row(self._checkpoint_table, row, condition) + + def put_checkpoint( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + *, + checkpoint_type: str, + checkpoint_data: str, + metadata_json: str, + parent_checkpoint_id: str = "", + ) -> None: + """写入/覆盖 checkpoint 行(同步)。""" + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ] + attribute_columns = [ + ("checkpoint_type", checkpoint_type), + ("checkpoint_data", checkpoint_data), + ("metadata", metadata_json), + ("parent_checkpoint_id", parent_checkpoint_id), + ] + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + self._client.put_row(self._checkpoint_table, row, condition) + + async def get_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """读取单条 checkpoint(异步)。 + + 若 checkpoint_id 为 None,使用 GetRange 获取最新的(按 checkpoint_id 倒序)。 + + Returns: + 包含 checkpoint 字段的字典,或 None。 + """ + if checkpoint_id is not None: + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ] + _, row, _ = await self._async_client.get_row( + self._checkpoint_table, primary_key, max_version=1 + ) + if row is None or row.primary_key is None: + return None + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + return {**pk, **attrs} + + # checkpoint_id 为 None -> 取最新 + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MAX), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MIN), + ] + _, _, rows, _ = await self._async_client.get_range( + self._checkpoint_table, + Direction.BACKWARD, + inclusive_start, + exclusive_end, + max_version=1, + limit=1, + ) + if not rows: + return None + row = rows[0] + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + return {**pk, **attrs} + + def get_checkpoint( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """读取单条 checkpoint(同步)。 + + 若 checkpoint_id 为 None,使用 GetRange 获取最新的(按 checkpoint_id 倒序)。 + + Returns: + 包含 checkpoint 字段的字典,或 None。 + """ + if checkpoint_id is not None: + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ] + _, row, _ = self._client.get_row( + self._checkpoint_table, primary_key, max_version=1 + ) + if row is None or row.primary_key is None: + return None + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + return {**pk, **attrs} + + # checkpoint_id 为 None -> 取最新 + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MAX), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MIN), + ] + _, _, rows, _ = self._client.get_range( + self._checkpoint_table, + Direction.BACKWARD, + inclusive_start, + exclusive_end, + max_version=1, + limit=1, + ) + if not rows: + return None + row = rows[0] + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + return {**pk, **attrs} + + async def list_checkpoints_async( + self, + thread_id: str, + checkpoint_ns: str, + *, + limit: int = 10, + before: Optional[str] = None, + ) -> list[dict[str, Any]]: + """按 checkpoint_id 倒序列出 checkpoint(异步)。 + + Args: + thread_id: 线程 ID。 + checkpoint_ns: checkpoint 命名空间。 + limit: 最多返回条数。 + before: 仅返回 checkpoint_id < before 的记录。 + """ + if before is not None: + start_id: Any = before + else: + start_id = INF_MAX + + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", start_id), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MIN), + ] + + results: list[dict[str, Any]] = [] + next_start = inclusive_start + + while len(results) < limit: + _, next_token, rows, _ = await self._async_client.get_range( + self._checkpoint_table, + Direction.BACKWARD, + next_start, + exclusive_end, + max_version=1, + limit=limit - len(results), + ) + + for row in rows: + pk = self._pk_to_dict(row.primary_key) + # 如果 before 指定了精确值,跳过它本身 + if before is not None and pk.get("checkpoint_id") == before: + continue + attrs = self._attrs_to_dict(row.attribute_columns) + results.append({**pk, **attrs}) + if len(results) >= limit: + break + + if next_token is None: + break + next_start = next_token + + return results + + def list_checkpoints( + self, + thread_id: str, + checkpoint_ns: str, + *, + limit: int = 10, + before: Optional[str] = None, + ) -> list[dict[str, Any]]: + """按 checkpoint_id 倒序列出 checkpoint(同步)。 + + Args: + thread_id: 线程 ID。 + checkpoint_ns: checkpoint 命名空间。 + limit: 最多返回条数。 + before: 仅返回 checkpoint_id < before 的记录。 + """ + if before is not None: + start_id: Any = before + else: + start_id = INF_MAX + + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", start_id), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", INF_MIN), + ] + + results: list[dict[str, Any]] = [] + next_start = inclusive_start + + while len(results) < limit: + _, next_token, rows, _ = self._client.get_range( + self._checkpoint_table, + Direction.BACKWARD, + next_start, + exclusive_end, + max_version=1, + limit=limit - len(results), + ) + + for row in rows: + pk = self._pk_to_dict(row.primary_key) + # 如果 before 指定了精确值,跳过它本身 + if before is not None and pk.get("checkpoint_id") == before: + continue + attrs = self._attrs_to_dict(row.attribute_columns) + results.append({**pk, **attrs}) + if len(results) >= limit: + break + + if next_token is None: + break + next_start = next_token + + return results + + async def put_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + ) -> None: + """批量写入 checkpoint writes(异步)。 + + Args: + writes: 每个元素是 dict,包含 task_idx, task_id, task_path, + channel, value_type, value_data 字段。 + """ + if not writes: + return + + for i in range(0, len(writes), _BATCH_WRITE_LIMIT): + batch = writes[i : i + _BATCH_WRITE_LIMIT] + from tablestore import PutRowItem # type: ignore[import-untyped] + + put_items = [] + for w in batch: + pk = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", w["task_idx"]), + ] + attrs = [ + ("task_id", w["task_id"]), + ("task_path", w.get("task_path", "")), + ("channel", w["channel"]), + ("value_type", w["value_type"]), + ("value_data", w["value_data"]), + ] + row = Row(pk, attrs) + condition = Condition(RowExistenceExpectation.IGNORE) + put_items.append(PutRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add( + TableInBatchWriteRowItem( + self._checkpoint_writes_table, put_items + ) + ) + await self._async_client.batch_write_row(request) + + def put_checkpoint_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + ) -> None: + """批量写入 checkpoint writes(同步)。 + + Args: + writes: 每个元素是 dict,包含 task_idx, task_id, task_path, + channel, value_type, value_data 字段。 + """ + if not writes: + return + + for i in range(0, len(writes), _BATCH_WRITE_LIMIT): + batch = writes[i : i + _BATCH_WRITE_LIMIT] + from tablestore import PutRowItem # type: ignore[import-untyped] + + put_items = [] + for w in batch: + pk = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", w["task_idx"]), + ] + attrs = [ + ("task_id", w["task_id"]), + ("task_path", w.get("task_path", "")), + ("channel", w["channel"]), + ("value_type", w["value_type"]), + ("value_data", w["value_data"]), + ] + row = Row(pk, attrs) + condition = Condition(RowExistenceExpectation.IGNORE) + put_items.append(PutRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add( + TableInBatchWriteRowItem( + self._checkpoint_writes_table, put_items + ) + ) + self._client.batch_write_row(request) + + async def get_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> list[dict[str, Any]]: + """读取指定 checkpoint 的所有 writes(异步)。""" + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", INF_MIN), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", INF_MAX), + ] + + results: list[dict[str, Any]] = [] + next_start = inclusive_start + + while True: + _, next_token, rows, _ = await self._async_client.get_range( + self._checkpoint_writes_table, + Direction.FORWARD, + next_start, + exclusive_end, + max_version=1, + ) + for row in rows: + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + results.append({**pk, **attrs}) + + if next_token is None: + break + next_start = next_token + + return results + + def get_checkpoint_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> list[dict[str, Any]]: + """读取指定 checkpoint 的所有 writes(同步)。""" + inclusive_start = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", INF_MIN), + ] + exclusive_end = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("checkpoint_id", checkpoint_id), + ("task_idx", INF_MAX), + ] + + results: list[dict[str, Any]] = [] + next_start = inclusive_start + + while True: + _, next_token, rows, _ = self._client.get_range( + self._checkpoint_writes_table, + Direction.FORWARD, + next_start, + exclusive_end, + max_version=1, + ) + for row in rows: + pk = self._pk_to_dict(row.primary_key) + attrs = self._attrs_to_dict(row.attribute_columns) + results.append({**pk, **attrs}) + + if next_token is None: + break + next_start = next_token + + return results + + async def put_checkpoint_blob_async( + self, + thread_id: str, + checkpoint_ns: str, + channel: str, + version: str, + *, + blob_type: str, + blob_data: str, + ) -> None: + """写入/覆盖 checkpoint blob 行(异步)。""" + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("channel", channel), + ("version", version), + ] + attribute_columns = [ + ("blob_type", blob_type), + ("blob_data", blob_data), + ] + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + await self._async_client.put_row( + self._checkpoint_blobs_table, row, condition + ) + + def put_checkpoint_blob( + self, + thread_id: str, + checkpoint_ns: str, + channel: str, + version: str, + *, + blob_type: str, + blob_data: str, + ) -> None: + """写入/覆盖 checkpoint blob 行(同步)。""" + primary_key = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("channel", channel), + ("version", version), + ] + attribute_columns = [ + ("blob_type", blob_type), + ("blob_data", blob_data), + ] + row = Row(primary_key, attribute_columns) + condition = Condition(RowExistenceExpectation.IGNORE) + self._client.put_row(self._checkpoint_blobs_table, row, condition) + + async def get_checkpoint_blobs_async( + self, + thread_id: str, + checkpoint_ns: str, + channel_versions: dict[str, str], + ) -> dict[str, dict[str, str]]: + """批量读取 checkpoint blobs(异步)。 + + Args: + channel_versions: {channel: version} 映射。 + + Returns: + {channel: {"blob_type": ..., "blob_data": ...}} 映射。 + """ + if not channel_versions: + return {} + + from tablestore import ( # type: ignore[import-untyped] + BatchGetRowRequest, + TableInBatchGetRowItem, + ) + + results: dict[str, dict[str, str]] = {} + items = list(channel_versions.items()) + + # OTS BatchGetRow 每次最多 100 行 + batch_limit = 100 + for i in range(0, len(items), batch_limit): + batch = items[i : i + batch_limit] + rows_to_get = [] + for ch, ver in batch: + pk = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("channel", ch), + ("version", str(ver)), + ] + rows_to_get.append(pk) + + table_item = TableInBatchGetRowItem( + self._checkpoint_blobs_table, + rows_to_get, + max_version=1, + ) + request = BatchGetRowRequest() + request.add(table_item) + response = await self._async_client.batch_get_row(request) + + table_results = response.get_result_by_table( + self._checkpoint_blobs_table + ) + for item in table_results: + if not item.is_ok or item.row is None: + continue + pk = self._pk_to_dict(item.row.primary_key) + attrs = self._attrs_to_dict(item.row.attribute_columns) + channel_name = pk.get("channel", "") + results[channel_name] = { + "blob_type": attrs.get("blob_type", ""), + "blob_data": attrs.get("blob_data", ""), + } + + return results + + def get_checkpoint_blobs( + self, + thread_id: str, + checkpoint_ns: str, + channel_versions: dict[str, str], + ) -> dict[str, dict[str, str]]: + """批量读取 checkpoint blobs(同步)。 + + Args: + channel_versions: {channel: version} 映射。 + + Returns: + {channel: {"blob_type": ..., "blob_data": ...}} 映射。 + """ + if not channel_versions: + return {} + + from tablestore import ( # type: ignore[import-untyped] + BatchGetRowRequest, + TableInBatchGetRowItem, + ) + + results: dict[str, dict[str, str]] = {} + items = list(channel_versions.items()) + + # OTS BatchGetRow 每次最多 100 行 + batch_limit = 100 + for i in range(0, len(items), batch_limit): + batch = items[i : i + batch_limit] + rows_to_get = [] + for ch, ver in batch: + pk = [ + ("thread_id", thread_id), + ("checkpoint_ns", checkpoint_ns), + ("channel", ch), + ("version", str(ver)), + ] + rows_to_get.append(pk) + + table_item = TableInBatchGetRowItem( + self._checkpoint_blobs_table, + rows_to_get, + max_version=1, + ) + request = BatchGetRowRequest() + request.add(table_item) + response = self._client.batch_get_row(request) + + table_results = response.get_result_by_table( + self._checkpoint_blobs_table + ) + for item in table_results: + if not item.is_ok or item.row is None: + continue + pk = self._pk_to_dict(item.row.primary_key) + attrs = self._attrs_to_dict(item.row.attribute_columns) + channel_name = pk.get("channel", "") + results[channel_name] = { + "blob_type": attrs.get("blob_type", ""), + "blob_data": attrs.get("blob_data", ""), + } + + return results + + async def delete_thread_checkpoints_async( + self, + thread_id: str, + ) -> None: + """删除指定 thread_id 的所有 checkpoint 相关数据(异步)。 + + 扫描并删除 checkpoint、checkpoint_writes、checkpoint_blobs 三张表中 + 所有以 thread_id 为分区键的行。 + """ + await self._scan_and_delete_async( + self._checkpoint_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("checkpoint_id", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("checkpoint_id", INF_MAX), + ], + ) + await self._scan_and_delete_async( + self._checkpoint_writes_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("checkpoint_id", INF_MIN), + ("task_idx", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("checkpoint_id", INF_MAX), + ("task_idx", INF_MAX), + ], + ) + await self._scan_and_delete_async( + self._checkpoint_blobs_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("channel", INF_MIN), + ("version", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("channel", INF_MAX), + ("version", INF_MAX), + ], + ) + + def delete_thread_checkpoints( + self, + thread_id: str, + ) -> None: + """删除指定 thread_id 的所有 checkpoint 相关数据(同步)。 + + 扫描并删除 checkpoint、checkpoint_writes、checkpoint_blobs 三张表中 + 所有以 thread_id 为分区键的行。 + """ + self._scan_and_delete( + self._checkpoint_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("checkpoint_id", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("checkpoint_id", INF_MAX), + ], + ) + self._scan_and_delete( + self._checkpoint_writes_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("checkpoint_id", INF_MIN), + ("task_idx", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("checkpoint_id", INF_MAX), + ("task_idx", INF_MAX), + ], + ) + self._scan_and_delete( + self._checkpoint_blobs_table, + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MIN), + ("channel", INF_MIN), + ("version", INF_MIN), + ], + [ + ("thread_id", thread_id), + ("checkpoint_ns", INF_MAX), + ("channel", INF_MAX), + ("version", INF_MAX), + ], + ) + + async def _scan_and_delete_async( + self, + table_name: str, + inclusive_start: list[Any], + exclusive_end: list[Any], + ) -> None: + """通用扫描删除:GetRange 扫描 PK 后 BatchWriteRow 删除(异步)。""" + all_pks: list[Any] = [] + next_start = inclusive_start + + while True: + _, next_token, rows, _ = await self._async_client.get_range( + table_name, + Direction.FORWARD, + next_start, + exclusive_end, + columns_to_get=[], + max_version=1, + ) + for row in rows: + all_pks.append(row.primary_key) + if next_token is None: + break + next_start = next_token + + if not all_pks: + return + + for i in range(0, len(all_pks), _BATCH_WRITE_LIMIT): + batch = all_pks[i : i + _BATCH_WRITE_LIMIT] + delete_items = [] + for pk in batch: + row = Row(pk) + condition = Condition(RowExistenceExpectation.IGNORE) + delete_items.append(DeleteRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add(TableInBatchWriteRowItem(table_name, delete_items)) + await self._async_client.batch_write_row(request) + + # ----------------------------------------------------------------------- + # 内部辅助方法(I/O 相关,异步) + # ----------------------------------------------------------------------- + + def _scan_and_delete( + self, + table_name: str, + inclusive_start: list[Any], + exclusive_end: list[Any], + ) -> None: + """通用扫描删除:GetRange 扫描 PK 后 BatchWriteRow 删除(同步)。""" + all_pks: list[Any] = [] + next_start = inclusive_start + + while True: + _, next_token, rows, _ = self._client.get_range( + table_name, + Direction.FORWARD, + next_start, + exclusive_end, + columns_to_get=[], + max_version=1, + ) + for row in rows: + all_pks.append(row.primary_key) + if next_token is None: + break + next_start = next_token + + if not all_pks: + return + + for i in range(0, len(all_pks), _BATCH_WRITE_LIMIT): + batch = all_pks[i : i + _BATCH_WRITE_LIMIT] + delete_items = [] + for pk in batch: + row = Row(pk) + condition = Condition(RowExistenceExpectation.IGNORE) + delete_items.append(DeleteRowItem(row, condition)) + + request = BatchWriteRowRequest() + request.add(TableInBatchWriteRowItem(table_name, delete_items)) + self._client.batch_write_row(request) + # ----------------------------------------------------------------------- # 内部辅助方法(I/O 相关,同步) # ----------------------------------------------------------------------- diff --git a/agentrun/conversation_service/session_store.py b/agentrun/conversation_service/session_store.py index 48378f8..9d24807 100644 --- a/agentrun/conversation_service/session_store.py +++ b/agentrun/conversation_service/session_store.py @@ -81,10 +81,6 @@ async def init_search_index_async(self) -> None: """ await self._backend.init_search_index_async() - # ------------------------------------------------------------------- - # Session 管理(异步)/ Session management (async) - # ------------------------------------------------------------------- - def init_search_index(self) -> None: """创建 Conversation 和 State 多元索引(同步)。 @@ -92,6 +88,294 @@ def init_search_index(self) -> None: """ self._backend.init_search_index() + async def init_checkpoint_tables_async(self) -> None: + """创建 LangGraph checkpoint 相关的 3 张表(异步)。 + + 包含 checkpoint、checkpoint_writes、checkpoint_blobs 表。 + 表已存在时跳过,可重复调用。 + """ + await self._backend.init_checkpoint_tables_async() + + def init_checkpoint_tables(self) -> None: + """创建 LangGraph checkpoint 相关的 3 张表(同步)。 + + 包含 checkpoint、checkpoint_writes、checkpoint_blobs 表。 + 表已存在时跳过,可重复调用。 + """ + self._backend.init_checkpoint_tables() + + async def init_langchain_tables_async(self) -> None: + """创建 LangChain 所需的全部表和索引(异步)。 + + 包含核心表(Conversation + Event + 二级索引)和多元索引。 + 表或索引已存在时跳过,可重复调用。 + """ + await self._backend.init_core_tables_async() + await self._backend.init_search_index_async() + + def init_langchain_tables(self) -> None: + """创建 LangChain 所需的全部表和索引(同步)。 + + 包含核心表(Conversation + Event + 二级索引)和多元索引。 + 表或索引已存在时跳过,可重复调用。 + """ + self._backend.init_core_tables() + self._backend.init_search_index() + + async def init_langgraph_tables_async(self) -> None: + """创建 LangGraph 所需的全部表和索引(异步)。 + + 包含核心表(Conversation + Event + 二级索引)、多元索引 + 以及 checkpoint 相关的 3 张表(checkpoint / checkpoint_writes / checkpoint_blobs)。 + 表或索引已存在时跳过,可重复调用。 + """ + await self._backend.init_core_tables_async() + await self._backend.init_search_index_async() + await self._backend.init_checkpoint_tables_async() + + # ------------------------------------------------------------------- + # Checkpoint 管理(LangGraph)(异步) + # ------------------------------------------------------------------- + + def init_langgraph_tables(self) -> None: + """创建 LangGraph 所需的全部表和索引(同步)。 + + 包含核心表(Conversation + Event + 二级索引)、多元索引 + 以及 checkpoint 相关的 3 张表(checkpoint / checkpoint_writes / checkpoint_blobs)。 + 表或索引已存在时跳过,可重复调用。 + """ + self._backend.init_core_tables() + self._backend.init_search_index() + self._backend.init_checkpoint_tables() + + # ------------------------------------------------------------------- + # Checkpoint 管理(LangGraph)(同步) + # ------------------------------------------------------------------- + + async def put_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + *, + checkpoint_type: str, + checkpoint_data: str, + metadata_json: str, + parent_checkpoint_id: str = "", + ) -> None: + """写入/覆盖 checkpoint 行(异步)。""" + await self._backend.put_checkpoint_async( + thread_id, + checkpoint_ns, + checkpoint_id, + checkpoint_type=checkpoint_type, + checkpoint_data=checkpoint_data, + metadata_json=metadata_json, + parent_checkpoint_id=parent_checkpoint_id, + ) + + def put_checkpoint( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + *, + checkpoint_type: str, + checkpoint_data: str, + metadata_json: str, + parent_checkpoint_id: str = "", + ) -> None: + """写入/覆盖 checkpoint 行(同步)。""" + self._backend.put_checkpoint( + thread_id, + checkpoint_ns, + checkpoint_id, + checkpoint_type=checkpoint_type, + checkpoint_data=checkpoint_data, + metadata_json=metadata_json, + parent_checkpoint_id=parent_checkpoint_id, + ) + + async def get_checkpoint_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """读取单条 checkpoint(异步)。 + + 若 checkpoint_id 为 None,返回最新的 checkpoint。 + """ + return await self._backend.get_checkpoint_async( + thread_id, checkpoint_ns, checkpoint_id + ) + + def get_checkpoint( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: Optional[str] = None, + ) -> Optional[dict[str, Any]]: + """读取单条 checkpoint(同步)。 + + 若 checkpoint_id 为 None,返回最新的 checkpoint。 + """ + return self._backend.get_checkpoint( + thread_id, checkpoint_ns, checkpoint_id + ) + + async def list_checkpoints_async( + self, + thread_id: str, + checkpoint_ns: str, + *, + limit: int = 10, + before: Optional[str] = None, + ) -> list[dict[str, Any]]: + """列出 checkpoint(按 checkpoint_id 倒序)(异步)。""" + return await self._backend.list_checkpoints_async( + thread_id, checkpoint_ns, limit=limit, before=before + ) + + def list_checkpoints( + self, + thread_id: str, + checkpoint_ns: str, + *, + limit: int = 10, + before: Optional[str] = None, + ) -> list[dict[str, Any]]: + """列出 checkpoint(按 checkpoint_id 倒序)(同步)。""" + return self._backend.list_checkpoints( + thread_id, checkpoint_ns, limit=limit, before=before + ) + + async def put_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + ) -> None: + """批量写入 checkpoint writes(异步)。""" + await self._backend.put_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id, writes + ) + + def put_checkpoint_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + ) -> None: + """批量写入 checkpoint writes(同步)。""" + self._backend.put_checkpoint_writes( + thread_id, checkpoint_ns, checkpoint_id, writes + ) + + async def get_checkpoint_writes_async( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> list[dict[str, Any]]: + """读取指定 checkpoint 的所有 writes(异步)。""" + return await self._backend.get_checkpoint_writes_async( + thread_id, checkpoint_ns, checkpoint_id + ) + + def get_checkpoint_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + ) -> list[dict[str, Any]]: + """读取指定 checkpoint 的所有 writes(同步)。""" + return self._backend.get_checkpoint_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + + async def put_checkpoint_blob_async( + self, + thread_id: str, + checkpoint_ns: str, + channel: str, + version: str, + *, + blob_type: str, + blob_data: str, + ) -> None: + """写入/覆盖 checkpoint blob 行(异步)。""" + await self._backend.put_checkpoint_blob_async( + thread_id, + checkpoint_ns, + channel, + version, + blob_type=blob_type, + blob_data=blob_data, + ) + + def put_checkpoint_blob( + self, + thread_id: str, + checkpoint_ns: str, + channel: str, + version: str, + *, + blob_type: str, + blob_data: str, + ) -> None: + """写入/覆盖 checkpoint blob 行(同步)。""" + self._backend.put_checkpoint_blob( + thread_id, + checkpoint_ns, + channel, + version, + blob_type=blob_type, + blob_data=blob_data, + ) + + async def get_checkpoint_blobs_async( + self, + thread_id: str, + checkpoint_ns: str, + channel_versions: dict[str, str], + ) -> dict[str, dict[str, str]]: + """批量读取 checkpoint blobs(异步)。""" + return await self._backend.get_checkpoint_blobs_async( + thread_id, checkpoint_ns, channel_versions + ) + + def get_checkpoint_blobs( + self, + thread_id: str, + checkpoint_ns: str, + channel_versions: dict[str, str], + ) -> dict[str, dict[str, str]]: + """批量读取 checkpoint blobs(同步)。""" + return self._backend.get_checkpoint_blobs( + thread_id, checkpoint_ns, channel_versions + ) + + async def delete_thread_checkpoints_async( + self, + thread_id: str, + ) -> None: + """删除指定 thread 的所有 checkpoint 相关数据(异步)。""" + await self._backend.delete_thread_checkpoints_async(thread_id) + + # ------------------------------------------------------------------- + # Session 管理(异步)/ Session management (async) + # ------------------------------------------------------------------- + + def delete_thread_checkpoints( + self, + thread_id: str, + ) -> None: + """删除指定 thread 的所有 checkpoint 相关数据(同步)。""" + self._backend.delete_thread_checkpoints(thread_id) + # ------------------------------------------------------------------- # Session 管理(同步)/ Session management (async) # ------------------------------------------------------------------- @@ -1307,11 +1591,8 @@ async def from_memory_collection_async( "agentrun 主包未安装。请先安装: pip install agentrun" ) from e - from tablestore import AsyncOTSClient # type: ignore[import-untyped] - from tablestore import OTSClient # type: ignore[import-untyped] - from tablestore import WriteRetryPolicy - from agentrun.conversation_service.utils import ( + build_ots_clients, convert_vpc_endpoint_to_public, ) @@ -1357,21 +1638,13 @@ async def from_memory_collection_async( sts_token = security_token if security_token else None # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend - ots_client = OTSClient( - endpoint, - access_key_id, - access_key_secret, - instance_name, - sts_token=sts_token, - retry_policy=WriteRetryPolicy(), - ) - async_ots_client = AsyncOTSClient( + # 使用 utils.build_ots_clients 避免 codegen 替换 AsyncOTSClient + ots_client, async_ots_client = build_ots_clients( endpoint, access_key_id, access_key_secret, instance_name, sts_token=sts_token, - retry_policy=WriteRetryPolicy(), ) backend = OTSBackend( @@ -1424,11 +1697,8 @@ def from_memory_collection( "agentrun 主包未安装。请先安装: pip install agentrun" ) from e - from tablestore import AsyncOTSClient # type: ignore[import-untyped] - from tablestore import OTSClient # type: ignore[import-untyped] - from tablestore import WriteRetryPolicy - from agentrun.conversation_service.utils import ( + build_ots_clients, convert_vpc_endpoint_to_public, ) @@ -1471,22 +1741,14 @@ def from_memory_collection( security_token = effective_config.get_security_token() sts_token = security_token if security_token else None - # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend - ots_client = OTSClient( - endpoint, - access_key_id, - access_key_secret, - instance_name, - sts_token=sts_token, - retry_policy=WriteRetryPolicy(), - ) - async_ots_client = AsyncOTSClient( + # 4. 构建 OTSClient + OTSClient 和 OTSBackend + # 使用 utils.build_ots_clients 避免 codegen 替换 OTSClient + ots_client, async_ots_client = build_ots_clients( endpoint, access_key_id, access_key_secret, instance_name, sts_token=sts_token, - retry_policy=WriteRetryPolicy(), ) backend = OTSBackend( diff --git a/agentrun/conversation_service/utils.py b/agentrun/conversation_service/utils.py index 70ac3b8..87d3384 100644 --- a/agentrun/conversation_service/utils.py +++ b/agentrun/conversation_service/utils.py @@ -97,3 +97,40 @@ def from_chunks(chunks: list[str]) -> str: 拼接后的完整字符串。 """ return "".join(chunks) + + +def build_ots_clients( + endpoint: str, + access_key_id: str, + access_key_secret: str, + instance_name: str, + *, + sts_token: str | None = None, +) -> tuple[Any, Any]: + """构建 OTSClient 和 AsyncOTSClient 实例。 + + 独立于 codegen 模板,避免 AsyncOTSClient 被替换为 OTSClient。 + + Returns: + (ots_client, async_ots_client) 二元组。 + """ + from tablestore import AsyncOTSClient # type: ignore[import-untyped] + from tablestore import OTSClient, WriteRetryPolicy + + ots_client = OTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + async_ots_client = AsyncOTSClient( + endpoint, + access_key_id, + access_key_secret, + instance_name, + sts_token=sts_token, + retry_policy=WriteRetryPolicy(), + ) + return ots_client, async_ots_client diff --git a/agentrun/integration/utils/tool.py b/agentrun/integration/utils/tool.py index bde72a5..5b0e874 100644 --- a/agentrun/integration/utils/tool.py +++ b/agentrun/integration/utils/tool.py @@ -1562,8 +1562,8 @@ def _build_openapi_schema( if isinstance(schema, dict): properties[name] = { **schema, - "description": param.get("description") or schema.get( - "description", "" + "description": ( + param.get("description") or schema.get("description", "") ), } if param.get("required"): diff --git a/agentrun/utils/log.py b/agentrun/utils/log.py index 26af5f2..9f40999 100644 --- a/agentrun/utils/log.py +++ b/agentrun/utils/log.py @@ -19,38 +19,45 @@ class CustomFormatter(logging.Formatter): Provides colorful log output format. """ - FORMATS = { - "DEBUG": ( - "\n\x1b[1;36m%(levelname)s\x1b[0m \x1b[36m[%(name)s] %(asctime)s" - " \x1b[2;3m%(pathname)s:%(lineno)s\x1b[0m\n\x1b[2m%(message)s\x1b[0m\n" - ), - "INFO": ( - "\n\x1b[1;34m%(levelname)s\x1b[0m \x1b[34m[%(name)s] %(asctime)s" - " \x1b[2;3m%(pathname)s:%(lineno)s\x1b[0m\n%(message)s\n" - ), - "WARNING": ( - "\n\x1b[1;33m%(levelname)s\x1b[0m \x1b[33m[%(name)s] %(asctime)s" - " \x1b[2;3m%(pathname)s:%(lineno)s\x1b[0m\n%(message)s\n" - ), - "ERROR": ( - "\n\x1b[1;31m%(levelname)s\x1b[0m \x1b[31m[%(name)s] %(asctime)s" - " \x1b[2;3m%(pathname)s:%(lineno)s\x1b[0m\n%(message)s\n" - ), - "CRITICAL": ( - "\n\x1b[1;31m%(levelname)s\x1b[0m \x1b[31m[%(name)s] %(asctime)s" - " \x1b[2;3m%(pathname)s:%(lineno)s\x1b[0m\n%(message)s\n" - ), - "DEFAULT": ( - "\n%(levelname)s [%(name)s] %(asctime)s" - " \x1b[2;3m%(pathname)s:%(lineno)s\x1b[0m\n%(message)s\n" - ), + COLORS: dict[str, str] = { + "DEBUG": "\x1b[36m", + "INFO": "\x1b[34m", + "WARNING": "\x1b[33m", + "ERROR": "\x1b[31m", + "CRITICAL": "\x1b[1;31m", } + RESET = "\x1b[0m" + DIM = "\x1b[2;3m" + DIM_ONLY = "\x1b[2m" + + def __init__(self) -> None: + super().__init__() + self._formatters: dict[str, logging.Formatter] = {} + for level, color in self.COLORS.items(): + if level == "DEBUG": + fmt = ( + f"\n{color}%(levelname)s{self.RESET} {color}[%(name)s]" + " %(asctime)s" + f" {self.DIM}%(pathname)s:%(lineno)s{self.RESET}" + f"\n{self.DIM_ONLY}%(message)s{self.RESET}" + ) + else: + fmt = ( + f"\n{color}%(levelname)s{self.RESET} {color}[%(name)s]" + " %(asctime)s" + f" {self.DIM}%(pathname)s:%(lineno)s{self.RESET}" + "\n%(message)s" + ) + self._formatters[level] = logging.Formatter(fmt) + self._default = logging.Formatter( + "\n%(levelname)s [%(name)s] %(asctime)s" + " %(pathname)s:%(lineno)s\n%(message)s" + ) - def format(self, record): - formatter = logging.Formatter( - self.FORMATS.get(record.levelname, self.FORMATS["DEFAULT"]) + def format(self, record: logging.LogRecord) -> str: + return self._formatters.get(record.levelname, self._default).format( + record ) - return formatter.format(record) logger = logging.getLogger("agentrun-logger") diff --git a/examples/conversation_service.md b/examples/conversation_service.md new file mode 100644 index 0000000..1763ebf --- /dev/null +++ b/examples/conversation_service.md @@ -0,0 +1,722 @@ +**Agent 的本质是对无状态 LLM 进行有状态的精细化 Context 管理**。会话(Session)与状态(State)是 LLM Context 的核心来源。因此,构建一个健壮的会话管理系统,不仅能显著提升开发者的体验,更是 Agent 运行平台的核心竞争力。 + +本文将介绍使用不同 Agent 开发框架如何接入、使用 AgentRun 提供的会话状态持久化能力。 + +## 1. 概述 +AgentRun 提供了**会话状态持久化服务**,为 AI Agent 应用提供开箱即用的会话管理能力。它将会话元数据、对话事件流和多级状态统一持久化到阿里云 TableStore(OTS),让 Agent 具备**跨请求、跨重启**的记忆能力。 + +### 核心能力 +| 能力 | 说明 | +| --- | --- | +| 会话管理 | 创建、查询、列出、删除会话,支持按时间排序和多元索引搜索 | +| 事件流持久化 | 自动持久化对话中的每一轮交互(用户消息、Agent 回复、工具调用等) | +| 三级状态管理 | 支持 App 级、User 级、Session 级三层状态,自动合并返回 | +| 多元索引搜索 | 按关键词、标签、时间范围等条件搜索会话 | +| 多框架适配 | 通过薄适配层对接不同 Agent 开发框架,应用代码无需感知底层存储 | + + +### 框架支持状态 +| 框架 | 适配器 | 状态 | +| --- | --- | --- | +| Google ADK | `OTSSessionService` | 已支持 | +| LangChain | `OTSChatMessageHistory` | 即将推出 | +| LangGraph | - | 即将推出 | + + +--- + +## 2. 前置条件 +### 2.1 创建 MemoryCollection +在 AgentRun 平台上创建一个 MemoryCollection 资源。MemoryCollection 内部包含了 TableStore 实例的连接信息(endpoint、instance_name),Conversation Service 会自动从中读取这些配置。 + +创建方式请参考 [AgentRun 官方文档](https://help.aliyun.com/zh/functioncompute/fc/memory-storage?spm=a2c4g.11186623.help-menu-2508973.d_3_11.3e076abaLO38Z2)。 + +### 2.2 配置环境变量 +在运行应用前,请设置以下环境变量: + +```bash +# 必填:MemoryCollection 名称(在 AgentRun 平台上创建的资源名称) +export MEMORY_COLLECTION_NAME="your-memory-collection-name" +``` + +也可以使用 `.env` 文件配合 `python-dotenv` 加载: + +```plain +MEMORY_COLLECTION_NAME=your-memory-collection-name +``` + +> Conversation Service 也支持备选环境变量 `ALIBABA_CLOUD_ACCESS_KEY_ID` / `ALIBABA_CLOUD_ACCESS_KEY_SECRET`,SDK 会按优先级自动查找。 +> + +### 2.4 Python 环境 ++ Python 3.10 及以上版本 + +--- + +## 3. 安装 +```bash +pip install agentrun-sdk +``` + +如果需要使用 Google ADK 集成,还需安装 ADK 及模型调用依赖: + +```bash +pip install google-adk litellm +``` + +--- + +## 4. 快速开始(Google ADK) +以下是一个最小可运行的示例,展示如何用 5 步将 Google ADK Agent 的会话持久化到 OTS。 + +> 示例中使用 DashScope 的 OpenAI 兼容接口,需要设置环境变量 `DASHSCOPE_API_KEY`。 +> + +```python +import asyncio +import os + +from google.adk.agents import Agent +from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner +from google.genai import types + +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSSessionService + +# ── Step 1: 初始化 SessionStore ────────────────────────────── +store = SessionStore.from_memory_collection( + os.environ["MEMORY_COLLECTION_NAME"] +) +store.init_tables() + +# ── Step 2: 创建 OTSSessionService ────────────────────────── +session_service = OTSSessionService(session_store=store) + +# ── Step 3: 创建 Agent + Runner ────────────────────────────── +model = LiteLlm( + model="openai/qwen3-max", + api_key=os.environ["DASHSCOPE_API_KEY"], + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +agent = Agent( + name="assistant", + model=model, + instruction="你是一个友好的中文智能助手。", +) +runner = Runner( + agent=agent, + app_name="my_app", + session_service=session_service, +) + + +# ── Step 4: 对话(自动持久化到 OTS) ──────────────────────── +async def main(): + # 创建会话 + session = await session_service.create_session( + app_name="my_app", + user_id="user_1", + ) + print(f"会话已创建: {session.id}") + + # 发送消息 + content = types.Content( + role="user", + parts=[types.Part(text="你好,介绍一下你自己")], + ) + async for event in runner.run_async( + user_id="user_1", + session_id=session.id, + new_message=content, + ): + if ( + event.is_final_response() + and event.content + and event.content.parts + ): + for part in event.content.parts: + if part.text: + print(f"Agent: {part.text}") + + # ── Step 5: 验证持久化 ─────────────────────────────────── + # 重新从 OTS 加载会话,确认事件已持久化 + loaded = await session_service.get_session( + app_name="my_app", + user_id="user_1", + session_id=session.id, + ) + print(f"持久化事件数: {len(loaded.events)}") + + +asyncio.run(main()) +``` + +运行后,您会看到 Agent 的回复以及持久化的事件数量。即使程序重启,再次通过 `get_session` 加载同一个 `session_id`,历史对话仍然存在。 + +--- + +## 5. Google ADK 详细指南 +### 5.1 初始化 SessionStore +SessionStore 是 Conversation Service 的核心入口。通过 MemoryCollection 名称初始化,SDK 会自动完成以下工作: + +1. 调用 AgentRun API 获取 MemoryCollection 配置 +2. 从中提取 TableStore 的 endpoint 和 instance_name +3. 构建 OTS 客户端 + +```python +from agentrun.conversation_service import SessionStore + +store = SessionStore.from_memory_collection("your-memory-collection-name") +``` + +初始化后,需要调用 `init_tables()` 创建所需的数据库表和索引。该方法是**幂等**的,表或索引已存在时会自动跳过,不会报错,可以安全地在每次启动时调用。 + +```python +store.init_tables() +``` + +`init_tables()` 会创建以下资源: + +| 资源 | 说明 | +| --- | --- | +| `conversation` 表 | 存储会话元信息(摘要、标签、时间戳等) | +| `event` 表 | 存储对话事件流(消息、工具调用等) | +| `state` 表 | 存储 Session 级状态 | +| `app_state` 表 | 存储 App 级状态 | +| `user_state` 表 | 存储 User 级状态 | +| `conversation_secondary_index` | 二级索引,支持按更新时间排序列出会话 | +| `conversation_search_index` | 多元索引,支持全文搜索和组合过滤 | +| `state_search_index` | 多元索引,支持按 session_id 独立查询状态 | + + +> **按需建表**:如果您的场景不需要全部功能,也可以分步创建: +> + +| 方法 | 创建的资源 | 适用场景 | +| --- | --- | --- | +| `init_core_tables()` | conversation + event + 二级索引 | 仅需会话和事件,无三级 State | +| `init_state_tables()` | state + app_state + user_state | 仅补建 State 表 | +| `init_search_index()` | conversation + state 多元索引 | 仅补建搜索索引 | +| `init_tables()` | 以上全部 | 推荐,一次创建所有资源 | + + +#### 异步初始化 +SessionStore 的所有方法均提供异步版本(方法名加 `_async` 后缀): + +```python +store = await SessionStore.from_memory_collection_async( + "your-memory-collection-name" +) +await store.init_tables_async() +``` + +#### 表名前缀 +如果多个应用共用同一个 OTS 实例,可以通过 `table_prefix` 参数隔离表名: + +```python +store = SessionStore.from_memory_collection( + "your-memory-collection-name", + table_prefix="myapp_", +) +# 创建的表名为: myapp_conversation, myapp_event, myapp_state, ... +``` + +### 5.2 创建 OTSSessionService +`OTSSessionService` 是 Google ADK `BaseSessionService` 的 OTS 实现。将它传给 ADK 的 `Runner`,即可让 ADK 的会话自动持久化到 OTS。 + +```python +from agentrun.conversation_service.adapters import OTSSessionService + +session_service = OTSSessionService(session_store=store) +``` + +然后将 `session_service` 传给 `Runner`: + +```python +from google.adk.runners import Runner + +runner = Runner( + agent=agent, + app_name="my_app", + session_service=session_service, +) +``` + +此后,通过 `runner.run_async()` 进行的所有对话都会自动持久化到 OTS,包括: + ++ 用户消息 ++ Agent 回复 ++ 工具调用(function_call)和工具返回(function_response) ++ State 变更(state_delta) + +### 5.3 Session 管理 +#### 创建 Session +```python +session = await session_service.create_session( + app_name="my_app", + user_id="user_1", + session_id="custom-session-id", # 可选,不传则自动生成 UUID + state={ # 可选,初始状态 + "app:model_name": "qwen-max", # app 级状态(app: 前缀) + "user:language": "zh-CN", # user 级状态(user: 前缀) + "turn_count": 0, # session 级状态(无前缀) + }, +) +print(f"Session ID: {session.id}") +``` + +**参数说明:** + +| 参数 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `app_name` | str | 是 | 应用名称,对应 OTS 中的 `agent_id` | +| `user_id` | str | 是 | 用户 ID | +| `session_id` | str | 否 | 会话 ID,不传则自动生成 UUID | +| `state` | dict | 否 | 初始状态,会根据 key 前缀自动拆分到三级 State | + + +> **session_id 的生成策略**:在 Server 场景中,通常由客户端通过 HTTP Header 传入 `session_id`,以便同一用户的多轮对话关联到同一个会话。如果不传,每次请求会创建一个新的独立会话。 +> + +#### 获取 Session +```python +session = await session_service.get_session( + app_name="my_app", + user_id="user_1", + session_id="your-session-id", +) + +if session is None: + print("会话不存在") +else: + print(f"事件数: {len(session.events)}") + print(f"当前状态: {session.state}") +``` + +返回的 `session` 对象包含: + +| 属性 | 类型 | 说明 | +| --- | --- | --- | +| `id` | str | 会话 ID | +| `app_name` | str | 应用名称 | +| `user_id` | str | 用户 ID | +| `events` | list[Event] | 完整的 ADK Event 列表(按时间正序) | +| `state` | dict | 合并后的三级状态(详见 5.4 节) | +| `last_update_time` | float | 最后更新时间(Unix 秒级时间戳) | + + +##### 控制返回的事件数量 +当会话事件很多时,可以通过 `GetSessionConfig` 控制只返回最近 N 条事件,避免一次性加载过多数据: + +```python +from google.adk.sessions.base_session_service import GetSessionConfig + +session = await session_service.get_session( + app_name="my_app", + user_id="user_1", + session_id="your-session-id", + config=GetSessionConfig(num_recent_events=20), +) +# session.events 只包含最近 20 条事件 +``` + +也可以通过 `after_timestamp` 只返回指定时间之后的事件: + +```python +import time + +one_hour_ago = time.time() - 3600 +session = await session_service.get_session( + app_name="my_app", + user_id="user_1", + session_id="your-session-id", + config=GetSessionConfig(after_timestamp=one_hour_ago), +) +``` + +#### 列出 Session +列出指定用户的所有会话,按最后更新时间倒序排列: + +```python +response = await session_service.list_sessions( + app_name="my_app", + user_id="user_1", +) +for s in response.sessions: + print(f"Session: {s.id}, 最后更新: {s.last_update_time}") +``` + +也可以不传 `user_id`,列出该应用下所有用户的会话: + +```python +response = await session_service.list_sessions( + app_name="my_app", + user_id=None, +) +``` + +> `list_sessions` 返回的 Session 对象**不包含** events 和 state(出于性能考虑),仅包含元信息。如果需要完整数据,请对感兴趣的 Session 调用 `get_session`。 +> + +#### 删除 Session +删除会话时会**级联删除**该会话下的所有事件和 Session 级状态: + +```python +await session_service.delete_session( + app_name="my_app", + user_id="user_1", + session_id="your-session-id", +) +``` + +删除顺序为 Event -> State -> Session 元数据。如果中间步骤失败,下次重试可继续清理(幂等安全)。 + +> 删除 Session 不会影响 App 级和 User 级的状态。 +> + +#### 同步方法 +`OTSSessionService` 的所有方法都提供同步版本,方法名加 `_sync` 后缀: + +```python +session = session_service.create_session_sync( + app_name="my_app", user_id="user_1" +) +session = session_service.get_session_sync( + app_name="my_app", user_id="user_1", session_id="xxx" +) +response = session_service.list_sessions_sync( + app_name="my_app", user_id="user_1" +) +session_service.delete_session_sync( + app_name="my_app", user_id="user_1", session_id="xxx" +) +``` + +### 5.4 三级 State 机制 +Google ADK 定义了三级 State 作用域,Conversation Service 将它们分别持久化到不同的 OTS 表中: + +```plain +┌─────────────────────────────────────────────────────────────────────┐ +│ 合并后的 session.state │ +│ │ +│ ┌──────────────────┐ │ +│ │ App State │ app_state 表 (agent_id) │ +│ │ app:model_name │ 所有用户、所有会话共享 │ +│ └────────┬─────────┘ │ +│ │ 覆盖 │ +│ ┌────────▼─────────┐ │ +│ │ User State │ user_state 表 (agent_id, user_id) │ +│ │ user:language │ 同一用户的所有会话共享 │ +│ └────────┬─────────┘ │ +│ │ 覆盖 │ +│ ┌────────▼─────────┐ │ +│ │ Session State │ state 表 (agent_id, user_id, session_id) │ +│ │ turn_count │ 仅当前会话可见 │ +│ └──────────────────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +#### Key 前缀约定 +ADK 通过 key 的前缀来区分 State 的作用域: + +| 前缀 | 作用域 | 存储位置 | 示例 | +| --- | --- | --- | --- | +| `app:` | App 级 | `app_state` 表 | `app:model_name`、`app:total_queries` | +| `user:` | User 级 | `user_state` 表 | `user:language`、`user:preferences` | +| 无前缀 | Session 级 | `state` 表 | `turn_count`、`last_reply` | +| `temp:` | 临时状态 | 仅内存,不持久化 | `temp:processing` | + + +#### State 合并规则 +当通过 `get_session` 加载会话时,三级 State 会按 **App -> User -> Session** 的顺序浅合并(后者覆盖前者)。返回的 `session.state` 是合并后的完整字典。 + +例如,如果三级 State 分别为: + +```python +# app_state 表 +{"model_name": "qwen-max", "version": "1.0"} + +# user_state 表 +{"language": "zh-CN"} + +# state 表 (session) +{"turn_count": 3, "last_reply": "北京今天晴朗"} +``` + +则 `session.state` 的内容为: + +```python +{ + "turn_count": 3, + "last_reply": "北京今天晴朗", + "user:language": "zh-CN", + "app:model_name": "qwen-max", + "app:version": "1.0", +} +``` + +#### 通过 state_delta 更新 State +在 ADK 中,Agent 可以通过事件的 `actions.state_delta` 自动更新 State。`OTSSessionService` 会自动将 delta 按前缀拆分并持久化到对应的 State 表: + +```python +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions + +event = Event( + invocation_id="inv-001", + author="my_agent", + content=..., + actions=EventActions( + state_delta={ + "turn_count": 1, # -> state 表 + "app:total_queries": 42, # -> app_state 表 + "user:last_query_city": "北京", # -> user_state 表 + }, + ), +) +``` + +#### output_key 自动写入 +ADK Agent 支持 `output_key` 参数,会自动将 Agent 的最终回复写入 `session.state[output_key]`。搭配 `OTSSessionService`,这个值会自动持久化到 OTS 的 Session State 中: + +```python +agent = Agent( + name="assistant", + model=model, + instruction="你是一个智能助手。", + output_key="last_reply", # Agent 回复自动写入 state["last_reply"] +) +``` + +后续通过 `get_session` 加载会话时,可以从 `session.state["last_reply"]` 读取上一轮的 Agent 回复。 + +#### 手动更新 State +除了通过 ADK 的 `state_delta` 自动更新外,也可以直接调用 `SessionStore` 手动更新指定级别的 State。这在 Server 的 `invoke_agent` 回调中常用: + +```python +# 更新 Session 级状态 +await store.update_session_state_async( + "my_app", "user_1", "session_id", + {"turn_count": 5, "last_user_input": "今天天气如何"}, +) + +# 更新 User 级状态 +await store.update_user_state_async( + "my_app", "user_1", + {"language": "en-US"}, +) + +# 更新 App 级状态 +await store.update_app_state_async( + "my_app", + {"model_name": "qwen-turbo"}, +) +``` + +State 更新采用**浅合并**语义:只覆盖提供的 key,未提供的 key 保持不变。将值设为 `None` 可以删除对应的 key。 + +### 5.5 结合 AgentRunServer 部署 +在生产环境中,通常将 ADK Agent 部署为 HTTP 服务。以下是结合 `AgentRunServer` 的完整示例: + +```python +import os +import uuid + +from google.adk.agents import Agent +from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner +from google.genai import types + +from agentrun import AgentRequest +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSSessionService +from agentrun.server import AgentRunServer + +APP_NAME = "my_chat_server" + +# ── 初始化 ──────────────────────────────────────────────────── + +store = SessionStore.from_memory_collection( + os.environ["MEMORY_COLLECTION_NAME"] +) +store.init_tables() + +session_service = OTSSessionService(session_store=store) + +model = LiteLlm( + model="openai/qwen3-max", + api_key=os.environ["DASHSCOPE_API_KEY"], + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +agent = Agent( + name="assistant", + model=model, + instruction="你是一个友好的中文智能助手。", + output_key="last_reply", +) +runner = Runner( + agent=agent, + app_name=APP_NAME, + session_service=session_service, +) + + +# ── 核心处理函数 ────────────────────────────────────────────── + +async def invoke_agent(req: AgentRequest): + # 从 HTTP Header 获取 session_id 和 user_id + headers = dict(req.raw_request.headers) if req.raw_request else {} + session_id = ( + headers.get("x-agentrun-session-id") + or f"chat_{uuid.uuid4().hex[:8]}" + ) + user_id = headers.get("x-agentrun-user-id") or "default_user" + + # 获取或创建 Session + session = await session_service.get_session( + app_name=APP_NAME, user_id=user_id, session_id=session_id, + ) + if session is None: + session = await session_service.create_session( + app_name=APP_NAME, user_id=user_id, session_id=session_id, + ) + + # 提取用户消息 + last_user_text = "" + for msg in reversed(req.messages): + if msg.role == "user": + last_user_text = msg.content or "" + break + + if not last_user_text: + yield "请输入您的问题。" + return + + # 调用 ADK Runner 流式输出 + content = types.Content( + role="user", + parts=[types.Part(text=last_user_text)], + ) + async for event in runner.run_async( + user_id=user_id, + session_id=session.id, + new_message=content, + ): + if ( + event.is_final_response() + and event.content + and event.content.parts + ): + for part in event.content.parts: + if part.text: + yield part.text + + +# ── 启动服务 ────────────────────────────────────────────────── + +if __name__ == "__main__": + server = AgentRunServer( + invoke_agent=invoke_agent, + memory_collection_name=os.environ["MEMORY_COLLECTION_NAME"], + ) + server.start(port=9000) +``` + +**客户端请求时的 Header 约定:** + +| Header | 说明 | 必填 | +| --- | --- | --- | +| `X-AgentRun-Session-ID` | 会话 ID,用于关联多轮对话 | 否(不传则自动生成新会话) | +| `X-AgentRun-User-ID` | 用户 ID | 否(默认 `default_user`) | + + +--- + +## 6. LangChain 集成 +> Coming Soon — LangChain 适配器 `OTSChatMessageHistory` 正在开发中,将支持与 `RunnableWithMessageHistory` 无缝集成。 +> + +--- + +## 7. LangGraph 集成 +> Coming Soon — LangGraph 适配器正在规划中。 +> + +--- + +## 8. 高级功能 +### 8.1 多元索引搜索 +Conversation Service 为会话表和状态表创建了多元索引(Search Index),支持不受主键顺序限制的灵活查询。 + +#### 搜索会话 +通过 `SessionStore.search_sessions` 可以按多种条件组合搜索会话: + +```python +results, total = store.search_sessions( + "my_app", + user_id="user_1", # 可选,精确匹配用户 + summary_keyword="天气", # 可选,全文搜索摘要 + labels='["重要"]', # 可选,精确匹配标签 + framework="adk", # 可选,精确匹配框架 + updated_after=1700000000000000, # 可选,仅返回此时间后更新的(纳秒时间戳) + updated_before=None, # 可选,仅返回此时间前更新的 + is_pinned=True, # 可选,是否置顶 + limit=20, # 每页条数,默认 20 + offset=0, # 分页偏移 +) + +print(f"共 {total} 条结果") +for session in results: + print(f" {session.session_id}: {session.summary}") +``` + +异步版本: + +```python +results, total = await store.search_sessions_async("my_app", summary_keyword="天气") +``` + +#### 按 session_id 独立查询状态 +State 表的多元索引(`state_search_index`)支持按 `session_id` 独立精确查询,不需要提供 `agent_id` 和 `user_id` 前缀。这在需要跨用户定位特定会话状态时非常有用。 + +### 8.2 表名前缀隔离 +在多租户场景或需要区分不同环境(开发/测试/生产)时,可以通过 `table_prefix` 参数为所有表名添加前缀: + +```python +# 开发环境 +dev_store = SessionStore.from_memory_collection( + "my-collection", table_prefix="dev_" +) +# 表名:dev_conversation, dev_event, dev_state, ... + +# 生产环境 +prod_store = SessionStore.from_memory_collection( + "my-collection", table_prefix="prod_" +) +# 表名:prod_conversation, prod_event, prod_state, ... +``` + +不同前缀的表完全独立,互不影响。 + +--- + +## 9. 常见问题 +### init_tables() 需要每次启动都调用吗? +可以。`init_tables()` 是**幂等**操作,表或索引已存在时会自动跳过。建议在应用启动时调用,确保所需资源就绪。 + +### 多元索引创建后多久生效? +多元索引创建后需要数秒到数十秒才能完全生效(取决于数据量)。在索引生效前,`search_sessions` 可能返回不完整的结果。建议首次创建索引后等待几秒再进行搜索操作。 + +### 为什么 list_sessions 返回的 Session 没有 events 和 state? +这是出于性能考虑。`list_sessions` 用于展示会话列表,只返回元信息(ID、更新时间等)。如果需要某个 Session 的完整事件和状态,请调用 `get_session`。 + +### Session 删除后 App 级和 User 级状态还在吗? +是的。`delete_session` 只删除 Session 本身及其关联的事件和 Session 级状态。App 级状态和 User 级状态的生命周期独立于单个 Session,不会被级联删除。 + +### 如何处理并发写入冲突? +State 更新使用**乐观锁**机制(version 字段)。如果两个请求同时更新同一行,后到的请求会因 version 不匹配而失败。在高并发场景下,建议在业务层实现重试逻辑。 + +### 支持哪些模型? +Conversation Service 不限制模型选择。示例中使用的是通义千问(通过 DashScope API 调用),您可以替换为任何 ADK 支持的模型(如 Gemini、OpenAI 等)。模型选择由 ADK 的 `Agent` 配置决定,与 Conversation Service 无关。 + diff --git a/examples/conversation_service_langchain_server.py b/examples/conversation_service_langchain_server.py new file mode 100644 index 0000000..214160e --- /dev/null +++ b/examples/conversation_service_langchain_server.py @@ -0,0 +1,181 @@ +"""LangChain Agent Server —— 使用 OTSChatMessageHistory 持久化消息历史。 + +集成步骤: + Step 1: 初始化 SessionStore(OTS 后端)+ 创建 LangChain 所需表和索引 + Step 2: 构建 LangChain Chain(ChatOpenAI + SystemMessage) + Step 3: 实现 invoke_agent,将 AgentRequest 转为 LangChain 调用并流式输出 + Step 4: 通过 AgentRunServer 启动 HTTP 服务 + +使用方式: + uv run --env-file .env python examples/conversation_service_langchain_server.py + + # 请求示例(curl): + curl -X POST http://localhost:9002/openai/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-AgentRun-Session-ID: my-session-1" \ + -H "X-AgentRun-User-ID: user-1" \ + -d '{"model":"qwen-max","stream":true,"messages":[{"role":"user","content":"你好"}]}' +""" + +from __future__ import annotations + +import os +import sys +import uuid + +from dotenv import load_dotenv +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI + +from agentrun import AgentRequest +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSChatMessageHistory +from agentrun.server import AgentRunServer + +load_dotenv() + +# ── 配置参数 ────────────────────────────────────────────────── +AGENT_ID = "langchain_chat_server" +MEMORY_COLLECTION_NAME = os.getenv("MEMORY_COLLECTION_NAME", "") +DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY", "") + +if not MEMORY_COLLECTION_NAME: + print("ERROR: 请设置环境变量 MEMORY_COLLECTION_NAME") + sys.exit(1) +if not DASHSCOPE_API_KEY: + print("ERROR: 请设置环境变量 DASHSCOPE_API_KEY") + sys.exit(1) + + +# ── Step 1: 初始化 SessionStore + 创建 LangChain 所需表和索引 ─ + +store = SessionStore.from_memory_collection(MEMORY_COLLECTION_NAME) +store.init_langchain_tables() + +# ── Step 2: 构建 LangChain Chain ───────────────────────────── + +SYSTEM_PROMPT = "你是一个友好的中文智能助手,请简洁、准确地回答用户问题。" + +llm = ChatOpenAI( + model="qwen-max", + api_key=DASHSCOPE_API_KEY, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + streaming=True, +) + + +# ── 辅助函数 ────────────────────────────────────────────────── + + +def _get_session_id(req: AgentRequest) -> str: + """从请求 header 中提取 session_id,没有则生成一个。""" + raw_headers: dict[str, str] = {} + if hasattr(req, "raw_request") and req.raw_request: + raw_headers = dict(req.raw_request.headers) + + return ( + raw_headers.get("X-AgentRun-Session-ID") + or raw_headers.get("x-agentrun-session-id") + or raw_headers.get("X-Agentrun-Session-Id") + or f"session_{uuid.uuid4().hex[:8]}" + ) + + +def _get_user_id(req: AgentRequest) -> str: + """从请求 header 中提取 user_id。""" + raw_headers: dict[str, str] = {} + if hasattr(req, "raw_request") and req.raw_request: + raw_headers = dict(req.raw_request.headers) + + return ( + raw_headers.get("X-AgentRun-User-ID") + or raw_headers.get("x-agentrun-user-id") + or "default_user" + ) + + +# ── Step 3: invoke_agent —— 核心 Server 处理函数 ───────────── + + +async def invoke_agent(req: AgentRequest): + """将 AgentRequest 转换为 LangChain 调用并流式输出。 + + 流程: + 1. 从 header 提取 session_id / user_id + 2. 创建 OTSChatMessageHistory(自动关联 / 创建 Session) + 3. 加载历史消息,展示 OTS 持久化状态 + 4. 提取最后一条用户消息,写入历史 + 5. 拼接 SystemMessage + 历史消息,调用 LLM 流式输出 + 6. 将 AI 回复写入历史(自动持久化到 OTS) + """ + session_id = _get_session_id(req) + user_id = _get_user_id(req) + + # 创建消息历史(自动关联 / 创建 Session) + history = OTSChatMessageHistory( + session_store=store, + agent_id=AGENT_ID, + user_id=user_id, + session_id=session_id, + ) + + # 展示当前持久化状态 + existing_messages = history.messages + print( + f"[Session {session_id}] " + f"user={user_id}, " + f"已有 {len(existing_messages)} 条消息" + ) + + # 提取最后一条用户消息 + last_user_text = "" + for msg in reversed(req.messages): + if msg.role == "user": + last_user_text = msg.content or "" + break + + if not last_user_text: + yield "请输入您的问题。" + return + + # 将用户消息写入历史 + history.add_message(HumanMessage(content=last_user_text)) + + # 拼接完整消息列表:SystemMessage + 历史消息 + full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + history.messages + + # 调用 LLM 流式输出 + try: + full_response = "" + async for chunk in llm.astream(full_messages): + text = chunk.content + if isinstance(text, str) and text: + full_response += text + yield text + + # 将 AI 回复写入历史(持久化到 OTS) + if full_response: + history.add_message(AIMessage(content=full_response)) + + print( + f"[Session {session_id}] 回复完成,当前共" + f" {len(history.messages)} 条消息" + ) + + except Exception as e: + print(f"LangChain 执行异常: {e}") + raise Exception("Internal Error") + + +# ── Step 4: 启动 Server ────────────────────────────────────── + +if __name__ == "__main__": + server = AgentRunServer( + invoke_agent=invoke_agent, + memory_collection_name=MEMORY_COLLECTION_NAME, + ) + print(f"Agent ID: {AGENT_ID}") + print(f"Memory Collection: {MEMORY_COLLECTION_NAME}") + print("请求时通过 X-AgentRun-Session-ID header 指定会话 ID") + print("请求时通过 X-AgentRun-User-ID header 指定 user_id") + server.start(port=9002) diff --git a/examples/conversation_service_langgraph_server.py b/examples/conversation_service_langgraph_server.py new file mode 100644 index 0000000..28bf3d8 --- /dev/null +++ b/examples/conversation_service_langgraph_server.py @@ -0,0 +1,203 @@ +"""LangGraph Agent Server —— 使用 OTSCheckpointSaver 持久化 checkpoint。 + +集成步骤: + Step 1: 初始化 SessionStore(OTS 后端)+ 创建 checkpoint 表 + Step 2: 创建 OTSCheckpointSaver + Step 3: 构建 LangGraph StateGraph + 编译(传入 checkpointer) + Step 4: 实现 invoke_agent,将 AgentRequest 转为 LangGraph 调用并流式输出 + Step 5: 通过 AgentRunServer 启动 HTTP 服务 + +使用方式: + uv run --env-file .env python examples/conversation_service_langgraph_server.py + + # 请求示例(curl): + curl -X POST http://localhost:9001/openai/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-AgentRun-Session-ID: my-thread-2222" \ + -H "X-AgentRun-User-ID: user-1" \ + -d '{"model":"qwen3-max","stream":true,"messages":[{"role":"user","content":"你好"}]}' +""" + +from __future__ import annotations + +import os +import sys +from typing import Annotated, Any +import uuid + +from dotenv import load_dotenv +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_openai import ChatOpenAI +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages + +from agentrun import AgentRequest +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSCheckpointSaver +from agentrun.server import AgentRunServer + +load_dotenv() + +# ── 配置参数 ────────────────────────────────────────────────── +AGENT_ID = "langgraph_chat_server" +MEMORY_COLLECTION_NAME = os.getenv("MEMORY_COLLECTION_NAME", "") +DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY", "") + +if not MEMORY_COLLECTION_NAME: + print("ERROR: 请设置环境变量 MEMORY_COLLECTION_NAME") + sys.exit(1) +if not DASHSCOPE_API_KEY: + print("ERROR: 请设置环境变量 DASHSCOPE_API_KEY") + sys.exit(1) + + +# ── 定义 State ─────────────────────────────────────────────── + +from typing import TypedDict + + +class ChatState(TypedDict): + messages: Annotated[list[BaseMessage], add_messages] + + +# ── 定义 Graph 节点 ────────────────────────────────────────── + +llm = ChatOpenAI( + model="qwen-max", + api_key=DASHSCOPE_API_KEY, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) + + +def chat_node(state: ChatState) -> dict[str, Any]: + """调用 LLM 生成回复。""" + response = llm.invoke(state["messages"]) + return {"messages": [response]} + + +# ── Step 1: 初始化 SessionStore + 创建 checkpoint 表 ───────── + +store = SessionStore.from_memory_collection(MEMORY_COLLECTION_NAME) +store.init_langgraph_tables() + +# ── Step 2: 创建 OTSCheckpointSaver ────────────────────────── + +checkpointer = OTSCheckpointSaver(store, agent_id=AGENT_ID) + +# ── Step 3: 构建 Graph ─────────────────────────────────────── + +graph = StateGraph(ChatState) +graph.add_node("chat", chat_node) +graph.add_edge(START, "chat") +graph.add_edge("chat", END) +app = graph.compile(checkpointer=checkpointer) + + +# ── 辅助函数 ────────────────────────────────────────────────── + + +def _get_thread_id(req: AgentRequest) -> str: + """从请求 header 中提取 thread_id(复用 session-id header)。""" + raw_headers: dict[str, str] = {} + if hasattr(req, "raw_request") and req.raw_request: + raw_headers = dict(req.raw_request.headers) + + return ( + raw_headers.get("X-AgentRun-Session-ID") + or raw_headers.get("x-agentrun-session-id") + or raw_headers.get("X-Agentrun-Session-Id") + or f"thread_{uuid.uuid4().hex[:8]}" + ) + + +def _get_user_id(req: AgentRequest) -> str: + """从请求 header 中提取 user_id。""" + raw_headers: dict[str, str] = {} + if hasattr(req, "raw_request") and req.raw_request: + raw_headers = dict(req.raw_request.headers) + + return ( + raw_headers.get("X-AgentRun-User-ID") + or raw_headers.get("x-agentrun-user-id") + or "default_user" + ) + + +# ── Step 4: invoke_agent —— 核心 Server 处理函数 ───────────── + + +async def invoke_agent(req: AgentRequest): + """将 AgentRequest 转换为 LangGraph 调用并流式输出。 + + 流程: + 1. 从 header 提取 thread_id / user_id + 2. 用 thread_id 作为 LangGraph 的 configurable.thread_id + 3. 提取最后一条用户消息 + 4. 调用 graph.astream() 流式输出 + 5. checkpoint 自动持久化到 OTS(由 OTSCheckpointSaver 处理) + """ + thread_id = _get_thread_id(req) + user_id = _get_user_id(req) + config = { + "configurable": {"thread_id": thread_id}, + "metadata": {"user_id": user_id}, + } + + # 展示当前 checkpoint 状态(体现 OTS 持久化能力) + existing = await checkpointer.aget_tuple(config) + if existing: + msg_count = len( + existing.checkpoint.get("channel_values", {}).get("messages", []) + ) + print( + f"[Thread {thread_id}] " + f"user={user_id}, " + f"已有 {msg_count} 条消息, " + f"checkpoint_id={existing.checkpoint['id']}" + ) + else: + print(f"[Thread {thread_id}] user={user_id}, 新会话") + + # 提取最后一条用户消息 + last_user_text = "" + for msg in reversed(req.messages): + if msg.role == "user": + last_user_text = msg.content or "" + break + + if not last_user_text: + yield "请输入您的问题。" + return + + # 调用 LangGraph 流式输出 + try: + async for event in app.astream( + {"messages": [HumanMessage(content=last_user_text)]}, + config=config, + stream_mode="values", + ): + messages = event.get("messages", []) + if messages: + last_msg = messages[-1] + if isinstance(last_msg, AIMessage) and last_msg.content: + yield last_msg.content + + print(f"[Thread {thread_id}] 回复完成") + + except Exception as e: + print(f"LangGraph 执行异常: {e}") + raise Exception("Internal Error") + + +# ── Step 5: 启动 Server ────────────────────────────────────── + +if __name__ == "__main__": + server = AgentRunServer( + invoke_agent=invoke_agent, + memory_collection_name=MEMORY_COLLECTION_NAME, + ) + print(f"Agent ID: {AGENT_ID}") + print(f"Memory Collection: {MEMORY_COLLECTION_NAME}") + print("请求时通过 X-AgentRun-Session-ID header 指定 thread_id") + print("请求时通过 X-AgentRun-User-ID header 指定 user_id") + server.start(port=9001) diff --git a/tests/unittests/integration/test_langchain_agui_integration.py b/tests/unittests/integration/test_langchain_agui_integration.py index ef0c076..500c86c 100644 --- a/tests/unittests/integration/test_langchain_agui_integration.py +++ b/tests/unittests/integration/test_langchain_agui_integration.py @@ -689,7 +689,9 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", + "content": ( + "查询当前的时间,并获取天气信息,同时输出我的密钥信息" + ), }], "stream": True, }, @@ -755,7 +757,9 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", + "content": ( + "查询当前的时间,并获取天气信息,同时输出我的密钥信息" + ), }], "stream": True, }, diff --git a/tests/unittests/toolset/api/test_openapi.py b/tests/unittests/toolset/api/test_openapi.py index bb32eac..0f2e82d 100644 --- a/tests/unittests/toolset/api/test_openapi.py +++ b/tests/unittests/toolset/api/test_openapi.py @@ -548,7 +548,9 @@ def test_post_with_ref_schema(self): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateOrderRequest" + "$ref": ( + "#/components/schemas/CreateOrderRequest" + ) } } }, @@ -759,7 +761,9 @@ def test_invalid_ref_gracefully_handled(self): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/NonExistent" + "$ref": ( + "#/components/schemas/NonExistent" + ) } } } @@ -792,7 +796,9 @@ def test_external_ref_not_resolved(self): "content": { "application/json": { "schema": { - "$ref": "https://example.com/schemas/external.json" + "$ref": ( + "https://example.com/schemas/external.json" + ) } } } @@ -912,7 +918,9 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateOrderRequest" + "$ref": ( + "#/components/schemas/CreateOrderRequest" + ) } } }, @@ -948,7 +956,9 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateOrderStatusRequest" + "$ref": ( + "#/components/schemas/UpdateOrderStatusRequest" + ) } } }, @@ -1219,7 +1229,9 @@ def test_tool_schema(self): "openapi": "3.0.1", "info": {"title": "Test", "version": "1.0"}, "servers": [{ - "url": "https://1431999136518149.agentrun-data.cn-hangzhou.aliyuncs.com/tools/test/" + "url": ( + "https://1431999136518149.agentrun-data.cn-hangzhou.aliyuncs.com/tools/test/" + ) }], "paths": { "/invoke": { From 49615aeb3f97f6957ffb7fc1d739e55f2f7e3b78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AF=92=E5=85=89?= <2510399607@qq.com> Date: Fri, 10 Apr 2026 10:52:41 +0800 Subject: [PATCH 2/3] fix: address code review feedback on LangGraph checkpoint PR - Fix sync/async section comment headers in ots_backend.py and session_store.py - Add langgraph optional dependency in pyproject.toml - Improve build_ots_clients return type annotation with TYPE_CHECKING import - Handle missing state table gracefully in init_search_index Made-with: Cursor --- agentrun/conversation_service/ots_backend.py | 8 ++++++-- agentrun/conversation_service/session_store.py | 12 ++++++++---- agentrun/conversation_service/utils.py | 10 ++++++++-- pyproject.toml | 5 +++++ 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/agentrun/conversation_service/ots_backend.py b/agentrun/conversation_service/ots_backend.py index 0135718..d79f7c5 100644 --- a/agentrun/conversation_service/ots_backend.py +++ b/agentrun/conversation_service/ots_backend.py @@ -2542,7 +2542,7 @@ async def delete_state_row_async( await self._async_client.delete_row(table_name, row, condition) # ----------------------------------------------------------------------- - # Checkpoint CRUD(LangGraph)(异步) + # State CRUD(同步) # ----------------------------------------------------------------------- def delete_state_row( @@ -2561,7 +2561,7 @@ def delete_state_row( self._client.delete_row(table_name, row, condition) # ----------------------------------------------------------------------- - # Checkpoint CRUD(LangGraph)(同步) + # Checkpoint CRUD(LangGraph)(异步) # ----------------------------------------------------------------------- async def put_checkpoint_async( @@ -2591,6 +2591,10 @@ async def put_checkpoint_async( condition = Condition(RowExistenceExpectation.IGNORE) await self._async_client.put_row(self._checkpoint_table, row, condition) + # ----------------------------------------------------------------------- + # Checkpoint CRUD(LangGraph)(同步) + # ----------------------------------------------------------------------- + def put_checkpoint( self, thread_id: str, diff --git a/agentrun/conversation_service/session_store.py b/agentrun/conversation_service/session_store.py index 9d24807..1340b67 100644 --- a/agentrun/conversation_service/session_store.py +++ b/agentrun/conversation_service/session_store.py @@ -134,7 +134,7 @@ async def init_langgraph_tables_async(self) -> None: await self._backend.init_checkpoint_tables_async() # ------------------------------------------------------------------- - # Checkpoint 管理(LangGraph)(异步) + # Checkpoint 管理(LangGraph)(同步) # ------------------------------------------------------------------- def init_langgraph_tables(self) -> None: @@ -149,7 +149,7 @@ def init_langgraph_tables(self) -> None: self._backend.init_checkpoint_tables() # ------------------------------------------------------------------- - # Checkpoint 管理(LangGraph)(同步) + # Checkpoint 管理(LangGraph)(异步) # ------------------------------------------------------------------- async def put_checkpoint_async( @@ -366,7 +366,7 @@ async def delete_thread_checkpoints_async( await self._backend.delete_thread_checkpoints_async(thread_id) # ------------------------------------------------------------------- - # Session 管理(异步)/ Session management (async) + # Checkpoint 清理(同步) # ------------------------------------------------------------------- def delete_thread_checkpoints( @@ -377,7 +377,7 @@ def delete_thread_checkpoints( self._backend.delete_thread_checkpoints(thread_id) # ------------------------------------------------------------------- - # Session 管理(同步)/ Session management (async) + # Session 管理(异步)/ Session management (async) # ------------------------------------------------------------------- async def create_session_async( @@ -426,6 +426,10 @@ async def create_session_async( await self._backend.put_session_async(session) return session + # ------------------------------------------------------------------- + # Session 管理(同步)/ Session management (sync) + # ------------------------------------------------------------------- + def create_session( self, agent_id: str, diff --git a/agentrun/conversation_service/utils.py b/agentrun/conversation_service/utils.py index 87d3384..89d6ac1 100644 --- a/agentrun/conversation_service/utils.py +++ b/agentrun/conversation_service/utils.py @@ -7,7 +7,13 @@ import json import time -from typing import Any +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from tablestore import ( + AsyncOTSClient, # type: ignore[import-untyped] + OTSClient, + ) # OTS 单个属性列值上限为 2MB,留 0.5MB 余量(按字符数计) MAX_COLUMN_SIZE: int = 1_500_000 # 1.5M 字符 @@ -106,7 +112,7 @@ def build_ots_clients( instance_name: str, *, sts_token: str | None = None, -) -> tuple[Any, Any]: +) -> tuple[OTSClient, AsyncOTSClient]: """构建 OTSClient 和 AsyncOTSClient 实例。 独立于 codegen 模板,避免 AsyncOTSClient 被替换为 OTSClient。 diff --git a/pyproject.toml b/pyproject.toml index e763967..f4aad0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,11 @@ tablestore = [ "tablestore>=6.1.0", ] +langgraph = [ + "langgraph>=0.2.0; python_version >= '3.10'", + "langchain-core>=0.3.0; python_version >= '3.10'", +] + [dependency-groups] dev = [ "coverage>=7.10.7", From a91c4947ab35249d041e593f131d1042e22aa95a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AF=92=E5=85=89?= <2510399607@qq.com> Date: Fri, 10 Apr 2026 13:32:19 +0800 Subject: [PATCH 3/3] fix: add RunnableConfig type annotation to fix mypy call-overload error Made-with: Cursor --- examples/conversation_service_langgraph_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/conversation_service_langgraph_server.py b/examples/conversation_service_langgraph_server.py index 28bf3d8..3aff60f 100644 --- a/examples/conversation_service_langgraph_server.py +++ b/examples/conversation_service_langgraph_server.py @@ -27,6 +27,7 @@ from dotenv import load_dotenv from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages @@ -138,7 +139,7 @@ async def invoke_agent(req: AgentRequest): """ thread_id = _get_thread_id(req) user_id = _get_user_id(req) - config = { + config: RunnableConfig = { "configurable": {"thread_id": thread_id}, "metadata": {"user_id": user_id}, }