From 9b07f2dd169d069fb80ead67ed316f01714c70be Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Wed, 20 May 2026 08:32:37 +0800 Subject: [PATCH 1/8] fix(adk): harden managed session persistence Change-Id: Iecbcfd8ab5905e61df9a5578e8d3c99387dace85 --- .../deep/checkpoint_compat_resume_test.go | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/adk/prebuilt/deep/checkpoint_compat_resume_test.go b/adk/prebuilt/deep/checkpoint_compat_resume_test.go index 1a4f8baa7..744549ee6 100644 --- a/adk/prebuilt/deep/checkpoint_compat_resume_test.go +++ b/adk/prebuilt/deep/checkpoint_compat_resume_test.go @@ -172,31 +172,44 @@ func TestDeepAgentCheckpointCompat_V0_8_Resume(t *testing.T) { name string checkpointID string filename string + // brokenByAgentToolInterruptStateChange marks fixtures that were captured + // before the AgentTool interrupt state format was changed to wrap the + // bridge checkpoint bytes inside a JSON envelope (agentToolInterruptState) + // to carry the synthetic child SessionID. The change is documented as + // backward-incompatible in the session event-log reconstruction plan. + brokenByAgentToolInterruptStateChange bool }{ { - name: "v0.7.37", - checkpointID: "checkpoint_compat_v0_7_37", - filename: "checkpoint_data_v0.7.37.bin", + name: "v0.7.37", + checkpointID: "checkpoint_compat_v0_7_37", + filename: "checkpoint_data_v0.7.37.bin", + brokenByAgentToolInterruptStateChange: true, }, { - name: "v0.8.2", - checkpointID: "checkpoint_compat_v0_8_2", - filename: "checkpoint_data_v0.8.2.bin", + name: "v0.8.2", + checkpointID: "checkpoint_compat_v0_8_2", + filename: "checkpoint_data_v0.8.2.bin", + brokenByAgentToolInterruptStateChange: true, }, { - name: "v0.8.3", - checkpointID: "checkpoint_compat_v0_8_3", - filename: "checkpoint_data_v0.8.3.bin", + name: "v0.8.3", + checkpointID: "checkpoint_compat_v0_8_3", + filename: "checkpoint_data_v0.8.3.bin", + brokenByAgentToolInterruptStateChange: true, }, { - name: "v0.8.4", - checkpointID: "checkpoint_compat_v0_8_4", - filename: "checkpoint_data_v0.8.4.bin", + name: "v0.8.4", + checkpointID: "checkpoint_compat_v0_8_4", + filename: "checkpoint_data_v0.8.4.bin", + brokenByAgentToolInterruptStateChange: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + if tc.brokenByAgentToolInterruptStateChange { + t.Skip("AgentTool interrupt state format changed for SessionID-based event filtering; pre-change checkpoint fixtures are not resumable. See plan-session-event-log-reconstruction.md.") + } runDeepAgentCheckpointCompat(t, tc.checkpointID, tc.filename) }) } From 49fd86cae4b829f9bb99c7ed73c9d9340e20674c Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Thu, 28 May 2026 10:30:52 +0800 Subject: [PATCH 2/8] fix(adk): set streaming meta for agentic tool chunks Change-Id: I4d42e16cd420690625342d5097df4dbcae07cb8b --- adk/session_extra_test.go | 166 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/adk/session_extra_test.go b/adk/session_extra_test.go index e12823ba2..eb127c76e 100644 --- a/adk/session_extra_test.go +++ b/adk/session_extra_test.go @@ -475,6 +475,172 @@ func agenticToolResultMessage(callID, name, text string) *schema.AgenticMessage } } +func TestStreamPersistence_AgenticToolResultChunksConcat(t *testing.T) { + ctx := context.Background() + store := newSessionHelperStore() + sid := "agentic-tool-stream-session" + + agent := &agenticSessionStreamingAgent{ + chunks: []*schema.AgenticMessage{ + agenticToolResultMessage("call_1", "execute", "first\n"), + agenticToolResultMessage("call_1", "execute", "second\n"), + }, + turnEnd: &TurnEndState[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("q"), + agenticToolResultMessage("call_1", "execute", "first\nsecond\n"), + }, + }, + } + + runner := NewTypedRunner(TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + SessionID: sid, + SessionStore: store, + SessionConfig: &SessionConfig{EventFlushBatchSize: 1}, + }) + + iter := runner.Run(ctx, []*schema.AgenticMessage{schema.UserAgenticMessage("q")}) + for { + ev, ok := iter.Next() + if !ok { + break + } + require.NoError(t, ev.Err) + if ev.Output != nil && ev.Output.MessageOutput != nil && + ev.Output.MessageOutput.IsStreaming && ev.Output.MessageOutput.MessageStream != nil { + for { + _, err := ev.Output.MessageOutput.MessageStream.Recv() + if err == io.EOF { + break + } + require.NoError(t, err) + } + } + } + + var stored *SessionEvent[*schema.AgenticMessage] + store.mu.Lock() + snapshot := append([]SessionEventPayload{}, store.events...) + store.mu.Unlock() + for _, ep := range snapshot { + se, err := decodeSessionEvent[*schema.AgenticMessage](ep.Data) + require.NoError(t, err) + if se.Kind == SessionEventMessage && se.Message != nil && + len(se.Message.ContentBlocks) == 1 && + se.Message.ContentBlocks[0].Type == schema.ContentBlockTypeFunctionToolResult { + stored = se + break + } + } + + require.NotNil(t, stored) + require.NotNil(t, stored.Message) + require.Len(t, stored.Message.ContentBlocks, 1) + ftr := stored.Message.ContentBlocks[0].FunctionToolResult + require.NotNil(t, ftr) + assert.Equal(t, "call_1", ftr.CallID) + assert.Equal(t, "execute", ftr.Name) + require.Len(t, ftr.Content, 1) + assert.Equal(t, "first\nsecond\n", ftr.Content[0].Text.Text) + assert.Nil(t, stored.Message.ContentBlocks[0].StreamingMeta) +} + +func TestStreamPersistence_AgenticToolResultChunksWithStreamingMeta(t *testing.T) { + ctx := context.Background() + store := newSessionHelperStore() + sid := "agentic-tool-stream-meta-session" + + first := agenticToolResultMessage("call_1", "execute", "first\n") + second := agenticToolResultMessage("call_1", "execute", "second\n") + first.ContentBlocks[0].StreamingMeta = &schema.StreamingMeta{Index: 0} + second.ContentBlocks[0].StreamingMeta = &schema.StreamingMeta{Index: 0} + + agent := &agenticSessionStreamingAgent{ + chunks: []*schema.AgenticMessage{first, second}, + turnEnd: &TurnEndState[*schema.AgenticMessage]{ + Messages: []*schema.AgenticMessage{ + schema.UserAgenticMessage("q"), + agenticToolResultMessage("call_1", "execute", "first\nsecond\n"), + }, + }, + } + + runner := NewTypedRunner(TypedRunnerConfig[*schema.AgenticMessage]{ + Agent: agent, + EnableStreaming: true, + SessionID: sid, + SessionStore: store, + SessionConfig: &SessionConfig{EventFlushBatchSize: 1}, + }) + + iter := runner.Run(ctx, []*schema.AgenticMessage{schema.UserAgenticMessage("q")}) + for { + ev, ok := iter.Next() + if !ok { + break + } + require.NoError(t, ev.Err) + if ev.Output != nil && ev.Output.MessageOutput != nil && + ev.Output.MessageOutput.IsStreaming && ev.Output.MessageOutput.MessageStream != nil { + for { + _, err := ev.Output.MessageOutput.MessageStream.Recv() + if err == io.EOF { + break + } + require.NoError(t, err) + } + } + } + + var stored *schema.AgenticMessage + store.mu.Lock() + snapshot := append([]SessionEventPayload{}, store.events...) + store.mu.Unlock() + for _, ep := range snapshot { + se, err := decodeSessionEvent[*schema.AgenticMessage](ep.Data) + require.NoError(t, err) + if se.Kind == SessionEventMessage && se.Message != nil && + len(se.Message.ContentBlocks) == 1 && + se.Message.ContentBlocks[0].Type == schema.ContentBlockTypeFunctionToolResult { + stored = se.Message + break + } + } + + require.NotNil(t, stored) + require.Len(t, stored.ContentBlocks, 1) + block := stored.ContentBlocks[0] + assert.Nil(t, block.StreamingMeta) + require.NotNil(t, block.FunctionToolResult) + assert.Equal(t, "call_1", block.FunctionToolResult.CallID) + assert.Equal(t, "execute", block.FunctionToolResult.Name) + require.Len(t, block.FunctionToolResult.Content, 1) + assert.Equal(t, "first\nsecond\n", block.FunctionToolResult.Content[0].Text.Text) +} + +func agenticToolResultMessage(callID, name, text string) *schema.AgenticMessage { + return &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: callID, + Name: name, + Content: []*schema.FunctionToolResultContentBlock{ + { + Type: schema.FunctionToolResultContentBlockTypeText, + Text: &schema.UserInputText{Text: text}, + }, + }, + }, + }, + }, + } +} + // TestStreamPersistence_GetMessageError_NotEnqueued verifies that a stream // materialization error sets persistErr (failing the turn commit) and does NOT // enqueue a corrupt SessionEvent. From d1947f2734e5088204118c2aff1ddfd7fedd3cbe Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Wed, 25 Mar 2026 15:36:52 +0800 Subject: [PATCH 3/8] feat(adk): auto memory middleware --- SESSION_API_DOCUMENTATION.md | 68 - V0.9_COMPATIBILITY_NOTE.md | 132 -- V0.9_RELEASE_FINDINGS.md | 90 - V0.9_RELEASE_NOTE.md | 70 - adk/chatmodel.go | 35 +- adk/handler.go | 8 +- adk/handler_test.go | 20 +- adk/interrupt_test.go | 2 +- adk/middlewares/automemory/automemory.go | 1612 +++++++++++++++++ adk/middlewares/automemory/automemory_test.go | 1005 ++++++++++ adk/middlewares/automemory/backend.go | 46 + adk/middlewares/automemory/consts.go | 55 + adk/middlewares/automemory/coordinator.go | 162 ++ adk/middlewares/automemory/dream/config.go | 181 ++ adk/middlewares/automemory/dream/dream.go | 258 +++ .../automemory/dream/dream_test.go | 434 +++++ adk/middlewares/automemory/dream/prompt.go | 137 ++ adk/middlewares/automemory/dream/session.go | 193 ++ .../automemory/dream/session_test.go | 125 ++ adk/middlewares/automemory/dream/store.go | 135 ++ .../automemory/inmemory_backend.go | 210 +++ .../automemory/internal/backend.go | 235 +++ adk/middlewares/automemory/local_backend.go | 192 ++ adk/middlewares/automemory/prompt.go | 316 ++++ .../dynamictool/toolsearch/toolsearch.go | 2 +- adk/middlewares/filesystem/filesystem.go | 2 +- adk/middlewares/filesystem/filesystem_test.go | 2 +- adk/middlewares/plantask/plantask.go | 2 +- adk/middlewares/plantask/plantask_test.go | 2 +- adk/middlewares/skill/skill.go | 2 +- adk/middlewares/skill/skill_test.go | 2 +- adk/prebuilt/deep/deep_test.go | 4 +- adk/prebuilt/deep/types.go | 2 +- adk/session_extra_test.go | 166 -- permission_middleware_comprehensive_review.md | 105 -- resume_wait_timeout_comprehensive_review.md | 152 -- 36 files changed, 5341 insertions(+), 823 deletions(-) delete mode 100644 SESSION_API_DOCUMENTATION.md delete mode 100644 V0.9_COMPATIBILITY_NOTE.md delete mode 100644 V0.9_RELEASE_FINDINGS.md delete mode 100644 V0.9_RELEASE_NOTE.md create mode 100644 adk/middlewares/automemory/automemory.go create mode 100644 adk/middlewares/automemory/automemory_test.go create mode 100644 adk/middlewares/automemory/backend.go create mode 100644 adk/middlewares/automemory/consts.go create mode 100644 adk/middlewares/automemory/coordinator.go create mode 100644 adk/middlewares/automemory/dream/config.go create mode 100644 adk/middlewares/automemory/dream/dream.go create mode 100644 adk/middlewares/automemory/dream/dream_test.go create mode 100644 adk/middlewares/automemory/dream/prompt.go create mode 100644 adk/middlewares/automemory/dream/session.go create mode 100644 adk/middlewares/automemory/dream/session_test.go create mode 100644 adk/middlewares/automemory/dream/store.go create mode 100644 adk/middlewares/automemory/inmemory_backend.go create mode 100644 adk/middlewares/automemory/internal/backend.go create mode 100644 adk/middlewares/automemory/local_backend.go create mode 100644 adk/middlewares/automemory/prompt.go delete mode 100644 permission_middleware_comprehensive_review.md delete mode 100644 resume_wait_timeout_comprehensive_review.md diff --git a/SESSION_API_DOCUMENTATION.md b/SESSION_API_DOCUMENTATION.md deleted file mode 100644 index 4f61bc8b7..000000000 --- a/SESSION_API_DOCUMENTATION.md +++ /dev/null @@ -1,68 +0,0 @@ - - -# Session API Documentation - -## Agent Interrupt Events - -`SessionEventAgentInterrupt` persists an agent-initiated business interrupt in the managed session timeline. - -```go -const SessionEventAgentInterrupt SessionEventKind = "agent.interrupt" -``` - -The event is distinct from `SessionEventUserInterrupt`: - -- `SessionEventAgentInterrupt` means the agent paused execution and needs external input before it can resume. -- `SessionEventUserInterrupt` means the user proactively cancelled execution. - -The payload is stored on `SessionEvent.AgentInterrupt`. - -```go -type AgentInterruptEvent struct { - Cause AgentInterruptCause `json:"cause,omitempty"` - CheckPointID string `json:"checkpoint_id,omitempty"` - InterruptContexts []*InterruptCtx `json:"interrupt_contexts,omitempty"` - ToolUseID string `json:"tool_use_id,omitempty"` - SpanEventID string `json:"span_event_id,omitempty"` -} -``` - -`Cause` categorizes why the interrupt happened: - -```go -const ( - AgentInterruptCauseToolPermission AgentInterruptCause = "tool_permission" - AgentInterruptCauseCustomTool AgentInterruptCause = "custom_tool" - AgentInterruptCauseGeneric AgentInterruptCause = "generic" -) -``` - -`CheckPointID` is the checkpoint key passed to `Runner.Resume` or `Runner.ResumeWithParams`. - -`InterruptContexts` uses the same public `[]*InterruptCtx` shape exposed on live `AgentAction.Interrupted` events. Root-cause `InterruptCtx.ID` values are the normal keys for `ResumeParams.Targets`. - -```go -resumeParams := &ResumeParams{ - Targets: map[string]any{ - interruptEvent.InterruptContexts[0].ID: result, - }, -} -``` - -`ToolUseID` is populated when the root-cause interrupt address contains a tool segment. The runner uses `AddressSegment.SubID` when present and falls back to `AddressSegment.ID` for compatibility with older tool-address paths. - -`SpanEventID` is a best-effort link to the related `span.tool_call_start` session event. It may be empty when the runner has not observed a matching tool span start event in the current run or resume drain loop. diff --git a/V0.9_COMPATIBILITY_NOTE.md b/V0.9_COMPATIBILITY_NOTE.md deleted file mode 100644 index 0d0018c77..000000000 --- a/V0.9_COMPATIBILITY_NOTE.md +++ /dev/null @@ -1,132 +0,0 @@ - - -# V0.9 agentic-runtime Compatibility Note - -本文列出现有用户从 V0.8.x 升级到 V0.9 `agentic-runtime` 时需要关注的 API 和语义变化。未列出的新增能力通常不影响既有 `*schema.Message` 路径。 - -## API 显式变更 - -### ChatModelAgentMiddleware 新增 AfterAgent - -`ChatModelAgentMiddleware` 新增 `AfterAgent` 方法。手写实现该接口的类型需要补充该方法,否则会编译失败。 - -推荐做法: - -- 如果 middleware 不需要特殊收尾逻辑,嵌入 `*adk.BaseChatModelAgentMiddleware`。 -- 如果 middleware 需要在 Agent 成功结束后清理状态、记录事件或补充统计,实现 `AfterAgent(ctx, state)`。 - -影响范围: - -- 仅影响显式实现 `ChatModelAgentMiddleware` 的用户代码。 -- 通过 `BaseChatModelAgentMiddleware` 组合扩展的代码可保持兼容。 - -### summarization.SummarizeMessages 被移除 - -`summarization.SummarizeMessages` 和 `summarization.SummarizeOutput` 不再导出。 - -迁移方式: - -- 构造 summarization middleware 时继续使用 `summarization.New` 或 `summarization.NewTyped`。 -- 需要主动触发同步 summarization 时,使用 `TypedMiddleware.Summarize`。 - -该调整将 summarization 的配置、状态读取和执行逻辑收敛到 middleware 内部,避免独立函数与运行时状态语义分叉。 - -## 需要关注语义变化的能力 - -### Summarization Finalize 后处理语义变化 - -V0.8.x 中,summarization middleware 会先执行默认 summary 后处理,再调用用户配置的 `Finalize`。因此自定义 `Finalize` 收到的 `summary` 已经包含 `PreserveUserMessages` 替换、`TranscriptFilePath` 注入和 summary preamble。 - -V0.9 中,如果设置了 `Config.Finalize`,middleware 会直接把模型生成的 raw summary 传给 `Finalize`,不再自动执行默认后处理。受影响的配置包括: - -- `PreserveUserMessages` -- `TranscriptFilePath` - -迁移方式: - -- 如果希望保留默认后处理,不要设置 `Finalize`,让 middleware 使用默认 finalization 路径。 -- 如果必须自定义 `Finalize`,但仍希望保留默认后处理,先通过 `DefaultFinalizer` 构造默认 finalizer,再在自定义逻辑中显式组合。 -- `DefaultFinalizer` 不会自动读取外层 `Config.PreserveUserMessages` 和 `Config.TranscriptFilePath`;需要通过 `DefaultFinalizerConfig` 显式传入。 -- 使用 `NewFinalizer().PreserveSkills(...).Build()` 的代码需要特别检查:该 finalizer 只负责 preserve skills,不会自动补上 `PreserveUserMessages` 和 `TranscriptFilePath`。 - -### 工具列表修改路径调整 - -`ModelContext.Tools` 不再是推荐的工具列表修改入口。 - -升级建议: - -- 在 `BeforeModelRewriteState` 中修改 `state.ToolInfos`。 -- 如需模型原生 deferred tool search,修改 `state.DeferredToolInfos`。 -- 不建议在 `WrapModel` 中修改工具列表;该修改只影响当前模型调用,后续 middleware、后续 turn 或 checkpoint/resume 不会继承这次修改。 - -### Model Retry 决策语义增强 - -`ModelRetryConfig` 新增 `ShouldRetry`。当 `ShouldRetry` 非空时,`IsRetryAble` 会被忽略。 - -需要注意: - -- 旧的 `IsRetryAble` 仍可用于错误维度的简单重试。 -- 使用 `ShouldRetry` 后,应显式处理成功输出但业务不接受的场景。 -- Interrupt 和 `ErrStreamCanceled` 不作为普通 retry error 处理。 - -### Cancel 错误语义 - -V0.9 引入主动取消语义后,应用需要区分主动取消、普通错误和业务 interrupt。 - -升级建议: - -- 上层应区分 `CancelError`、普通 error 和业务 interrupt。 -- 如果应用主动接入 `WithCancel`,不要把 `CancelError` 当作普通业务失败处理。 - -### AgenticMessage 迁移需要理解新的消息结构 - -`TypedChatModelAgent[*schema.AgenticMessage]` 是面向模型原生 Agentic 协议的新路径。迁移到该路径不只是把泛型参数从 `*schema.Message` 改成 `*schema.AgenticMessage`,还需要按 `AgenticMessage` 的 content block 结构处理消息内容。 - -需要注意: - -- AgenticMessage 路径使用 `AgenticModel` 与 `AgenticToolsNode` 处理工具调用。 -- 工具调用和工具结果通过 `AgenticMessage` content block 表达,尤其需要正确处理 tool call / tool result content block。 -- Agent transfer 能力不适用于 AgenticMessage 路径。 -- 既有应用如果不需要模型原生 Agentic 协议,建议继续使用默认 `*schema.Message` 路径;只有在明确要接入 `AgenticModel` 协议时再迁移。 - -### 模型适配器需要识别新增 option - -V0.9 引入 `AgenticModel` 后,模型适配器需要更严格地处理 call-time options。`AgenticModel` 是 `BaseModel[*schema.AgenticMessage]` 的别名,不再提供类似 `ToolCallingChatModel.WithTools` 的增强接口;工具绑定统一通过 `model.WithTools` 作为 `model.Option` 传入。 - -需要注意: - -- 所有支持 AgenticMessage 的模型适配器都应读取 `Options.Tools`,并将其映射到 provider 的 tool calling 协议。 -- `AgenticModel` 不应要求用户先调用某个 `WithTools` 方法得到“带工具的模型实例”;ADK 会在每次模型调用时通过 `model.WithTools` 传递当前工具列表。 -- 如果适配器只从自身 config 读取工具,而忽略 `model.WithTools`,在 ChatModelAgent / AgenticToolsNode 路径下会出现模型看不到工具或工具列表不随运行态变化的问题。 - -V0.9 还在 `model.Options` 中新增: - -- `DeferredTools` -- `ToolSearchTool` -- `AgenticToolChoice` - -现有模型适配器忽略这些 option 通常不会导致编译失败,但会导致 deferred tool search、模型原生 tool search 或 agentic tool choice 不生效。适配器维护者应按目标 provider 的协议补齐转换逻辑。 - -### ToolInfo 序列化形态变化 - -`ToolInfo` 增加显式 JSON/Gob 编解码,以保留 `ParamsOneOf`。 - -影响: - -- `ToolInfo` 进入了 `ChatModelAgentState.ToolInfos` / `DeferredToolInfos`,因此可能随 Agent state 一起进入 checkpoint。 -- 显式 JSON/Gob 编解码用于保证 `ParamsOneOf` 在 checkpoint、deep copy 和恢复过程中不会丢失。 -- 如果外部系统直接依赖旧版 `ToolInfo` JSON 形态,需要重新确认序列化兼容性。 diff --git a/V0.9_RELEASE_FINDINGS.md b/V0.9_RELEASE_FINDINGS.md deleted file mode 100644 index e58c2aaa3..000000000 --- a/V0.9_RELEASE_FINDINGS.md +++ /dev/null @@ -1,90 +0,0 @@ - - -# v0.9 Release Findings - -## Comparison Scope - -- Compared `alpha/09` against `main` using `main...alpha/09`. -- `main` is the merge base of `alpha/09`. -- Branch heads observed during analysis: - - `main`: `5e1305506c4fa89ef5d786035a947258e29a7593` - - `alpha/09`: `c39433511896d6a12e379a7958c6e5d489560b5a` -- Second validation pass confirmed `main == origin/main`, `alpha/09 == origin/alpha/09`, and `main` is the merge base. -- Diff size: `136 files changed`, `49,967 insertions`, `2,790 deletions`. -- Changed surface is concentrated in `adk`, `schema`, `components/model`, `components/prompt`, `compose`, and callback helpers. - -## Primary Features - -| Area | v0.9 feature | Direct diff validation | -| --- | --- | --- | -| Agentic message model | Adds `schema.AgenticMessage`, content-block based message schema, provider extensions, streaming metadata, MCP/server/function tool blocks, and concat support. | `A schema/agentic_message.go`; concat registration in `schema/message.go`. | -| Generic model abstraction | Introduces `model.BaseModel[M]`; keeps `BaseChatModel` as `BaseModel[*schema.Message]`; adds `AgenticModel`. | `M components/model/interface.go`. | -| Typed ADK | Adds typed agents, typed events, typed runner, typed `ChatModelAgent`, and typed message variants while preserving default `*schema.Message` aliases. | `M adk/interface.go`, `M adk/chatmodel.go`, `M adk/runner.go`. | -| Agentic ChatModelAgent path | `TypedChatModelAgent[*schema.AgenticMessage]` supports a single-shot agentic model path where tool calling is handled inside the model/message protocol. | `M adk/chatmodel.go`; `TypedChatModelAgent` and agentic ReAct path are added in the diff. | -| Cancellation | Adds `WithCancel`, `CancelMode`, safe-point cancellation, recursive cancellation, timeout escalation, `CancelHandle`, and `CancelError` with resumable interrupt contexts. | `A adk/cancel.go`; cancel integration hunks in `adk/chatmodel.go`, `adk/flow.go`, and `adk/wrappers.go`. | -| TurnLoop | Adds a push-based `TurnLoop` runtime with `Push`, non-blocking `Stop`, idle-stop, checkpoint/resume integration, and preempt handling. | `A adk/turn_loop.go`, `A adk/turn_buffer.go`. | -| Model retry | Upgrades retry from error-only retryability to `ShouldRetry(ctx, RetryContext) -> RetryDecision`, allowing output inspection, input rewrite, option rewrite, backoff override, and reject reason. | `M adk/retry_chatmodel.go`. | -| Model failover | Adds ChatModel failover with `ModelFailoverConfig`, `FailoverContext`, last-success model preference, and callback-aware proxying. | `A adk/failover_chatmodel.go`; config wiring in `adk/chatmodel.go`. | -| Tool search | Adds dynamic tool search middleware with both client-side search and model-native deferred tool search via `DeferredToolInfos`, `WithDeferredTools`, and `WithToolSearchTool`. | `M adk/middlewares/dynamictool/toolsearch/toolsearch.go`, `M components/model/option.go`. | -| Middleware modernization | Generifies summarization, reduction, skill, filesystem, plan-task, patch-tool-calls and adds `AfterAgent`; state now carries `ToolInfos` and `DeferredToolInfos` as the recommended mutable model-call surface. | Diff hunks in `adk/handler.go`, `adk/middlewares/*`, and `adk/prebuilt/deep/deep.go`. | -| Summarization API | Adds `TypedMiddleware.Summarize` and typed finalizer/customized-action paths; removes the old standalone `SummarizeMessages` / `SummarizeOutput` API in favor of middleware-owned summarization. | `M adk/middlewares/summarization/summarization.go`, `M customized_action.go`, `M finalizer_builder.go`. | -| Compose/tooling | Adds `AgenticToolsNode` and tool name/argument aliases for `ToolsNode`. | `A compose/agentic_tools_node.go`, `M compose/tool_node.go`. | -| Prompt/callback support | Adds agentic prompt templates and callback types for agentic prompt/model/tools/agent components. | `A components/prompt/agentic_chat_template.go`, `A components/*/agentic_callback_extra.go`, `M utils/callbacks/template.go`. | -| Filesystem | Adds enhanced multimodal read support and PDF page validation. | `M adk/filesystem/backend.go`, `M adk/middlewares/filesystem/filesystem.go`. | -| Agents.md | Adds `agentsmd` middleware for automatically loading and injecting `AGENTS.md`-style instructions. | `A adk/middlewares/agentsmd/agentsmd.go`, `A loader.go`. | - -## Compatibility Notes - -| Impact | Note | -| --- | --- | -| Source break for custom middleware implementers | `ChatModelAgentMiddleware` now includes `AfterAgent`. Any user-defined type that manually implements the interface must add this method or embed `BaseChatModelAgentMiddleware`. | -| Middleware tool mutation semantics | `ModelContext.Tools` is now deprecated as a mutation surface; tool list changes should happen through `state.ToolInfos` / `state.DeferredToolInfos` in `BeforeModelRewriteState`. Mutating tools in `WrapModel` only affects one model call and is explicitly discouraged. | -| Summarization standalone API removal | `summarization.SummarizeMessages` and `summarization.SummarizeOutput` are no longer exported. Use `New` / `NewTyped` to construct middleware, or call `TypedMiddleware.Summarize` when direct summarization is needed. | -| Retry behavior change | If `ShouldRetry` is set, `IsRetryAble` is ignored. In streaming mode, the full stream is consumed before the retry decision is made, although events are still emitted in real time. | -| Retry cancellation semantics | Retry now treats interrupts and `ErrStreamCanceled` as non-retryable and uses context-aware backoff rather than unconditional sleep. Users relying on retrying interrupt/cancel errors should adjust policy. | -| Cancellation error semantics | During active cancel, business interrupts are absorbed into `CancelError`; the checkpoint preserves interrupt contexts and business interrupt can re-fire on resume. Consumers should handle `CancelError` separately from ordinary business interrupts. | -| TurnLoop stop semantics | `TurnLoop.Stop` is non-blocking; use `Wait` for terminal state. Cancel-related stop options degrade to "finish current turn then exit" if the running agent does not support `WithCancel`. `UntilIdleFor` silently drops cancel options in the same call. | -| Agentic path limitations | `TypedChatModelAgent[*schema.AgenticMessage]` is not feature-equivalent to `*schema.Message`: it uses a single-shot path, does not support agent transfer, and cancel monitoring/retry on model streams are not yet wired. | -| Model adapters must honor new options | Native tool search requires model implementations to read `Options.DeferredTools`, `Options.ToolSearchTool`, and `Options.AgenticToolChoice`. Existing adapters that ignore unknown common options will compile but will not support the new behavior. | -| Serialization shape change | `ToolInfo` now has explicit JSON/Gob encoding that preserves `ParamsOneOf`. This fixes checkpoint/deep-copy loss, but external systems depending on the previous raw JSON shape should re-check serialized payloads. | -| Filesystem page validation | Multimodal read validates PDF `pages` and rejects ranges over 20 pages per request. Users passing arbitrary page ranges should handle validation errors. | -| Transfer/workflow/supervisor positioning | Agent transfer, workflow agents, and supervisor are not removed, but many APIs now carry `NOT RECOMMENDED` guidance in favor of `ChatModelAgent` + `AgentTool` or `DeepAgent`. This is a semantic/product-direction compatibility note, not a signature break. | - -## Likely Non-Breaking Alias Changes - -- `BaseChatModel` becomes an alias of `BaseModel[*schema.Message]`; existing implementations with `Generate(ctx, []*schema.Message, ...)` and `Stream(ctx, []*schema.Message, ...)` should still satisfy it. -- `Agent`, `AgentInput`, `AgentEvent`, `AgentOutput`, `ChatModelAgent`, `ChatModelAgentConfig`, `ChatModelAgentState`, `ModelContext`, and several middleware config types are preserved as `*schema.Message` aliases over typed forms. -- `ToolOutputPart`, `ToolResult`, and related tool-result types moved from `schema/message.go` to `schema/tool.go`, but remain in package `schema`, so import paths and qualified names are unchanged. - -## Validation Results - -Completed checks: - -- Direct branch-ref validation: - - Verified `main == origin/main`, `alpha/09 == origin/alpha/09`, and the merge base is `main`. - - Rechecked each retained feature row with `git diff main...alpha/09` file status or hunks. - - Removed raw AST API-diff counts from the release findings because the script over-reported generic alias refactors as removals. -- Representative downstream compatibility compile check: - - `GOWORK=off go test .` passed in a temporary external module using `replace github.com/cloudwego/eino => ..`. - - Verified that existing `BaseChatModel` implementations still compile against the `BaseModel[*schema.Message]` alias. - - Verified that `ChatModelAgentConfig`, `summarization.Config`, `reduction.Config`, `skill.Config`, `ToolResult`, and new model options are usable from downstream code. - - Verified that embedding `*adk.BaseChatModelAgentMiddleware` remains the safe compatibility path for middleware implementations. -- Negative compile check for old custom middleware: - - `GOWORK=off go test -tags=oldmiddleware .` fails as expected with: `oldStyleMiddleware does not implement adk.TypedChatModelAgentMiddleware[*schema.Message] (missing method AfterAgent)`. - - This confirms the `AfterAgent` source compatibility note for users who manually implement `ChatModelAgentMiddleware` without embedding the base middleware. -- Targeted package tests: - - `go test ./adk ./adk/middlewares/summarization ./adk/middlewares/reduction ./adk/middlewares/skill ./adk/middlewares/dynamictool/toolsearch ./components/model ./components/prompt ./compose ./schema` passed. diff --git a/V0.9_RELEASE_NOTE.md b/V0.9_RELEASE_NOTE.md deleted file mode 100644 index dc50213ef..000000000 --- a/V0.9_RELEASE_NOTE.md +++ /dev/null @@ -1,70 +0,0 @@ - - -# V0.9 agentic-runtime Release Note - -V0.9 的版本主题是 `agentic-runtime`。该版本主要围绕 ADK 的消息协议、Agent 运行控制和多轮运行时能力展开,在保留 `*schema.Message` 默认路径的同时,引入 `AgenticMessage` 及配套泛型抽象,为更丰富的模型原生 Agent 协议、服务端工具调用、运行中断与恢复打下基础。 - -## 1. AgenticMessage 与 ADK 支持 - -V0.9 新增 `schema.AgenticMessage`,用于表达比传统 `schema.Message` 更完整的 Agentic 消息结构。 - -- `AgenticMessage` 采用 content block 模型,支持文本、推理内容、工具调用、工具结果、服务端工具、MCP 工具和多模态内容等结构化片段。 -- `[]ContentBlock` 能更完整地保留不同模型协议响应中的 block 时序;新增 block 类型也更适配 OpenAI Responses API、Claude、Gemini 等协议中的 tool use、reasoning、streaming metadata 等结构。 -- `components/model` 新增 `AgenticModel` 组件,用于接入以 `AgenticMessage` 为输入输出的模型实现。 -- ADK 对 `AgenticMessage` 路径提供 typed agent、typed event、typed runner 和 typed `ChatModelAgent` 支持,使 AgenticModel 能进入 ADK 的 Agent 生命周期。 - -## 2. ChatModelAgent 能力扩展 - -V0.9 对 `ChatModelAgent` 的运行控制、模型调用可靠性和 middleware 扩展点进行了系统增强。 - -### Cancel - -- 新增 Agent Cancel 能力,用于从外部主动终止正在运行的 Agent。 -- 支持安全点取消、递归取消、取消超时升级,以及取消过程中的 checkpoint 持久化。 -- 取消期间发生的 interrupt 会统一进入取消语义,调用方可以通过 `CancelError` 区分主动取消与普通业务失败。 - -### Model Retry - -- Retry 从简单的 error retry 扩展为 `ShouldRetry(ctx, RetryContext) -> RetryDecision`。 -- Retry 决策可以读取模型输出、拒绝不满足条件的输出、修改下一次输入、追加模型 option,并覆盖 backoff。 - -### Model Failover - -- 新增 Model Failover 能力,用于在模型调用失败后切换到备用模型。 -- Failover 决策可以读取失败 attempt 的输出、错误、原始输入和 attempt 序号,并选择下一次使用的模型。 -- 支持为备用模型改写输入;也支持优先复用上一次调用成功的模型,降低每次从固定主模型开始试错的成本。 - -### Middleware 增强 - -- `ChatModelAgentMiddleware` 新增 `AfterAgent`,用于在 Agent 成功结束后执行收尾逻辑。 -- Summarization、reduction、skill、filesystem、plan-task、patch-tool-calls 等 middleware 完成泛型化,支持 `AgenticMessage` 路径。 -- Summarization middleware 新增 `TypedMiddleware.Summarize`,同步 summarization 能力从独立函数转为 middleware 内聚能力。 -- Filesystem middleware 增强多模态读取能力,并增加 PDF pages 校验。 -- 新增 `agentsmd` middleware,用于加载和注入 `AGENTS.md` 风格的项目指令。 -- `ChatModelAgentState` 增加 `ToolInfos` 和 `DeferredToolInfos`,作为 middleware 调整模型可见工具集合的主路径。 -- `ToolInfos` 表示当前模型调用直接可见的工具;`DeferredToolInfos` 表示可由模型通过工具搜索机制按需发现的候选工具。 -- Tool search middleware 支持三类工具加载方式:使用模型侧原生 tool search 能力从 deferred tools 中按需加载;按模型协议要求提供固定 schema 的 `ToolSearchTool`,由模型通过该入口搜索 deferred tools;不依赖模型侧协议,使用 Eino 提供的自定义 `tool_search` tool 检索工具,并把命中的工具追加到常规 `ToolInfos`。 -- Compose 新增 `AgenticToolsNode`,`ToolsNode` 增加 tool name 和 argument alias 支持。 - -## 3. TurnLoop - -V0.9 新增 `TurnLoop`,用于把一次性的 Agent run 提升为可持续运行、可被外部驱动的 turn 级运行时。 - -- 面向多轮运行:`TurnLoop` 持续接收外部输入,每个 turn 独立规划输入、构造 Agent、消费事件,适合长期在线的交互式 Agent。 -- 支持输入合并:`GenInput` 在 turn 边界决定本轮消费哪些输入、哪些继续等待,应用可以实现批处理、去重、合并用户连续输入等策略。 -- 支持抢占:带 preempt option 的 `Push` 会原子地写入新输入并请求取消当前 turn,使高优先级输入可以打断正在运行的 Agent。 -- 支持声明式 checkpoint/resume:恢复时,应用不需要自行还原输入队列;`TurnLoop` 会区分被中断的输入、尚未处理的输入和恢复后新到达的输入,应用只需声明这些输入如何重新进入后续 turn。 diff --git a/adk/chatmodel.go b/adk/chatmodel.go index 8e0cee42f..3e0ee2c93 100644 --- a/adk/chatmodel.go +++ b/adk/chatmodel.go @@ -883,9 +883,12 @@ type execContext struct { toolUpdated bool // whether needs to pass a compose.WithToolList option to ToolsNode due to tool list change } -func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execContext) (context.Context, *execContext, error) { - runCtx := &ChatModelAgentContext{ +func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execContext, agentInput *TypedAgentInput[M]) ( + context.Context, *execContext, *TypedAgentInput[M], error) { + + runCtx := &ChatModelAgentContext[M]{ Instruction: ec.instruction, + AgentInput: agentInput, Tools: cloneSlice(ec.unwrappedTools), ReturnDirectly: copyMap(ec.returnDirectly), } @@ -894,7 +897,7 @@ func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execC for i, handler := range a.handlers { ctx, runCtx, err = handler.BeforeAgent(ctx, runCtx) if err != nil { - return ctx, nil, fmt.Errorf("handler[%d] (%T) BeforeAgent failed: %w", i, handler, err) + return ctx, nil, nil, fmt.Errorf("handler[%d] (%T) BeforeAgent failed: %w", i, handler, err) } } @@ -914,12 +917,12 @@ func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execC toolInfos, err := genToolInfos(ctx, &runtimeEC.toolsNodeConf) if err != nil { - return ctx, nil, err + return ctx, nil, nil, err } runtimeEC.toolInfos = toolInfos - return ctx, runtimeEC, nil + return ctx, runtimeEC, runCtx.AgentInput, nil } func (a *TypedChatModelAgent[M]) applyAfterAgent(ctx context.Context) (context.Context, error) { @@ -1581,12 +1584,12 @@ func (a *TypedChatModelAgent[M]) buildRunFunc(ctx context.Context) typedRunFunc[ return a.run } -func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Context, typedRunFunc[M], *execContext, error) { +func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context, agentInput *TypedAgentInput[M]) (context.Context, typedRunFunc[M], *execContext, *TypedAgentInput[M], error) { defaultRun := a.buildRunFunc(ctx) bc := a.exeCtx if bc == nil { - return ctx, defaultRun, bc, nil + return ctx, defaultRun, bc, agentInput, nil } if len(a.handlers) == 0 { @@ -1596,32 +1599,32 @@ func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Contex returnDirectly: bc.returnDirectly, toolInfos: bc.toolInfos, } - return ctx, defaultRun, runtimeBC, nil + return ctx, defaultRun, runtimeBC, agentInput, nil } - ctx, runtimeBC, err := a.applyBeforeAgent(ctx, bc) + ctx, runtimeBC, agentInput, err := a.applyBeforeAgent(ctx, bc, agentInput) if err != nil { - return ctx, nil, nil, err + return ctx, nil, nil, nil, err } if !runtimeBC.rebuildGraph { - return ctx, defaultRun, runtimeBC, nil + return ctx, defaultRun, runtimeBC, agentInput, nil } var tempRun typedRunFunc[M] if len(runtimeBC.toolsNodeConf.Tools) == 0 { tempRun, err = a.buildNoToolsRunFunc(ctx) if err != nil { - return ctx, nil, nil, err + return ctx, nil, nil, nil, err } } else { tempRun, err = a.buildReActRunFunc(ctx, runtimeBC) if err != nil { - return ctx, nil, nil, err + return ctx, nil, nil, nil, err } } - return ctx, tempRun, runtimeBC, nil + return ctx, tempRun, runtimeBC, agentInput, nil } func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput[M], opts ...AgentRunOption) *AsyncIterator[*TypedAgentEvent[M]] { @@ -1630,7 +1633,7 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput o := getCommonOptions(nil, opts...) cancelCtx, cancelCtxOwned := resolveRunCancelContext(ctx, o) - ctx, run, bc, err := a.getRunFunc(ctx) + ctx, run, bc, input, err := a.getRunFunc(ctx, input) if err != nil { go func() { if cancelCtxOwned && cancelCtx != nil { @@ -1725,7 +1728,7 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o o := getCommonOptions(nil, opts...) cancelCtx, cancelCtxOwned := resolveRunCancelContext(ctx, o) - ctx, run, bc, err := a.getRunFunc(ctx) + ctx, run, bc, _, err := a.getRunFunc(ctx, nil) if err != nil { go func() { if cancelCtxOwned && cancelCtx != nil { diff --git a/adk/handler.go b/adk/handler.go index 472831da2..db7ff59b4 100644 --- a/adk/handler.go +++ b/adk/handler.go @@ -89,7 +89,7 @@ type ModelContext = TypedModelContext[*schema.Message] // Handlers can modify Instruction, Tools, and ReturnDirectly to customize agent behavior. // // This type is specific to ChatModelAgent. Other agent types may define their own context types. -type ChatModelAgentContext struct { +type ChatModelAgentContext[M MessageType] struct { // Instruction is the current instruction for the Agent execution. // It includes the instruction configured for the agent, additional instructions appended by framework // and AgentMiddleware, and modifications applied by previous BeforeAgent handlers. @@ -97,6 +97,8 @@ type ChatModelAgentContext struct { // to be (optionally) formatted with SessionValues and converted to system message. Instruction string + AgentInput *TypedAgentInput[M] + // Tools are the raw tools (without any wrapper or tool middleware) currently configured for the Agent execution. // They includes tools passed in AgentConfig, implicit tools added by framework such as transfer / exit tools, // and other tools already added by middlewares. @@ -144,7 +146,7 @@ type ChatModelAgentContext struct { type TypedChatModelAgentMiddleware[M MessageType] interface { // BeforeAgent is called before each agent run, allowing modification of // the agent's instruction and tools configuration. - BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) + BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[M]) (context.Context, *ChatModelAgentContext[M], error) // AfterAgent is called after the agent run reaches a successful terminal state. // Successful terminal states are: final answer (model response with no tool calls), @@ -301,7 +303,7 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) WrapModel(_ context.Context, m mo return m, nil } -func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[M]) (context.Context, *ChatModelAgentContext[M], error) { return ctx, runCtx, nil } diff --git a/adk/handler_test.go b/adk/handler_test.go index 811cd2b27..cd304cbce 100644 --- a/adk/handler_test.go +++ b/adk/handler_test.go @@ -37,7 +37,7 @@ type testInstructionHandler struct { text string } -func (h *testInstructionHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *testInstructionHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { if runCtx.Instruction == "" { runCtx.Instruction = h.text } else if h.text != "" { @@ -51,7 +51,7 @@ type testInstructionFuncHandler struct { fn func(ctx context.Context, instruction string) (context.Context, string, error) } -func (h *testInstructionFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *testInstructionFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { newCtx, newInstruction, err := h.fn(ctx, runCtx.Instruction) if err != nil { return ctx, runCtx, err @@ -65,7 +65,7 @@ type testToolsHandler struct { tools []tool.BaseTool } -func (h *testToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *testToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { runCtx.Tools = append(runCtx.Tools, h.tools...) return ctx, runCtx, nil } @@ -75,7 +75,7 @@ type testToolsFuncHandler struct { fn func(ctx context.Context, tools []tool.BaseTool, returnDirectly map[string]bool) (context.Context, []tool.BaseTool, map[string]bool, error) } -func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { newCtx, newTools, newReturnDirectly, err := h.fn(ctx, runCtx.Tools, runCtx.ReturnDirectly) if err != nil { return ctx, runCtx, err @@ -87,10 +87,10 @@ func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatMode type testBeforeAgentHandler struct { *BaseChatModelAgentMiddleware - fn func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) + fn func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) } -func (h *testBeforeAgentHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *testBeforeAgentHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { return h.fn(ctx, runCtx) } @@ -894,10 +894,10 @@ func TestContextPropagation(t *testing.T) { Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ - &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { + &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { return context.WithValue(ctx, key1, "value1"), runCtx, nil }}, - &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { + &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { handler2ReceivedValue = ctx.Value(key1) return ctx, runCtx, nil }}, @@ -962,7 +962,7 @@ func TestHandlerErrorHandling(t *testing.T) { Description: "Test agent", Model: cm, Handlers: []ChatModelAgentMiddleware{ - &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { + &testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { return ctx, runCtx, assert.AnError }}, }, @@ -1042,7 +1042,7 @@ type countingHandler struct { mu sync.Mutex } -func (h *countingHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *countingHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { h.mu.Lock() h.beforeAgentCount++ h.mu.Unlock() diff --git a/adk/interrupt_test.go b/adk/interrupt_test.go index 773684010..2a26e0bea 100644 --- a/adk/interrupt_test.go +++ b/adk/interrupt_test.go @@ -59,7 +59,7 @@ func TestPreprocessADKCheckpoint(t *testing.T) { }) } -func (h *interruptTestToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) { +func (h *interruptTestToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) { runCtx.Tools = append(runCtx.Tools, h.tools...) return ctx, runCtx, nil } diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go new file mode 100644 index 000000000..ec3aeb6a2 --- /dev/null +++ b/adk/middlewares/automemory/automemory.go @@ -0,0 +1,1612 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package automemory provides middleware that injects and persists session +// memories around chat-model agent runs. +package automemory + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/slongfield/pyfmt" + "gopkg.in/yaml.v3" + + "github.com/cloudwego/eino/adk" + ainternal "github.com/cloudwego/eino/adk/middlewares/automemory/internal" + adkfs "github.com/cloudwego/eino/adk/middlewares/filesystem" + fsmw "github.com/cloudwego/eino/adk/middlewares/filesystem" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func init() { + schema.RegisterName[*memoryExtra]("_eino_adk_automemory_extra") +} + +type Config[M adk.MessageType] struct { + MemoryDirectory string + + MemoryBackend Backend + + // Model is the default tool-calling model used by topic selection and memory extraction. + // Per-read/per-write overrides can be configured in Read.Model / Write.Model. + Model model.ToolCallingChatModel + + // Read controls how memories are loaded and injected. + // Optional. Defaults to Sync load with topic selection enabled (if Model is set). + Read *ReadConfig + + // Write controls post-run memory extraction and persistence. + // Optional. Default: disabled. + Write *WriteConfig + + // Coordination controls session identity and distributed async extraction coordination. + // Optional. Defaults to a local in-process coordinator. + Coordination *CoordinationConfig[M] + + // OnError is called when automemory encounters an error. Errors are best-effort by default: + // the middleware will skip memory injection and allow the agent to continue. + // Optional. + OnError func(ctx context.Context, stage string, err error) +} + +type ReadMode string + +const ( + ReadModeSync ReadMode = "sync" + ReadModeAsync ReadMode = "async" +) + +type ReadConfig struct { + Mode ReadMode + + // Model is used for topic selection. Defaults to Config.Model. + Model model.ToolCallingChatModel + + // Instruction overrides the default auto memory instruction block appended to system prompt. + // Optional. + Instruction *string + + // Index controls how MEMORY.md is loaded into system prompt. + // Optional. + Index *IndexConfig + + // TopicSelection controls the "LLM select topics" path. + // Optional. If nil, default topic selection settings are applied. + // Topic selection becomes active when Read.Model is available. + TopicSelection *TopicSelectionConfig +} + +type IndexConfig struct { + FileName string + MaxLines int + MaxBytes int +} + +type TopicSelectionConfig struct { + // CandidateGlob is matched against the RELATIVE path under MemoryDirectory. + // Example: "**/*.md" + CandidateGlob string + CandidateLimit int + // CandidatePreviewLines are read from each candidate to parse YAML frontmatter. + CandidatePreviewLines int + + TopK int + + MaxLines int + MaxBytes int +} + +type WriteMode string + +const ( + WriteModeDisabled WriteMode = "disabled" + WriteModeAsync WriteMode = "async" + WriteModeSync WriteMode = "sync" +) + +type WriteConfig struct { + Mode WriteMode + + // Model is used for memory extraction. Defaults to Config.Model. + Model model.ToolCallingChatModel + + // MaxTurns caps the extractor's tool-call loop. + MaxTurns int + + SkipIndex bool + + // HandleExtractionIterator, if set, is called with the extractionAgent's event + // iterator returned by Run(). The handler is responsible for draining the + // iterator (calling Next until it returns ok=false) and returning any error + // it wants to surface to the middleware. + // + // If nil, automemory uses the default drain behavior: ignore all events and + // return the first ev.Err encountered (if any). + HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error +} + +type middleware[M adk.MessageType] struct { + adk.TypedBaseChatModelAgentMiddleware[M] + + cfg *Config[M] + + resolvedMemoryDirectory string + + topicSelectionModel model.ToolCallingChatModel + extractionHandler adk.ChatModelAgentMiddleware + topicSelectionTool *schema.ToolInfo + coordination *CoordinationConfig[M] +} + +type selectionFuture struct { + done chan struct{} + mu sync.Mutex + + // Store an immutable snapshot to avoid being mutated via shared pointers. + content string + err error + applied bool +} + +type ctxKeySelectionFuture struct{} + +const ( + memoryExtraKey = "__eino_automemory__" + instructionMarker = "" +) + +type memoryExtra struct { + Type string + Cursor int + UpdatedAt string + Visibility string + SchemaVer int +} + +// New creates an automemory middleware from the provided configuration. +func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedChatModelAgentMiddleware[M], error) { + if config == nil { + return nil, fmt.Errorf("auto memory config: invalid") + } + + cfg := cloneConfig(config) + if cfg.MemoryDirectory == "" || cfg.MemoryBackend == nil { + return nil, fmt.Errorf("auto memory config: invalid") + } + + resolvedMemoryDir, err := ainternal.ResolveMemoryDir(cfg.MemoryDirectory) + if err != nil { + return nil, fmt.Errorf("auto memory config: resolve memory directory: %w", err) + } + if cfg.Read == nil { + cfg.Read = &ReadConfig{} + } + applyReadDefaults(cfg) + + m := &middleware[M]{ + TypedBaseChatModelAgentMiddleware: adk.TypedBaseChatModelAgentMiddleware[M]{}, + cfg: cfg, + resolvedMemoryDirectory: resolvedMemoryDir, + coordination: cfg.Coordination, + } + + m.topicSelectionTool = topicSelectionToolInfo() + if cfg.Read.TopicSelection != nil && cfg.Read.Model != nil { + bound, err := cfg.Read.Model.WithTools([]*schema.ToolInfo{m.topicSelectionTool}) + if err != nil { + return nil, fmt.Errorf("auto memory topic selection model init failed: %w", err) + } + m.topicSelectionModel = bound + } + + if cfg.Write.Mode != WriteModeDisabled && cfg.Write.Model != nil { + writeFSBackend, err := ainternal.NewFSBackend(cfg.MemoryBackend, ainternal.FSBackendConfig{ + BaseDir: resolvedMemoryDir, + NotFoundAsContent: true, + ErrorPrefix: "fs backend", + }) + if err != nil { + return nil, err + } + fileSystemMiddleware, err := fsmw.New(ctx, &fsmw.MiddlewareConfig{ + Backend: writeFSBackend, + LsToolConfig: &fsmw.ToolConfig{Disable: true}, + GrepToolConfig: &fsmw.ToolConfig{Disable: true}, + }) + if err != nil { + return nil, err + } + m.extractionHandler = fileSystemMiddleware + } + + return m, nil +} + +func (m *middleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext[M]) (context.Context, *adk.ChatModelAgentContext[M], error) { + if runCtx == nil { + return ctx, runCtx, nil + } + nRunCtx := *runCtx + + // Sync distributed write cursor back into message extras so later runs on other + // machines still carry a transcript-local marker. + if nRunCtx.AgentInput != nil && len(nRunCtx.AgentInput.Messages) > 0 && m.coordination != nil && m.coordination.Coordinator != nil { + if sessionID, err := m.resolveSessionID(ctx, &adk.TypedChatModelAgentState[M]{Messages: nRunCtx.AgentInput.Messages}); err == nil && sessionID != "" { + localCursor := getWriteCursorFromMessages(nRunCtx.AgentInput.Messages) + if remoteCursor, ok, err := m.coordination.Coordinator.GetCursor(ctx, sessionID); err == nil && ok && remoteCursor > localCursor { + st := markWriteCursor(&adk.TypedChatModelAgentState[M]{Messages: nRunCtx.AgentInput.Messages}, remoteCursor) + if st != nil { + nRunCtx.AgentInput = &adk.TypedAgentInput[M]{ + Messages: st.Messages, + EnableStreaming: nRunCtx.AgentInput.EnableStreaming, + } + } + } + } + } + + // If automemory was already injected into the instruction or message list, + // skip all memory-loading work for this run and let the agent continue. + if hasInstructionInjected(nRunCtx.Instruction) || (nRunCtx.AgentInput != nil && alreadyInjected(nRunCtx.AgentInput.Messages)) { + return ctx, &nRunCtx, nil + } + + // 1) System prompt: inject auto memory instruction + MEMORY.md content (best-effort). + nRunCtx.Instruction = m.injectIndexIntoInstruction(ctx, nRunCtx.Instruction) + + // 2) Topic memories: sync mode injects before the user's query. + if m.cfg.Read.Mode == ReadModeSync && m.cfg.Read.TopicSelection != nil && m.topicSelectionModel != nil { + memMsg, err := m.selectAndBuildTopicMemoryMessage(ctx, nRunCtx.AgentInput) + if err != nil { + m.onErr(ctx, OnErrorStageTopicSelectionSync, err) + } else if memMsg != nil && nRunCtx.AgentInput != nil && len(nRunCtx.AgentInput.Messages) > 0 { + msgs := append([]M{}, nRunCtx.AgentInput.Messages...) + msgs = append(msgs, memMsg) + nRunCtx.AgentInput = &adk.TypedAgentInput[M]{Messages: msgs, EnableStreaming: nRunCtx.AgentInput.EnableStreaming} + } + } + + // 3) Topic memories: async mode starts selection here (cannot use RunLocalValue in BeforeAgent). + if m.cfg.Read.Mode == ReadModeAsync && m.cfg.Read.TopicSelection != nil && m.topicSelectionModel != nil && nRunCtx.AgentInput != nil { + if existing, _ := ctx.Value(ctxKeySelectionFuture{}).(*selectionFuture); existing == nil { + fut := &selectionFuture{done: make(chan struct{})} + ctx = context.WithValue(ctx, ctxKeySelectionFuture{}, fut) + + // Snapshot current messages for selection; async path is best-effort. + msgSnapshot := append([]M{}, nRunCtx.AgentInput.Messages...) + go func() { + defer close(fut.done) + memMsg, selErr := m.selectAndBuildTopicMemoryMessage(ctx, &adk.TypedAgentInput[M]{Messages: msgSnapshot}) + fut.mu.Lock() + defer fut.mu.Unlock() + if selErr != nil { + fut.err = selErr + return + } + if !isNilMessage(memMsg) { + fut.content = userMessageTextContent(memMsg) + } + }() + } + } + + return ctx, &nRunCtx, nil +} + +func (m *middleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[M], _ *adk.TypedModelContext[M]) (context.Context, *adk.TypedChatModelAgentState[M], error) { + if state == nil { + return ctx, state, nil + } + // Best-effort protection: if automemory content has been injected before and later + // mutated by other components, restore it using the immutable snapshot stored in the future. + if fut, _ := ctx.Value(ctxKeySelectionFuture{}).(*selectionFuture); fut != nil { + fut.mu.Lock() + expected := fut.content + fut.mu.Unlock() + if strings.TrimSpace(expected) != "" { + state = ensureMemoryMsgUnchanged(state, expected) + } + } + if m.cfg.Read.Mode != ReadModeAsync { + return ctx, state, nil + } + fut, _ := ctx.Value(ctxKeySelectionFuture{}).(*selectionFuture) + if fut == nil { + return ctx, state, nil + } + + select { + case <-fut.done: + default: + return ctx, state, nil + } + + fut.mu.Lock() + if fut.applied { + fut.mu.Unlock() + return ctx, state, nil + } + content := fut.content + err := fut.err + fut.mu.Unlock() + if err != nil { + m.onErr(ctx, OnErrorStageTopicSelectionAsync, err) + } + + var msgs []M + if strings.TrimSpace(content) != "" { + msgs = append(msgs, state.Messages...) + msgs = append(msgs, newMemoryMessage[M](content)) + } else { + msgs = state.Messages + } + + fut.mu.Lock() + fut.applied = true + fut.mu.Unlock() + + return ctx, &adk.TypedChatModelAgentState[M]{Messages: msgs}, nil +} + +func applyReadDefaults[M adk.MessageType](cfg *Config[M]) { + if cfg.Read.Mode == "" { + cfg.Read.Mode = ReadModeSync + } + if cfg.Read.Index == nil { + cfg.Read.Index = &IndexConfig{} + } + if cfg.Read.Index.FileName == "" { + cfg.Read.Index.FileName = memoryIndexFileName + } + if cfg.Read.Index.MaxLines <= 0 { + cfg.Read.Index.MaxLines = defaultIndexMaxLines + } + if cfg.Read.Index.MaxBytes <= 0 { + cfg.Read.Index.MaxBytes = defaultIndexMaxBytes + } + if cfg.Read.Model == nil { + cfg.Read.Model = cfg.Model + } + if cfg.Read.TopicSelection == nil { + cfg.Read.TopicSelection = &TopicSelectionConfig{} + } + if cfg.Read.TopicSelection.TopK <= 0 { + cfg.Read.TopicSelection.TopK = defaultTopicTopK + } + if cfg.Read.TopicSelection.CandidateGlob == "" { + cfg.Read.TopicSelection.CandidateGlob = CandidateGlobPattern + } + if cfg.Read.TopicSelection.CandidateLimit <= 0 { + cfg.Read.TopicSelection.CandidateLimit = defaultCandidateLimit + } + if cfg.Read.TopicSelection.CandidatePreviewLines <= 0 { + cfg.Read.TopicSelection.CandidatePreviewLines = defaultCandidatePreviewLine + } + if cfg.Read.TopicSelection.MaxLines <= 0 { + cfg.Read.TopicSelection.MaxLines = defaultTopicMaxLines + } + if cfg.Read.TopicSelection.MaxBytes <= 0 { + cfg.Read.TopicSelection.MaxBytes = defaultTopicMaxBytes + } + + if cfg.Write == nil { + cfg.Write = &WriteConfig{Mode: WriteModeDisabled} + } + if cfg.Write.Mode == "" { + cfg.Write.Mode = WriteModeDisabled + } + if cfg.Write.Model == nil { + cfg.Write.Model = cfg.Model + } + if cfg.Write.MaxTurns <= 0 { + cfg.Write.MaxTurns = defaultMemoryWriteMaxTurns + } + + if cfg.Coordination == nil { + cfg.Coordination = &CoordinationConfig[M]{} + } + if cfg.Coordination.Coordinator == nil { + cfg.Coordination.Coordinator = NewLocalCoordinator() + } + if cfg.Coordination.LockTTL <= 0 { + cfg.Coordination.LockTTL = 2 * time.Minute + } +} + +func cloneConfig[M adk.MessageType](cfg *Config[M]) *Config[M] { + if cfg == nil { + return nil + } + + cp := *cfg + if cfg.Read != nil { + readCopy := *cfg.Read + cp.Read = &readCopy + if cfg.Read.Instruction != nil { + instructionCopy := *cfg.Read.Instruction + cp.Read.Instruction = &instructionCopy + } + if cfg.Read.Index != nil { + indexCopy := *cfg.Read.Index + cp.Read.Index = &indexCopy + } + if cfg.Read.TopicSelection != nil { + topicSelectionCopy := *cfg.Read.TopicSelection + cp.Read.TopicSelection = &topicSelectionCopy + } + } + if cfg.Write != nil { + writeCopy := *cfg.Write + cp.Write = &writeCopy + } + if cfg.Coordination != nil { + coordinationCopy := *cfg.Coordination + cp.Coordination = &coordinationCopy + } + return &cp +} + +type topicSelectionResp struct { + SelectedMemories []string `json:"selected_memories"` +} + +func (m *middleware[M]) injectIndexIntoInstruction(ctx context.Context, baseInstruction string) string { + memDir := m.resolvedMemoryDirectory + + var memDesc string + if m.cfg.Read.Instruction != nil { + memDesc = *m.cfg.Read.Instruction + } else { + s, err := pyfmt.Fmt(getDefaultMemoryInstruction(), map[string]any{"memory_dir": memDir}) + if err != nil { + m.onErr(ctx, OnErrorStageRenderInstruction, err) + return baseInstruction + } + memDesc = s + } + + indexPath := filepath.Join(m.cfg.MemoryDirectory, m.cfg.Read.Index.FileName) + indexContent := "" + totalLines := 0 + + fc, err := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{FilePath: indexPath}) + if err == nil && fc != nil { + indexContent = fc.Content + totalLines = strings.Count(indexContent, "\n") + 1 + } else { + // Missing index is not fatal; keep empty. + indexContent = "" + } + + sb := make([]string, 0, 5) + sb = append(sb, memDesc) + sb = append(sb, "## "+m.cfg.Read.Index.FileName) + if strings.TrimSpace(indexContent) == "" { + sb = append(sb, getAppendEmptyIndexTemplate()) + } else { + truncatedMemoryIndex, _, truncated := linesOrSizeTrunc(indexContent, m.cfg.Read.Index.MaxLines, m.cfg.Read.Index.MaxBytes) + sb = append(sb, truncatedMemoryIndex) + if truncated { + notify, err := pyfmt.Fmt(getAppendCurrentIndexTruncNotify(), map[string]any{ + "memory_lines": totalLines, + }) + if err == nil { + sb = append(sb, notify) + } + } + } + + return baseInstruction + "\n" + instructionMarker + "\n" + strings.Join(sb, "\n") +} + +func linesOrSizeTrunc(content string, lines, size int) (newContent string, reason string, truncated bool) { + linesTrunc := func(content string, lines int) { + sp := strings.Split(content, "\n") + if len(sp) > lines { + newContent = strings.Join(sp[:lines], "\n") + reason = fmt.Sprintf("first %d lines", lines) + truncated = true + } else { + newContent = content + } + } + + sizeTrunc := func(content string, size int) { + if len(content) > size { + newContent = content[:size] + reason = fmt.Sprintf("%d byte limit", size) + truncated = true + } else { + newContent = content + } + } + + if lines == 0 && size == 0 { + return content, "", false + } else if lines == 0 { + sizeTrunc(content, size) + } else if size == 0 { + linesTrunc(content, lines) + } else { + linesTrunc(content, lines) + sizeTrunc(newContent, size) + } + return +} + +func (m *middleware[M]) onErr(ctx context.Context, stage string, err error) { + if err == nil { + return + } + if m.cfg != nil && m.cfg.OnError != nil { + m.cfg.OnError(ctx, stage, err) + } +} + +type topicFrontmatter struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Type string `yaml:"type"` +} + +type topicCandidateBundle struct { + AbsPath string + RelPath string + Info FileInfo +} + +func parseFrontmatter(md string) (fm topicFrontmatter, ok bool) { + // Only consider YAML frontmatter at the beginning. + s := strings.TrimLeft(md, "\ufeff \t\r\n") + if !strings.HasPrefix(s, "---\n") && !strings.HasPrefix(s, "---\r\n") { + return topicFrontmatter{}, false + } + // Find the next delimiter. + parts := strings.SplitN(s, "\n---", 2) + if len(parts) != 2 { + return topicFrontmatter{}, false + } + yml := strings.TrimPrefix(parts[0], "---\n") + if err := yaml.Unmarshal([]byte(yml), &fm); err != nil { + return topicFrontmatter{}, false + } + return fm, true +} + +func (m *middleware[M]) selectAndBuildTopicMemoryMessage(ctx context.Context, agentIn *adk.TypedAgentInput[M]) (M, error) { + last, ok := m.lastUserMessage(agentIn) + if !ok { + return nil, nil + } + + relToBundle, available, orderedRel, err := m.listTopicCandidates(ctx) + if err != nil || len(orderedRel) == 0 { + return nil, err + } + + topK := m.topicSelectionTopK() + selected, err := m.selectTopicCandidates(ctx, agentIn, userMessageTextContent(last), available, orderedRel, relToBundle) + if err != nil || len(selected) == 0 { + return nil, err + } + + rendered := m.renderTopicMemories(ctx, selected, relToBundle, topK) + if len(rendered) == 0 { + return nil, nil + } + + return newMemoryMessage[M]("\n" + strings.Join(rendered, "\n\n")), nil +} + +func (m *middleware[M]) lastUserMessage(agentIn *adk.TypedAgentInput[M]) (M, bool) { + if agentIn == nil || len(agentIn.Messages) == 0 { + return nil, false + } + if m.cfg.Read.TopicSelection == nil || m.topicSelectionModel == nil { + return nil, false + } + last := agentIn.Messages[len(agentIn.Messages)-1] + if isNilMessage(last) || !isUserRole(last) { + return nil, false + } + return last, true +} + +func (m *middleware[M]) listTopicCandidates(ctx context.Context) (map[string]topicCandidateBundle, []string, []string, error) { + candidates, err := m.topicSelectionCandidates(ctx) + if err != nil || len(candidates) == 0 { + return nil, nil, nil, err + } + + relToBundle := make(map[string]topicCandidateBundle, len(candidates)) + available := make([]string, 0, len(candidates)) + orderedRel := make([]string, 0, len(candidates)) + + for _, fi := range candidates { + bundle, manifestLine, ok := m.buildTopicCandidateBundle(ctx, fi) + if !ok { + continue + } + relToBundle[bundle.RelPath] = bundle + available = append(available, manifestLine) + orderedRel = append(orderedRel, bundle.RelPath) + } + + return relToBundle, available, orderedRel, nil +} + +func (m *middleware[M]) topicSelectionCandidates(ctx context.Context) ([]FileInfo, error) { + files, err := m.cfg.MemoryBackend.GlobInfo(ctx, &GlobInfoRequest{ + Pattern: m.cfg.Read.TopicSelection.CandidateGlob, + Path: m.cfg.MemoryDirectory, + }) + if err != nil || len(files) == 0 { + return nil, err + } + + indexAbs := filepath.Join(m.cfg.MemoryDirectory, m.cfg.Read.Index.FileName) + candidates := make([]FileInfo, 0, len(files)) + for _, fi := range files { + if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) { + continue + } + candidates = append(candidates, fi) + } + if len(candidates) == 0 { + return nil, nil + } + + sort.Slice(candidates, func(i, j int) bool { + return parseRFC3339NanoBestEffort(candidates[i].ModifiedAt).After(parseRFC3339NanoBestEffort(candidates[j].ModifiedAt)) + }) + if len(candidates) > m.cfg.Read.TopicSelection.CandidateLimit { + candidates = candidates[:m.cfg.Read.TopicSelection.CandidateLimit] + } + return candidates, nil +} + +func (m *middleware[M]) buildTopicCandidateBundle(ctx context.Context, fi FileInfo) (topicCandidateBundle, string, bool) { + rel, relErr := filepath.Rel(m.cfg.MemoryDirectory, fi.Path) + if relErr != nil { + rel = filepath.Base(fi.Path) + } + rel = filepath.ToSlash(rel) + + preview, err := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{ + FilePath: fi.Path, + Limit: m.cfg.Read.TopicSelection.CandidatePreviewLines, + }) + if err != nil || preview == nil { + return topicCandidateBundle{}, "", false + } + + desc := describeTopicCandidate(preview.Content) + manifestLine := fmt.Sprintf("- %s (saved %s): %s", rel, fi.ModifiedAt, desc) + return topicCandidateBundle{AbsPath: fi.Path, RelPath: rel, Info: fi}, manifestLine, true +} + +func describeTopicCandidate(content string) string { + desc := "" + if fm, ok := parseFrontmatter(content); ok { + switch { + case strings.TrimSpace(fm.Description) != "": + desc = strings.TrimSpace(fm.Description) + case strings.TrimSpace(fm.Name) != "": + desc = strings.TrimSpace(fm.Name) + } + if strings.TrimSpace(fm.Type) != "" { + if desc == "" { + desc = "type=" + strings.TrimSpace(fm.Type) + } else { + desc = desc + " (type=" + strings.TrimSpace(fm.Type) + ")" + } + } + } + if desc == "" { + snippet, _, _ := linesOrSizeTrunc(content, 3, 256) + desc = strings.TrimSpace(snippet) + } + return desc +} + +func (m *middleware[M]) topicSelectionTopK() int { + topK := m.cfg.Read.TopicSelection.TopK + if topK <= 0 { + return defaultTopicTopK + } + return topK +} + +func (m *middleware[M]) selectTopicCandidates( + ctx context.Context, + agentIn *adk.TypedAgentInput[M], + userQuery string, + available []string, + orderedRel []string, + relToBundle map[string]topicCandidateBundle, +) ([]string, error) { + topK := m.topicSelectionTopK() + if len(orderedRel) <= topK { + return orderedRel, nil + } + + userMsg, err := pyfmt.Fmt(getTopicSelectionUserPrompt(), map[string]any{ + "user_query": userQuery, + "available_memories": strings.Join(available, "\n"), + "tools": strings.Join(collectToolNames(agentIn.Messages), ", "), + }) + if err != nil { + return nil, err + } + + toolInfo := topicSelectionToolInfo() + resp, err := m.topicSelectionModel.Generate( + ctx, + []*schema.Message{ + schema.SystemMessage(getTopicSelectionSystemPrompt()), + schema.UserMessage(userMsg), + }, + model.WithToolChoice(schema.ToolChoiceForced, toolInfo.Name), + ) + if err != nil { + return nil, err + } + + valid := make(map[string]struct{}, len(relToBundle)) + for k := range relToBundle { + valid[k] = struct{}{} + } + return parseTopicSelectionFromToolCall(resp, valid) +} + +func collectToolNames[M adk.MessageType](msgs []M) []string { + dedupTools := make(map[string]struct{}) + for _, msg := range msgs { + for _, name := range messageToolNames(msg) { + dedupTools[name] = struct{}{} + } + } + tools := make([]string, 0, len(dedupTools)) + for t := range dedupTools { + tools = append(tools, t) + } + sort.Strings(tools) + return tools +} + +func (m *middleware[M]) renderTopicMemories( + ctx context.Context, + selected []string, + relToBundle map[string]topicCandidateBundle, + topK int, +) []string { + capHint := topK + if capHint > len(selected) { + capHint = len(selected) + } + rendered := make([]string, 0, capHint) + for _, rel := range selected { + if len(rendered) >= topK { + break + } + bundle, ok := relToBundle[rel] + if !ok { + continue + } + renderedContent, ok := m.renderTopicMemory(ctx, bundle) + if !ok { + continue + } + rendered = append(rendered, renderedContent) + } + return rendered +} + +func (m *middleware[M]) renderTopicMemory(ctx context.Context, bundle topicCandidateBundle) (string, bool) { + full, err := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{FilePath: bundle.AbsPath}) + if err != nil || full == nil { + return "", false + } + + content, truncReason, truncated := linesOrSizeTrunc(full.Content, m.cfg.Read.TopicSelection.MaxLines, m.cfg.Read.TopicSelection.MaxBytes) + if truncated { + truncNotify, err := pyfmt.Fmt(getTopicMemoryTruncNotify(), map[string]any{ + "reason": truncReason, + "abs_path": bundle.AbsPath, + }) + if err == nil { + content += truncNotify + } + } + + return fmt.Sprintf( + "\nContents of %s (saved %s):\n\n%s\n", + bundle.AbsPath, + bundle.Info.ModifiedAt, + content, + ), true +} + +func topicSelectionToolInfo() *schema.ToolInfo { + return &schema.ToolInfo{ + Name: topicSelectionToolName, + Desc: "Select which memory files to surface for the current query. Return selected_memories as RELATIVE paths (relative to the memory directory).", + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "selected_memories": { + Type: schema.Array, + Desc: "Relative paths of selected memory files, e.g. \"debugging.md\" or \"notes/patterns.md\".", + Required: true, + ElemInfo: &schema.ParameterInfo{Type: schema.String}, + }, + }), + } +} + +func parseTopicSelectionFromToolCall(msg *schema.Message, valid map[string]struct{}) ([]string, error) { + if msg == nil || len(msg.ToolCalls) == 0 { + return nil, fmt.Errorf("no tool calls") + } + tc := msg.ToolCalls[0] + if tc.Function.Name != topicSelectionToolName { + return nil, fmt.Errorf("unexpected tool call: %s", tc.Function.Name) + } + var parsed topicSelectionResp + if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err != nil { + return nil, err + } + out := normalizeSelected(parsed.SelectedMemories) + // Filter to known candidates to avoid hallucinated paths. + filtered := make([]string, 0, len(out)) + for _, p := range out { + if _, ok := valid[p]; ok { + filtered = append(filtered, p) + } + } + return filtered, nil +} + +func normalizeSelected(in []string) []string { + out := make([]string, 0, len(in)) + seen := make(map[string]struct{}, len(in)) + for _, s := range in { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "./") + s = filepath.ToSlash(s) + if s == "" { + continue + } + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out +} + +func isNilMessage[M adk.MessageType](msg M) bool { + var zero M + return any(msg) == any(zero) +} + +func isUserRole[M adk.MessageType](msg M) bool { + switch m := any(msg).(type) { + case *schema.Message: + return m != nil && m.Role == schema.User + case *schema.AgenticMessage: + return m != nil && m.Role == schema.AgenticRoleTypeUser + default: + panic("unreachable") + } +} + +func isAssistantRole[M adk.MessageType](msg M) bool { + switch m := any(msg).(type) { + case *schema.Message: + return m != nil && m.Role == schema.Assistant + case *schema.AgenticMessage: + return m != nil && m.Role == schema.AgenticRoleTypeAssistant + default: + panic("unreachable") + } +} + +func userMessageTextContent[M adk.MessageType](msg M) string { + switch m := any(msg).(type) { + case *schema.Message: + if m == nil { + return "" + } + if len(m.UserInputMultiContent) == 0 { + return m.Content + } + parts := make([]string, 0, len(m.UserInputMultiContent)) + for _, part := range m.UserInputMultiContent { + if part.Type == schema.ChatMessagePartTypeText && part.Text != "" { + parts = append(parts, part.Text) + } + } + if len(parts) > 0 { + return strings.Join(parts, "\n") + } + return m.Content + case *schema.AgenticMessage: + if m == nil { + return "" + } + parts := make([]string, 0, len(m.ContentBlocks)) + for _, block := range m.ContentBlocks { + if block != nil && block.UserInputText != nil { + parts = append(parts, block.UserInputText.Text) + } + } + return strings.Join(parts, "\n") + default: + panic("unreachable") + } +} + +func getMsgExtra[M adk.MessageType](msg M) map[string]any { + switch m := any(msg).(type) { + case *schema.Message: + if m == nil { + return nil + } + return m.Extra + case *schema.AgenticMessage: + if m == nil { + return nil + } + return m.Extra + default: + panic("unreachable") + } +} + +func setMsgExtra[M adk.MessageType](msg M, key string, value any) { + switch m := any(msg).(type) { + case *schema.Message: + if m.Extra == nil { + m.Extra = map[string]any{} + } + m.Extra[key] = value + case *schema.AgenticMessage: + if m.Extra == nil { + m.Extra = map[string]any{} + } + m.Extra[key] = value + default: + panic("unreachable") + } +} + +func makeUserMsg[M adk.MessageType](text string) M { + var zero M + switch any(zero).(type) { + case *schema.Message: + return any(schema.UserMessage(text)).(M) + case *schema.AgenticMessage: + return any(schema.UserAgenticMessage(text)).(M) + default: + panic("unreachable") + } +} + +func messageToolCalls[M adk.MessageType](msg M) []schema.ToolCall { + switch m := any(msg).(type) { + case *schema.Message: + if m == nil { + return nil + } + return m.ToolCalls + case *schema.AgenticMessage: + if m == nil { + return nil + } + out := make([]schema.ToolCall, 0, len(m.ContentBlocks)) + for _, block := range m.ContentBlocks { + if block == nil || block.FunctionToolCall == nil { + continue + } + out = append(out, schema.ToolCall{ + ID: block.FunctionToolCall.CallID, + Type: "function", + Function: schema.FunctionCall{ + Name: block.FunctionToolCall.Name, + Arguments: block.FunctionToolCall.Arguments, + }, + }) + } + return out + default: + panic("unreachable") + } +} + +func messageToolNames[M adk.MessageType](msg M) []string { + switch m := any(msg).(type) { + case *schema.Message: + if m == nil || m.Role != schema.Tool || m.ToolName == "" { + return nil + } + return []string{m.ToolName} + case *schema.AgenticMessage: + if m == nil { + return nil + } + var out []string + for _, block := range m.ContentBlocks { + if block == nil || block.FunctionToolResult == nil || block.FunctionToolResult.Name == "" { + continue + } + out = append(out, block.FunctionToolResult.Name) + } + return out + default: + panic("unreachable") + } +} + +func projectMessagesToSchema[M adk.MessageType](msgs []M) []adk.Message { + out := make([]adk.Message, 0, len(msgs)) + for _, msg := range msgs { + if projected := projectMessageToSchema(msg); projected != nil { + out = append(out, projected) + } + } + return out +} + +func projectMessageToSchema[M adk.MessageType](msg M) adk.Message { + switch m := any(msg).(type) { + case *schema.Message: + return m + case *schema.AgenticMessage: + if m == nil { + return nil + } + text := m.String() + switch m.Role { + case schema.AgenticRoleTypeSystem: + return schema.SystemMessage(text) + case schema.AgenticRoleTypeAssistant: + return schema.AssistantMessage(text, messageToolCalls(msg)) + case schema.AgenticRoleTypeUser: + return schema.UserMessage(text) + default: + return schema.UserMessage(text) + } + default: + panic("unreachable") + } +} + +func alreadyInjected[M adk.MessageType](msgs []M) bool { + for _, m := range msgs { + if isMemoryMessage(m) { + return true + } + } + return false +} + +func isMemoryMessage[M adk.MessageType](m M) bool { + if isNilMessage(m) || !isUserRole(m) { + return false + } + if extra := getMsgExtra(m); extra != nil { + if v, ok := extra[memoryExtraKey]; ok && v != nil { + return true + } + } + // Backward compatible marker (older versions). + return strings.Contains(userMessageTextContent(m), "") +} + +func hasInstructionInjected(instruction string) bool { + return strings.Contains(instruction, instructionMarker) +} + +func newMemoryMessage[M adk.MessageType](content string) M { + msg := makeUserMsg[M](content) + setMsgExtra(msg, memoryExtraKey, &memoryExtra{Type: "memory"}) + return msg +} + +func ensureMemoryMsgUnchanged[M adk.MessageType](state *adk.TypedChatModelAgentState[M], expectedContent string) *adk.TypedChatModelAgentState[M] { + if state == nil || strings.TrimSpace(expectedContent) == "" { + return state + } + changed := false + out := *state + out.Messages = append([]M{}, state.Messages...) + + for i, m := range out.Messages { + if !isMemoryMessage(m) { + continue + } + extra := getMsgExtra(m) + if userMessageTextContent(m) != expectedContent || extra == nil || extra[memoryExtraKey] == nil { + out.Messages[i] = newMemoryMessage[M](expectedContent) + changed = true + } + } + if !changed { + return state + } + return &out +} + +func extractFilePath(args string) (string, bool) { + var m map[string]any + if err := json.Unmarshal([]byte(args), &m); err != nil { + return "", false + } + if v, ok := m["file_path"]; ok { + if s, ok := v.(string); ok && s != "" { + return s, true + } + } + if v, ok := m["filePath"]; ok { // tolerate camelCase + if s, ok := v.(string); ok && s != "" { + return s, true + } + } + return "", false +} + +func isPathWithinMemoryDir(memDir string, filePath string) bool { + if memDir == "" || filePath == "" { + return false + } + md := filepath.Clean(memDir) + fp := filepath.Clean(filePath) + if !filepath.IsAbs(fp) { + fp = filepath.Join(md, fp) + fp = filepath.Clean(fp) + } + if fp == md { + return true + } + sep := string(filepath.Separator) + return strings.HasPrefix(fp, md+sep) +} + +func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (context.Context, error) { + if m.cfg == nil || m.cfg.Write == nil || m.cfg.Write.Mode == WriteModeDisabled { + return ctx, nil + } + if m.cfg.Write.Model == nil || m.extractionHandler == nil { + return ctx, nil + } + if state == nil || len(state.Messages) == 0 { + return ctx, nil + } + + sessionID, err := m.resolveSessionID(ctx, state) + if err != nil { + m.onErr(ctx, OnErrorStageResolveSessionID, err) + return ctx, nil + } + + cursor := getWriteCursorFromMessages(state.Messages) + if sessionID != "" { + if remoteCursor, ok, err := m.coordination.Coordinator.GetCursor(ctx, sessionID); err == nil && ok && remoteCursor > cursor { + cursor = remoteCursor + state = markWriteCursor(state, cursor) + } + } + if cursor >= len(state.Messages) { + return ctx, nil + } + + // Skip background extraction if the main agent already wrote memory files in this range. + if hasMemoryWritesSince(state.Messages, cursor, m.resolvedMemoryDirectory) { + end := len(state.Messages) + if sessionID != "" { + _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, end) + } + state = markWriteCursor(state, end) + return ctx, nil + } + + if countModelVisibleMessages(state.Messages[cursor:]) == 0 { + end := len(state.Messages) + if sessionID != "" { + _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, end) + } + state = markWriteCursor(state, end) + return ctx, nil + } + + switch m.cfg.Write.Mode { + case WriteModeDisabled: + // do nothing + return ctx, nil + + case WriteModeSync: + end := len(state.Messages) + if err := m.runMemoryExtractionAgent(ctx, state.Messages, cursor, state.ToolInfos); err != nil { + m.onErr(ctx, OnErrorStageMemoryWriteSync, err) + return ctx, nil + } + if sessionID != "" { + _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, end) + } + state = markWriteCursor(state, end) + return ctx, nil + + case WriteModeAsync: + if sessionID == "" { + sessionID = getOrInitWriteSessionID(ctx) + } + snap, err := buildPendingSnapshot(state.Messages, cursor, state.ToolInfos) + if err != nil { + m.onErr(ctx, OnErrorStageSnapshotMarshal, err) + return ctx, nil + } + unlock, ok, err := m.coordination.Coordinator.AcquireLock(ctx, sessionID, m.coordination.LockTTL) + if err != nil { + m.onErr(ctx, OnErrorStageAcquireExtractionLock, err) + return ctx, nil + } + if !ok { + if err := m.coordination.Coordinator.SetPendingSnapshot(ctx, sessionID, snap); err != nil { + m.onErr(ctx, OnErrorStageStashPendingSnapshot, err) + } + return ctx, nil + } + go m.runExtractionDrain(context.Background(), sessionID, unlock, snap) + return ctx, nil + + default: + return ctx, nil + } +} + +func getWriteCursorFromMessages[M adk.MessageType](msgs []M) int { + for i := len(msgs) - 1; i >= 0; i-- { + m := msgs[i] + extra := getMsgExtra(m) + if isNilMessage(m) || extra == nil { + continue + } + v, ok := extra[memoryExtraKey] + if !ok { + continue + } + switch meta := v.(type) { + case *memoryExtra: + if meta != nil && meta.Type == "write_cursor" { + return meta.Cursor + } + case map[string]any: + if typ, _ := meta["type"].(string); typ != "write_cursor" { + continue + } + switch c := meta["cursor"].(type) { + case int: + return c + case int64: + return int(c) + case float64: + return int(c) + } + } + } + return 0 +} + +func markWriteCursor[M adk.MessageType](state *adk.TypedChatModelAgentState[M], cursor int) *adk.TypedChatModelAgentState[M] { + if state == nil || len(state.Messages) == 0 { + return state + } + last := state.Messages[len(state.Messages)-1] + if isNilMessage(last) { + return state + } + + setMsgExtra(last, memoryExtraKey, &memoryExtra{ + Type: "write_cursor", + Cursor: cursor, + UpdatedAt: time.Now().Format(time.RFC3339Nano), + Visibility: "internal", + SchemaVer: 1, + }) + + return state +} + +func countModelVisibleMessages[M adk.MessageType](msgs []M) int { + n := 0 + for _, m := range msgs { + if isNilMessage(m) { + continue + } + if isUserRole(m) || isAssistantRole(m) { + n++ + } + } + return n +} + +func getOrInitWriteSessionID(ctx context.Context) string { + const key = "__automemory_write_session_id__" + if v, ok := adk.GetSessionValue(ctx, key); ok { + if s, ok := v.(string); ok && s != "" { + return s + } + } + // Stable enough for in-process session identity. + s := fmt.Sprintf("%d", time.Now().UnixNano()) + adk.AddSessionValue(ctx, key, s) + return s +} + +func (m *middleware[M]) resolveSessionID(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (string, error) { + if m.coordination != nil && m.coordination.SessionIDFunc != nil { + return m.coordination.SessionIDFunc(ctx, state) + } + return getOrInitWriteSessionID(ctx), nil +} + +func buildPendingSnapshot[M adk.MessageType](messages []M, cursor int, toolInfos []*schema.ToolInfo) (*PendingSnapshot, error) { + raw, err := json.Marshal(messages) + if err != nil { + return nil, err + } + var rawToolInfos json.RawMessage + if toolInfos != nil { + rawToolInfos, err = json.Marshal(toolInfos) + if err != nil { + return nil, err + } + } + return &PendingSnapshot{Cursor: cursor, Messages: raw, ToolInfos: rawToolInfos}, nil +} + +func decodePendingSnapshot[M adk.MessageType](snapshot *PendingSnapshot) ([]M, int, []*schema.ToolInfo, error) { + if snapshot == nil { + return nil, 0, nil, nil + } + var msgs []M + if err := json.Unmarshal(snapshot.Messages, &msgs); err != nil { + return nil, 0, nil, err + } + var toolInfos []*schema.ToolInfo + if len(snapshot.ToolInfos) > 0 { + if err := json.Unmarshal(snapshot.ToolInfos, &toolInfos); err != nil { + return nil, 0, nil, err + } + } + return msgs, snapshot.Cursor, toolInfos, nil +} + +func (m *middleware[M]) runExtractionDrain(ctx context.Context, sessionID string, unlock func(context.Context) error, initial *PendingSnapshot) { + defer func() { + if unlock == nil { + return + } + if err := unlock(ctx); err != nil { + m.onErr(ctx, OnErrorStageReleaseExtractionLock, err) + } + }() + + current := initial + for current != nil { + msgs, cursor, toolInfos, err := decodePendingSnapshot[M](current) + if err != nil { + m.onErr(ctx, OnErrorStageDecodePendingSnapshot, err) + } else if err := m.runMemoryExtractionAgent(ctx, msgs, cursor, toolInfos); err != nil { + m.onErr(ctx, OnErrorStageMemoryWriteAsync, err) + } else { + _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, len(msgs)) + } + + next, loadErr := m.coordination.Coordinator.PopPendingSnapshot(ctx, sessionID) + if loadErr != nil { + m.onErr(ctx, OnErrorStageLoadPendingSnapshot, loadErr) + return + } + current = next + } +} + +func hasMemoryWritesSince[M adk.MessageType](msgs []M, cursor int, memoryDir string) bool { + if cursor < 0 { + cursor = 0 + } + for _, msg := range msgs[cursor:] { + if isNilMessage(msg) || !isAssistantRole(msg) { + continue + } + for _, tc := range messageToolCalls(msg) { + if tc.Function.Name != adkfs.ToolNameWriteFile && tc.Function.Name != adkfs.ToolNameEditFile { + continue + } + if fp, ok := extractFilePath(tc.Function.Arguments); ok && isPathWithinMemoryDir(memoryDir, fp) { + return true + } + } + } + return false +} + +func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int { + if cursor < 0 { + cursor = 0 + } + if cursor >= len(msgs) { + return 0 + } + return countModelVisibleMessages(msgs[cursor:]) +} + +func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.ChatModelAgent, error) { + if m.cfg == nil || m.cfg.Write == nil || m.cfg.Write.Model == nil { + return nil, fmt.Errorf("auto memory extraction agent init failed: missing write model") + } + if m.extractionHandler == nil { + return nil, fmt.Errorf("auto memory extraction agent init failed: missing extraction handler") + } + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "automemory_extractor", + Description: "Internal auto memory extraction subagent", + Model: m.cfg.Write.Model, + Handlers: []adk.ChatModelAgentMiddleware{ + m.extractionHandler, // fs middleware + &toolInfoOverrideMiddleware{toolInfos: toolInfos}, // tool info override, for prefix cache + }, + ToolsConfig: adk.ToolsConfig{ + ToolsNodeConfig: compose.ToolsNodeConfig{ + UnknownToolsHandler: func(ctx context.Context, name, input string) (string, error) { + return "This tool is not allowed to be called. Please follow user prompt to proceed.", nil + }, + }, + EmitInternalEvents: false, + }, + MaxIterations: m.cfg.Write.MaxTurns, + }) + if err != nil { + return nil, fmt.Errorf("auto memory extraction agent init failed: %w", err) + } + return agent, nil +} + +func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot []M, cursor int, toolInfos []*schema.ToolInfo) error { + if len(snapshot) == 0 || cursor >= len(snapshot) { + return nil + } + manifest, err := m.buildMemoryManifest(ctx) + if err != nil { + return err + } + newMessageCount := countModelVisibleMessagesSince(snapshot, cursor) + userPrompt := buildExtractAutoOnlyPrompt(m.resolvedMemoryDirectory, newMessageCount, manifest, m.cfg.Write.SkipIndex) + msgs := append(projectMessagesToSchema(snapshot), schema.UserMessage(userPrompt)) + extractionAgent, err := m.newExtractionAgent(ctx, toolInfos) + if err != nil { + return err + } + + iter := extractionAgent.Run(ctx, &adk.AgentInput{ + Messages: msgs, + EnableStreaming: true, + }) + + if m.cfg != nil && m.cfg.Write != nil && m.cfg.Write.HandleExtractionIterator != nil { + return m.cfg.Write.HandleExtractionIterator(ctx, iter) + } + + for { + ev, ok := iter.Next() + if !ok { + return nil + } + if ev == nil { + continue + } + if ev.Err != nil { + return ev.Err + } + } +} + +func (m *middleware[M]) buildMemoryManifest(ctx context.Context) (string, error) { + files, err := m.cfg.MemoryBackend.GlobInfo(ctx, &GlobInfoRequest{ + Pattern: CandidateGlobPattern, + Path: m.cfg.MemoryDirectory, + }) + if err != nil { + return "", err + } + indexAbs := filepath.Join(m.cfg.MemoryDirectory, m.cfg.Read.Index.FileName) + lines := make([]string, 0, len(files)) + for _, fi := range files { + rel, relErr := filepath.Rel(m.cfg.MemoryDirectory, fi.Path) + if relErr != nil { + rel = filepath.Base(fi.Path) + } + rel = filepath.ToSlash(rel) + if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) { + rel = m.cfg.Read.Index.FileName + } + desc := "" + preview, rerr := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{FilePath: fi.Path, Limit: defaultCandidatePreviewLine}) + if rerr == nil && preview != nil { + if fm, ok := parseFrontmatter(preview.Content); ok { + desc = strings.TrimSpace(fm.Description) + } + } + if desc != "" { + lines = append(lines, fmt.Sprintf("- %s (saved %s): %s", rel, fi.ModifiedAt, desc)) + } else { + lines = append(lines, fmt.Sprintf("- %s (saved %s)", rel, fi.ModifiedAt)) + } + } + return strings.Join(lines, "\n"), nil +} + +func parseRFC3339NanoBestEffort(s string) time.Time { + if s == "" { + return time.Time{} + } + if t, err := time.Parse(time.RFC3339Nano, s); err == nil { + return t + } + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t + } + return time.Time{} +} + +type toolInfoOverrideMiddleware struct { + adk.BaseChatModelAgentMiddleware + + once sync.Once + toolInfos []*schema.ToolInfo +} + +func (t *toolInfoOverrideMiddleware) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[*schema.Message], _ *adk.TypedModelContext[*schema.Message]) ( + context.Context, *adk.TypedChatModelAgentState[*schema.Message], error) { + + t.once.Do(func() { + toolNameMapping := make(map[string]struct{}, len(t.toolInfos)) + for _, tool := range t.toolInfos { + toolNameMapping[tool.Name] = struct{}{} + } + + overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...) + for _, tool := range state.ToolInfos { + if _, ok := toolNameMapping[tool.Name]; !ok { // add fs tools if not exists + overrideTools = append(overrideTools, tool) + } + } + state.ToolInfos = overrideTools + }) + + return ctx, state, nil +} diff --git a/adk/middlewares/automemory/automemory_test.go b/adk/middlewares/automemory/automemory_test.go new file mode 100644 index 000000000..e128d75ec --- /dev/null +++ b/adk/middlewares/automemory/automemory_test.go @@ -0,0 +1,1005 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type fixedModel struct { + out string +} + +func (m *fixedModel) Generate(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage(m.out, nil), nil +} + +func (m *fixedModel) Stream(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, _ := m.Generate(ctx, input) + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *fixedModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +func TestMiddleware_IndexInjection_Empty(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + // Model nil => topic selection disabled. + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}}, + } + + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Contains(t, out.Instruction, "# auto memory") + require.Contains(t, out.Instruction, "## MEMORY.md") + require.Contains(t, out.Instruction, "currently empty") +} + +func TestMiddleware_IndexInjection_ChineseInstruction(t *testing.T) { + require.NoError(t, adk.SetLanguage(adk.LanguageChinese)) + defer func() { + require.NoError(t, adk.SetLanguage(adk.LanguageEnglish)) + }() + + ctx := context.Background() + b := NewInMemoryBackend() + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}}, + } + + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Contains(t, out.Instruction, "# 自动记忆") + require.Contains(t, out.Instruction, "你的 MEMORY.md 当前为空") +} + +func TestNew_DoesNotMutateConfig(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + + cfgNilNested := &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, + } + _, err := New(ctx, cfgNilNested) + require.NoError(t, err) + require.Nil(t, cfgNilNested.Read) + require.Nil(t, cfgNilNested.Write) + require.Nil(t, cfgNilNested.Coordination) + + cfgExplicitNested := &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, + Read: &ReadConfig{}, + Write: &WriteConfig{}, + Coordination: &CoordinationConfig[*schema.Message]{}, + } + _, err = New(ctx, cfgExplicitNested) + require.NoError(t, err) + require.Empty(t, cfgExplicitNested.Read.Mode) + require.Nil(t, cfgExplicitNested.Read.Model) + require.Nil(t, cfgExplicitNested.Read.Index) + require.Nil(t, cfgExplicitNested.Read.TopicSelection) + require.Empty(t, cfgExplicitNested.Write.Mode) + require.Nil(t, cfgExplicitNested.Write.Model) + require.Zero(t, cfgExplicitNested.Write.MaxTurns) + require.Nil(t, cfgExplicitNested.Coordination.Coordinator) + require.Zero(t, cfgExplicitNested.Coordination.LockTTL) +} + +func TestMiddleware_TopicSelection_InsertsMemoryMessage(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + + b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md) - notes\n", now) + b.put("/mem/debugging.md", "---\nname: Debugging\ndescription: build and test commands\ntype: project\n---\n\n# Debugging\npnpm test\n", now) + b.put("/mem/other.md", "---\nname: Other\ndescription: unrelated\ntype: misc\n---\n", now.Add(-time.Hour)) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, + }) + require.NoError(t, err) + + in := &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("How to run tests?")}} + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: in, + } + + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.NotNil(t, out.AgentInput) + require.Len(t, out.AgentInput.Messages, 2) + require.Equal(t, schema.User, out.AgentInput.Messages[0].Role) + require.Contains(t, out.AgentInput.Messages[0].Content, "How to run tests?") + require.Contains(t, out.AgentInput.Messages[1].Content, "") + require.NotNil(t, out.AgentInput.Messages[1].Extra) + require.NotNil(t, out.AgentInput.Messages[1].Extra["__eino_automemory__"]) + require.Contains(t, out.AgentInput.Messages[1].Content, "Contents of /mem/debugging.md") +} + +func TestMiddleware_TopicSelection_AsyncInjectsInBeforeModel(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + + b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md) - notes\n", now) + b.put("/mem/debugging.md", "---\nname: Debugging\ndescription: build and test commands\ntype: project\n---\n\n# Debugging\npnpm test\n", now) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, + Read: &ReadConfig{Mode: ReadModeAsync}, + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("How to run tests?")}}, + } + ctx2, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Len(t, out.AgentInput.Messages, 1) // async doesn't inject here + + st := &adk.ChatModelAgentState{Messages: []adk.Message{schema.UserMessage("How to run tests?")}} + + require.Eventually(t, func() bool { + _, next, err := mw.BeforeModelRewriteState(ctx2, st, nil) + require.NoError(t, err) + st = next + last := st.Messages[len(st.Messages)-1] + return len(st.Messages) == 2 && last.Extra != nil && last.Extra["__eino_automemory__"] != nil + }, 2*time.Second, 10*time.Millisecond) +} + +type panicModel struct{} + +func (m *panicModel) Generate(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + panic("should not call model") +} + +func (m *panicModel) Stream(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("should not call model") +} + +func (m *panicModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type toolCallSelectionModel struct { + calls int32 +} + +func (m *toolCallSelectionModel) Generate(_ context.Context, _ []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.calls, 1) + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "select-1", + Type: "function", + Function: schema.FunctionCall{ + Name: topicSelectionToolName, + Arguments: `{"selected_memories":["debugging.md","hallucinated.md"]}`, + }, + }, + }), nil +} + +func (m *toolCallSelectionModel) Stream(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *toolCallSelectionModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +type extractionModel struct { + mu sync.Mutex + promptSeen []string + boundToolCalls [][]string + blockFirstRun chan struct{} + firstRunStarted chan struct{} + blockedOnce uint32 // atomic (0/1) + generateCallings int32 +} + +type countingBackend struct { + *InMemoryBackend + writeCalls int32 + mu sync.Mutex + paths []string +} + +func (b *countingBackend) Write(ctx context.Context, req *WriteRequest) error { + atomic.AddInt32(&b.writeCalls, 1) + b.mu.Lock() + b.paths = append(b.paths, req.FilePath) + b.mu.Unlock() + return b.InMemoryBackend.Write(ctx, req) +} + +func (m *extractionModel) Generate(_ context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) { + atomic.AddInt32(&m.generateCallings, 1) + promptIdx := findExtractionPromptIndex(input) + if promptIdx < 0 { + return nil, fmt.Errorf("missing extraction prompt") + } + + m.mu.Lock() + m.promptSeen = append(m.promptSeen, input[promptIdx].Content) + m.mu.Unlock() + + if hasToolMessageAfter(input, promptIdx) { + return schema.AssistantMessage("done", nil), nil + } + + if m.blockFirstRun != nil && atomic.SwapUint32(&m.blockedOnce, 1) == 0 { + if m.firstRunStarted != nil { + close(m.firstRunStarted) + } + <-m.blockFirstRun + } + + payload := lastBusinessUserBeforePrompt(input, promptIdx) + return schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "write-topic", + Type: "function", + Function: schema.FunctionCall{ + Name: "write_file", + Arguments: fmt.Sprintf(`{"file_path":"topic.md","content":%q}`, payload), + }, + }, + { + ID: "write-index", + Type: "function", + Function: schema.FunctionCall{ + Name: "write_file", + Arguments: `{"file_path":"MEMORY.md","content":"- [topic.md](topic.md)\n"}`, + }, + }, + }), nil +} + +func (m *extractionModel) Stream(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) { + msg, err := m.Generate(ctx, input) + if err != nil { + return nil, err + } + return schema.StreamReaderFromArray([]*schema.Message{msg}), nil +} + +func (m *extractionModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + names := make([]string, 0, len(tools)) + for _, ti := range tools { + if ti == nil { + continue + } + names = append(names, ti.Name) + } + m.mu.Lock() + m.boundToolCalls = append(m.boundToolCalls, names) + m.mu.Unlock() + return m, nil +} + +func findExtractionPromptIndex(input []*schema.Message) int { + for i := len(input) - 1; i >= 0; i-- { + if input[i] != nil && input[i].Role == schema.User && strings.Contains(input[i].Content, "memory extraction subagent") { + return i + } + } + return -1 +} + +func hasToolMessageAfter(input []*schema.Message, idx int) bool { + for i := idx + 1; i < len(input); i++ { + if input[i] != nil && input[i].Role == schema.Tool { + switch input[i].ToolName { + case "read_file", "glob", "write_file", "edit_file": + return true + default: + } + } + } + return false +} + +func lastBusinessUserBeforePrompt(input []*schema.Message, promptIdx int) string { + for i := promptIdx - 1; i >= 0; i-- { + if input[i] == nil || input[i].Role != schema.User { + continue + } + if strings.Contains(input[i].Content, "") { + continue + } + return input[i].Content + } + return "unknown" +} + +func TestMiddleware_TopicSelection_SmallCandidateSetBypassesModel(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + + b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md)\n- [patterns.md](patterns.md)\n", now) + b.put("/mem/debugging.md", "---\ndescription: debug notes\n---\nbody\n", now) + b.put("/mem/patterns.md", "---\ndescription: patterns\n---\nbody\n", now) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: &panicModel{}, + Read: &ReadConfig{ + Mode: ReadModeSync, + TopicSelection: &TopicSelectionConfig{ + TopK: 5, + }, + }, + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("How to run tests?")}}, + } + + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Len(t, out.AgentInput.Messages, 2) + require.Contains(t, out.AgentInput.Messages[1].Content, "debugging.md") + require.Contains(t, out.AgentInput.Messages[1].Content, "patterns.md") +} + +func TestMiddleware_AfterAgent_SyncExtractionWritesMemoryFiles(t *testing.T) { + ctx := context.Background() + b := &countingBackend{InMemoryBackend: NewInMemoryBackend()} + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + extModel := &extractionModel{} + var onErrStages []string + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeSync, + Model: extModel, + }, + OnError: func(ctx context.Context, stage string, err error) { + onErrStages = append(onErrStages, stage) + }, + }) + require.NoError(t, err) + + state := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember alpha"), + schema.AssistantMessage("ack", nil), + }, + } + + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{ + Messages: state.Messages, + ToolInfos: []*schema.ToolInfo{ + {Name: "tool_b"}, + {Name: "tool_a"}, + }, + }) + require.NoError(t, err) + require.Empty(t, onErrStages) + require.Equal(t, len(state.Messages), getWriteCursorFromMessages(state.Messages)) + require.GreaterOrEqual(t, atomic.LoadInt32(&extModel.generateCallings), int32(1)) + require.GreaterOrEqual(t, atomic.LoadInt32(&b.writeCalls), int32(1)) + b.mu.Lock() + paths := append([]string(nil), b.paths...) + b.mu.Unlock() + require.NotEmpty(t, paths) + require.Contains(t, paths, "/mem/topic.md") + require.Contains(t, paths, "/mem/MEMORY.md") + + mem, err := b.Read(ctx, &ReadRequest{FilePath: "/mem/MEMORY.md"}) + require.NoError(t, err) + require.Contains(t, mem.Content, "topic.md") + + topic, err := b.Read(ctx, &ReadRequest{FilePath: "/mem/topic.md"}) + require.NoError(t, err) + require.Equal(t, "remember alpha", topic.Content) + + extModel.mu.Lock() + defer extModel.mu.Unlock() + require.NotEmpty(t, extModel.promptSeen) + require.Contains(t, extModel.promptSeen[0], "memory extraction subagent") + require.Contains(t, extModel.promptSeen[0], "Memory directory: /mem") +} + +func TestMiddleware_AfterAgent_SyncExtraction_IteratorHandlerCanDrain(t *testing.T) { + ctx := context.Background() + b := &countingBackend{InMemoryBackend: NewInMemoryBackend()} + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + extModel := &extractionModel{} + var seen int32 + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeSync, + Model: extModel, + HandleExtractionIterator: func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error { + for { + ev, ok := iter.Next() + if !ok { + return nil + } + if ev == nil { + continue + } + atomic.AddInt32(&seen, 1) + if ev.Err != nil { + return ev.Err + } + } + }, + }, + }) + require.NoError(t, err) + + state := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember handler"), + schema.AssistantMessage("ack", nil), + }, + } + + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{ + Messages: state.Messages, + ToolInfos: []*schema.ToolInfo{ + {Name: "tool_1"}, + }, + }) + require.NoError(t, err) + require.Greater(t, atomic.LoadInt32(&seen), int32(0)) + + // Still writes memory files as usual (handler only changes event draining). + _, err = b.Read(ctx, &ReadRequest{FilePath: "/mem/topic.md"}) + require.NoError(t, err) +} + +func TestMiddleware_AfterAgent_SkipsExtractionWhenMainAgentAlreadyWroteMemory(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + extModel := &extractionModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeSync, + Model: extModel, + }, + }) + require.NoError(t, err) + + state := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember beta"), + schema.AssistantMessage("", []schema.ToolCall{ + { + ID: "call-1", + Type: "function", + Function: schema.FunctionCall{ + Name: "write_file", + Arguments: `{"file_path":"/mem/topic.md","content":"written by main agent"}`, + }, + }, + }), + schema.ToolMessage("ok", "call-1", schema.WithToolName("write_file")), + }, + } + + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{Messages: state.Messages}) + require.NoError(t, err) + require.Equal(t, len(state.Messages), getWriteCursorFromMessages(state.Messages)) + require.EqualValues(t, 0, atomic.LoadInt32(&extModel.generateCallings)) + + _, err = b.Read(ctx, &ReadRequest{FilePath: "/mem/topic.md"}) + require.Error(t, err) +} + +func TestMiddleware_AfterAgent_AsyncExtractionKeepsLatestPendingSnapshot(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + blockCh := make(chan struct{}) + startedCh := make(chan struct{}) + extModel := &extractionModel{ + blockFirstRun: blockCh, + firstRunStarted: startedCh, + } + coord := &CoordinationConfig[*schema.Message]{ + SessionIDFunc: func(ctx context.Context, state *adk.ChatModelAgentState) (string, error) { + return "session-1", nil + }, + Coordinator: NewLocalCoordinator(), + LockTTL: time.Minute, + } + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeAsync, + Model: extModel, + }, + Coordination: coord, + }) + require.NoError(t, err) + + state1 := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember one"), + schema.AssistantMessage("ack1", nil), + }, + } + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{ + Messages: state1.Messages, + ToolInfos: []*schema.ToolInfo{ + {Name: "tool_one"}, + }, + }) + require.NoError(t, err) + + <-startedCh + + state2 := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember one"), + schema.AssistantMessage("ack1", nil), + schema.UserMessage("remember two"), + schema.AssistantMessage("ack2", nil), + }, + } + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{ + Messages: state2.Messages, + ToolInfos: []*schema.ToolInfo{ + {Name: "tool_one"}, + {Name: "tool_two"}, + }, + }) + require.NoError(t, err) + + close(blockCh) + + require.Eventually(t, func() bool { + topic, readErr := b.Read(ctx, &ReadRequest{FilePath: "/mem/topic.md"}) + if readErr != nil || topic == nil || topic.Content != "remember two" { + return false + } + cursor, ok, cursorErr := coord.Coordinator.GetCursor(ctx, "session-1") + if cursorErr != nil || !ok { + return false + } + return cursor == len(state2.Messages) + }, 2*time.Second, 10*time.Millisecond) +} + +func TestMiddleware_BeforeAgent_InstructionIdempotent_NoTopicMemory(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "line1\nline2\n", now) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + // No topic selection model. + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}}, + } + + _, out1, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Contains(t, out1.Instruction, instructionMarker) + + // Call again with the already-injected instruction; should not duplicate. + _, out2, err := mw.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: out1.Instruction, + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi again")}}, + }) + require.NoError(t, err) + require.Equal(t, 1, strings.Count(out2.Instruction, instructionMarker)) +} + +func TestMiddleware_BeforeAgent_SkipsWhenMessagesAlreadyContainMemory(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + }) + require.NoError(t, err) + + memMsg := newMemoryMessage[*schema.Message]("\npreloaded") + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi"), memMsg}}, + } + + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Equal(t, "base", out.Instruction) + require.Len(t, out.AgentInput.Messages, 2) +} + +func TestMiddleware_BeforeAgent_DistributedCursorSyncIntoMessageExtra(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + coord := &CoordinationConfig[*schema.Message]{ + SessionIDFunc: func(ctx context.Context, state *adk.ChatModelAgentState) (string, error) { + return "sess-cursor", nil + }, + Coordinator: NewLocalCoordinator(), + LockTTL: time.Minute, + } + require.NoError(t, coord.Coordinator.SetCursor(ctx, "sess-cursor", 5)) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Coordination: coord, + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{ + schema.UserMessage("hi"), + schema.AssistantMessage("ack", nil), + }}, + } + + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + last := out.AgentInput.Messages[len(out.AgentInput.Messages)-1] + require.NotNil(t, last.Extra) + meta, ok := last.Extra[memoryExtraKey].(*memoryExtra) + require.True(t, ok) + require.Equal(t, "write_cursor", meta.Type) + require.EqualValues(t, 5, meta.Cursor) +} + +func TestMiddleware_TopicSelection_ToolCallParsingAndFiltering(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md)\n", now) + b.put("/mem/debugging.md", "---\ndescription: debug notes\n---\nbody\n", now) + b.put("/mem/other.md", "---\ndescription: other\n---\nbody\n", now.Add(-time.Hour)) + + selModel := &toolCallSelectionModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: selModel, + Read: &ReadConfig{ + Mode: ReadModeSync, + TopicSelection: &TopicSelectionConfig{ + TopK: 1, + }, + }, + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("How to debug?")}}, + } + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Len(t, out.AgentInput.Messages, 2) + mem := out.AgentInput.Messages[1] + require.Contains(t, mem.Content, "Contents of /mem/debugging.md") + require.NotContains(t, mem.Content, "hallucinated.md") + require.EqualValues(t, 1, atomic.LoadInt32(&selModel.calls)) +} + +func TestMiddleware_TopicSelection_AsyncProtectsMemoryMessageFromMutation(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md)\n", now) + b.put("/mem/debugging.md", "---\ndescription: debug notes\n---\nbody\n", now) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, + Read: &ReadConfig{Mode: ReadModeAsync}, + }) + require.NoError(t, err) + + ctx2, _, err := mw.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}}, + }) + require.NoError(t, err) + + st := &adk.ChatModelAgentState{Messages: []adk.Message{schema.UserMessage("hi")}} + + var expected string + require.Eventually(t, func() bool { + _, next, callErr := mw.BeforeModelRewriteState(ctx2, st, nil) + require.NoError(t, callErr) + st = next + if len(st.Messages) < 2 { + return false + } + expected = st.Messages[len(st.Messages)-1].Content + return strings.Contains(expected, "") + }, 2*time.Second, 10*time.Millisecond) + + // Mutate the memory message content. + st.Messages[len(st.Messages)-1].Content = "tampered" + _, next, err := mw.BeforeModelRewriteState(ctx2, st, nil) + require.NoError(t, err) + require.Equal(t, expected, next.Messages[len(next.Messages)-1].Content) + require.NotNil(t, next.Messages[len(next.Messages)-1].Extra[memoryExtraKey]) +} + +func TestMiddleware_AfterAgent_SyncExtraction_SkipIndexPrompt(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + extModel := &extractionModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeSync, + Model: extModel, + SkipIndex: true, + }, + }) + require.NoError(t, err) + + state := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember gamma"), + schema.AssistantMessage("ack", nil), + }, + } + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{Messages: state.Messages}) + require.NoError(t, err) + + extModel.mu.Lock() + defer extModel.mu.Unlock() + require.NotEmpty(t, extModel.promptSeen) + require.NotContains(t, extModel.promptSeen[0], "Step 2") +} + +func TestMiddleware_AfterAgent_SyncExtraction_ChinesePrompt(t *testing.T) { + require.NoError(t, adk.SetLanguage(adk.LanguageChinese)) + defer func() { + require.NoError(t, adk.SetLanguage(adk.LanguageEnglish)) + }() + + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + extModel := &extractionModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeSync, + Model: extModel, + }, + }) + require.NoError(t, err) + + state := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember chinese"), + schema.AssistantMessage("ack", nil), + }, + } + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{Messages: state.Messages}) + require.NoError(t, err) + + extModel.mu.Lock() + defer extModel.mu.Unlock() + require.NotEmpty(t, extModel.promptSeen) + require.Contains(t, extModel.promptSeen[0], "你现在扮演 memory extraction subagent") + require.Contains(t, extModel.promptSeen[0], "记忆目录:/mem") +} + +func TestMiddleware_AfterAgent_RelativeMemoryDirRendersAbsolutePath(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + oldwd, err := os.Getwd() + require.NoError(t, err) + require.NoError(t, os.Chdir(tmp)) + defer func() { + _ = os.Chdir(oldwd) + }() + + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte(""), 0o644)) + expectedDir, err := filepath.Abs(".") + require.NoError(t, err) + + extModel := &extractionModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: ".", + MemoryBackend: NewLocalBackend(), + Write: &WriteConfig{ + Mode: WriteModeSync, + Model: extModel, + }, + }) + require.NoError(t, err) + + state := &adk.TypedChatModelAgentState[*schema.Message]{ + Messages: []adk.Message{ + schema.UserMessage("remember relative"), + schema.AssistantMessage("ack", nil), + }, + } + _, err = mw.AfterAgent(ctx, state) + require.NoError(t, err) + + extModel.mu.Lock() + require.NotEmpty(t, extModel.promptSeen) + require.Contains(t, extModel.promptSeen[0], "Memory directory: "+expectedDir) + extModel.mu.Unlock() + + raw, err := os.ReadFile(filepath.Join(expectedDir, "topic.md")) + require.NoError(t, err) + require.Equal(t, "remember relative", string(raw)) +} + +func TestFSBackend_ReadMissingFileReturnsContentInsteadOfError(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + + fs, err := newFSBackend(NewLocalBackend(), tmp) + require.NoError(t, err) + + content, err := fs.Read(ctx, &ReadRequest{FilePath: "missing.md"}) + require.NoError(t, err) + require.NotNil(t, content) + require.Contains(t, content.Content, "File not found:") + require.Contains(t, content.Content, filepath.Join(tmp, "missing.md")) +} + +func TestMiddleware_AfterAgent_AsyncSetsPendingSnapshotWhenLockHeld(t *testing.T) { + ctx := context.Background() + b := NewInMemoryBackend() + now := time.Now() + b.put("/mem/MEMORY.md", "", now) + + extModel := &extractionModel{} + coord := &CoordinationConfig[*schema.Message]{ + SessionIDFunc: func(ctx context.Context, state *adk.ChatModelAgentState) (string, error) { + return "sess-pending", nil + }, + Coordinator: NewLocalCoordinator(), + LockTTL: time.Minute, + } + // Hold the lock. + unlock, ok, err := coord.Coordinator.AcquireLock(ctx, "sess-pending", time.Minute) + require.NoError(t, err) + require.True(t, ok) + + mwI, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: b, + Write: &WriteConfig{ + Mode: WriteModeAsync, + Model: extModel, + }, + Coordination: coord, + }) + require.NoError(t, err) + mw := mwI.(*middleware[*schema.Message]) + + state := &adk.ChatModelAgentState{ + Messages: []adk.Message{ + schema.UserMessage("remember pending"), + schema.AssistantMessage("ack", nil), + }, + } + _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{ + Messages: state.Messages, + ToolInfos: []*schema.ToolInfo{ + {Name: "pending_tool"}, + }, + }) + require.NoError(t, err) + + pending, err := coord.Coordinator.PopPendingSnapshot(ctx, "sess-pending") + require.NoError(t, err) + require.NotNil(t, pending) + + // Release and drain manually to complete write synchronously in test. + require.NoError(t, unlock(ctx)) + unlock2, ok, err := coord.Coordinator.AcquireLock(ctx, "sess-pending", time.Minute) + require.NoError(t, err) + require.True(t, ok) + mw.runExtractionDrain(ctx, "sess-pending", unlock2, pending) + + topic, err := b.Read(ctx, &ReadRequest{FilePath: "/mem/topic.md"}) + require.NoError(t, err) + require.Equal(t, "remember pending", topic.Content) +} diff --git a/adk/middlewares/automemory/backend.go b/adk/middlewares/automemory/backend.go new file mode 100644 index 000000000..bf58e4e8f --- /dev/null +++ b/adk/middlewares/automemory/backend.go @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +import ( + "github.com/cloudwego/eino/adk/filesystem" + ainternal "github.com/cloudwego/eino/adk/middlewares/automemory/internal" +) + +// Backend is the only filesystem storage abstraction users need to implement +// for automemory and dream. +// +// It intentionally exposes only the capabilities required by memory loading and +// consolidation: Read, GlobInfo, Write, and Edit. +// +// LocalBackend and InMemoryBackend both implement this interface. +type Backend = ainternal.Backend + +type ReadRequest = filesystem.ReadRequest +type FileContent = filesystem.FileContent +type GlobInfoRequest = filesystem.GlobInfoRequest +type FileInfo = filesystem.FileInfo +type WriteRequest = filesystem.WriteRequest +type EditRequest = filesystem.EditRequest + +func newFSBackend(backend Backend, baseDir string) (ainternal.Backend, error) { + return ainternal.NewFSBackend(backend, ainternal.FSBackendConfig{ + BaseDir: baseDir, + NotFoundAsContent: true, + ErrorPrefix: "fs backend", + }) +} diff --git a/adk/middlewares/automemory/consts.go b/adk/middlewares/automemory/consts.go new file mode 100644 index 000000000..dbdc27efe --- /dev/null +++ b/adk/middlewares/automemory/consts.go @@ -0,0 +1,55 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +const ( + // CandidateGlobPattern matches topic files under the memory directory. + CandidateGlobPattern = "**/*.md" + + memoryIndexFileName = "MEMORY.md" + + defaultIndexMaxLines = 200 + defaultIndexMaxBytes = 4 * 1024 + + defaultCandidateLimit = 200 + defaultCandidatePreviewLine = 30 + + defaultTopicTopK = 5 + defaultTopicMaxLines = 200 + defaultTopicMaxBytes = 4 * 1024 + + defaultMemoryWriteMaxTurns = 5 + + topicSelectionToolName = "select_memories" +) + +// OnError stage constants. These values are stable identifiers used to report +// best-effort failures through Config.OnError. +const ( + OnErrorStageTopicSelectionSync = "topic_selection_sync" + OnErrorStageTopicSelectionAsync = "topic_selection_async" + OnErrorStageRenderInstruction = "render_instruction" + OnErrorStageResolveSessionID = "resolve_session_id" + OnErrorStageMemoryWriteSync = "memory_write_sync" + OnErrorStageSnapshotMarshal = "snapshot_marshal" + OnErrorStageAcquireExtractionLock = "acquire_extraction_lock" + OnErrorStageStashPendingSnapshot = "stash_pending_snapshot" + OnErrorStageReleaseExtractionLock = "release_extraction_lock" + OnErrorStageDecodePendingSnapshot = "decode_pending_snapshot" + OnErrorStageMemoryWriteAsync = "memory_write_async" + OnErrorStageLoadPendingSnapshot = "load_pending_snapshot" +) diff --git a/adk/middlewares/automemory/coordinator.go b/adk/middlewares/automemory/coordinator.go new file mode 100644 index 000000000..c784c50c4 --- /dev/null +++ b/adk/middlewares/automemory/coordinator.go @@ -0,0 +1,162 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/cloudwego/eino/adk" +) + +type SessionIDFunc[M adk.MessageType] func(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (string, error) + +// Coordinator abstracts distributed coordination for async memory extraction. +// A Redis-backed implementation can map these methods to SETNX + TTL and plain KV get/set. +type Coordinator interface { + // AcquireLock tries to acquire a lock for a given session. When ok==true, + // it returns an unlock function that must be called exactly once. + AcquireLock(ctx context.Context, sessionID string, ttl time.Duration) (unlock func(context.Context) error, ok bool, err error) + + // PopPendingSnapshot returns and deletes the pending snapshot for a session. + // If there is no pending snapshot, it returns (nil, nil). + PopPendingSnapshot(ctx context.Context, sessionID string) (*PendingSnapshot, error) + SetPendingSnapshot(ctx context.Context, sessionID string, snapshot *PendingSnapshot) error + + GetCursor(ctx context.Context, sessionID string) (cursor int, ok bool, err error) + SetCursor(ctx context.Context, sessionID string, cursor int) error +} + +type PendingSnapshot struct { + Cursor int `json:"cursor"` + Messages json.RawMessage `json:"messages"` + ToolInfos json.RawMessage `json:"tool_infos,omitempty"` +} + +type CoordinationConfig[M adk.MessageType] struct { + SessionIDFunc SessionIDFunc[M] + Coordinator Coordinator + LockTTL time.Duration +} + +// LocalCoordinator is the default in-process coordinator used in tests and single-instance deployments. +// For distributed deployments, provide a Coordinator backed by Redis or another shared KV. +type LocalCoordinator struct { + mu sync.Mutex + locks map[string]localLock + pending map[string]*PendingSnapshot + cursor map[string]int +} + +type localLock struct { + token string + expiry time.Time +} + +// NewLocalCoordinator returns the default in-process Coordinator implementation. +func NewLocalCoordinator() *LocalCoordinator { + return &LocalCoordinator{ + locks: map[string]localLock{}, + pending: map[string]*PendingSnapshot{}, + cursor: map[string]int{}, + } +} + +func (c *LocalCoordinator) AcquireLock(_ context.Context, sessionID string, ttl time.Duration) (func(context.Context) error, bool, error) { + c.mu.Lock() + defer c.mu.Unlock() + now := time.Now() + if l, ok := c.locks[sessionID]; ok && now.Before(l.expiry) { + return nil, false, nil + } + token := randToken() + c.locks[sessionID] = localLock{token: token, expiry: now.Add(ttl)} + return func(_ context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + l, ok := c.locks[sessionID] + if !ok { + return nil + } + if l.token != token { + return fmt.Errorf("lock token mismatch") + } + delete(c.locks, sessionID) + return nil + }, true, nil +} + +func (c *LocalCoordinator) PopPendingSnapshot(_ context.Context, sessionID string) (*PendingSnapshot, error) { + c.mu.Lock() + defer c.mu.Unlock() + s, ok := c.pending[sessionID] + if !ok || s == nil { + return nil, nil + } + cp := *s + if s.Messages != nil { + cp.Messages = append([]byte(nil), s.Messages...) + } + if s.ToolInfos != nil { + cp.ToolInfos = append([]byte(nil), s.ToolInfos...) + } + delete(c.pending, sessionID) + return &cp, nil +} + +func (c *LocalCoordinator) SetPendingSnapshot(_ context.Context, sessionID string, snapshot *PendingSnapshot) error { + c.mu.Lock() + defer c.mu.Unlock() + if snapshot == nil { + delete(c.pending, sessionID) + return nil + } + cp := *snapshot + if snapshot.Messages != nil { + cp.Messages = append([]byte(nil), snapshot.Messages...) + } + if snapshot.ToolInfos != nil { + cp.ToolInfos = append([]byte(nil), snapshot.ToolInfos...) + } + c.pending[sessionID] = &cp + return nil +} + +func (c *LocalCoordinator) GetCursor(_ context.Context, sessionID string) (int, bool, error) { + c.mu.Lock() + defer c.mu.Unlock() + v, ok := c.cursor[sessionID] + return v, ok, nil +} + +func (c *LocalCoordinator) SetCursor(_ context.Context, sessionID string, cursor int) error { + c.mu.Lock() + defer c.mu.Unlock() + c.cursor[sessionID] = cursor + return nil +} + +func randToken() string { + var b [8]byte + _, _ = rand.Read(b[:]) + return hex.EncodeToString(b[:]) +} diff --git a/adk/middlewares/automemory/dream/config.go b/adk/middlewares/automemory/dream/config.go new file mode 100644 index 000000000..193270363 --- /dev/null +++ b/adk/middlewares/automemory/dream/config.go @@ -0,0 +1,181 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package dream provides scheduled consolidation middleware built on top of +// automemory-managed session files. +package dream + +import ( + "context" + "fmt" + "time" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/middlewares/automemory" + "github.com/cloudwego/eino/components/model" +) + +const ( + defaultSessionKey = "__eino_automemory_dream_session_id__" + defaultMinInterval = 24 * time.Hour + defaultMinTouchedSession = 5 + defaultScanInterval = 10 * time.Minute + defaultLockTTL = time.Hour +) + +// OnError handles non-fatal dream errors. +// Optional. Nil means ignore the error. +type OnError func(ctx context.Context, stage string, err error) + +// HandleIterator handles the dream sub-agent event stream. +// Optional. Nil means dream drains the iterator itself. +type HandleIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error + +// Config configures auto dream for both `New(...)` and `Run(...)`. +type Config[M adk.MessageType] struct { + // MemoryDirectory is the memory root directory. + // Required. Relative paths are resolved during init. + MemoryDirectory string + + // MemoryBackend reads and updates memory files. + // Required. + MemoryBackend automemory.Backend + + // Model is the tool-calling model used by the internal dream agent. + // Required. + Model model.ToolCallingChatModel + + // SessionIDFunc resolves the current session ID. + // Optional. Default: a generated session-scoped ID. + SessionIDFunc automemory.SessionIDFunc[M] + + // OnError handles non-fatal runtime errors. + // Optional. Default: nil. + OnError OnError + + // SessionStore enables session timeline lookup through `grep_session_history`. + // Optional. Default: nil. + // + // When nil, dream consolidates using only memory files plus scheduler-provided + // touch signals; no session-history search tool is exposed to the model. + // + // When set, dream exposes `grep_session_history` for the sessions included in + // the current run scope: + // - middleware-triggered runs search the touched sessions selected by the scheduler + // - manual `Run(...)` searches the provided/current session only + SessionStore *SessionStoreProvider[M] + + // Schedule controls middleware-triggered runs only. + // Optional. `Run(...)` ignores it. + Schedule *ScheduleConfig + + // HandleIterator overrides iterator consumption. + // Optional. Default: nil. + HandleIterator HandleIterator +} + +// ScheduleConfig controls middleware-triggered runs. +type ScheduleConfig struct { + // MinInterval is the minimum interval between successful runs. + // Optional. Default: 24h. + MinInterval time.Duration + + // MinTouchedSession is the minimum touched-session count before a run. + // Optional. Default: 5. + MinTouchedSession int + + // ScanInterval is the retry delay when the session threshold is not met. + // Optional. Default: 10m. + ScanInterval time.Duration + + // LockTTL is the lease for the per-memory-directory run lock. + // Optional. Default: 1h. + LockTTL time.Duration + + // Store persists touched sessions, schedule state, and run locks. + // Optional. Default: in-process `LocalStore`. + Store Store + + // RunInline runs triggered dreams in the `AfterAgent` call path. + // Optional. Default: false. + RunInline bool +} + +func applyCoreDefaults[M adk.MessageType](cfg *Config[M]) error { + if cfg == nil { + return fmt.Errorf("auto dream config: nil") + } + if cfg.MemoryDirectory == "" || cfg.MemoryBackend == nil || cfg.Model == nil { + return fmt.Errorf("auto dream config: invalid") + } + if cfg.SessionIDFunc == nil { + cfg.SessionIDFunc = defaultSessionIDFunc[M] + } + return nil +} + +func cloneConfig[M adk.MessageType](cfg *Config[M]) *Config[M] { + if cfg == nil { + return nil + } + + cp := *cfg + if cfg.SessionStore != nil { + sessionStoreCopy := *cfg.SessionStore + cp.SessionStore = &sessionStoreCopy + } + if cfg.Schedule != nil { + scheduleCopy := *cfg.Schedule + cp.Schedule = &scheduleCopy + } + return &cp +} + +func applyScheduleDefaults[M adk.MessageType](cfg *Config[M]) error { + if err := applyCoreDefaults(cfg); err != nil { + return err + } + if cfg.Schedule == nil { + cfg.Schedule = &ScheduleConfig{} + } + if cfg.Schedule.MinInterval <= 0 { + cfg.Schedule.MinInterval = defaultMinInterval + } + if cfg.Schedule.MinTouchedSession <= 0 { + cfg.Schedule.MinTouchedSession = defaultMinTouchedSession + } + if cfg.Schedule.ScanInterval <= 0 { + cfg.Schedule.ScanInterval = defaultScanInterval + } + if cfg.Schedule.LockTTL <= 0 { + cfg.Schedule.LockTTL = defaultLockTTL + } + if cfg.Schedule.Store == nil { + cfg.Schedule.Store = NewLocalStore() + } + return nil +} + +func defaultSessionIDFunc[M adk.MessageType](ctx context.Context, _ *adk.TypedChatModelAgentState[M]) (string, error) { + if v, ok := adk.GetSessionValue(ctx, defaultSessionKey); ok { + if s, ok := v.(string); ok && s != "" { + return s, nil + } + } + s := fmt.Sprintf("dream-%d", time.Now().UnixNano()) + adk.AddSessionValue(ctx, defaultSessionKey, s) + return s, nil +} diff --git a/adk/middlewares/automemory/dream/dream.go b/adk/middlewares/automemory/dream/dream.go new file mode 100644 index 000000000..e05645a72 --- /dev/null +++ b/adk/middlewares/automemory/dream/dream.go @@ -0,0 +1,258 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dream + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/cloudwego/eino/adk" + ainternal "github.com/cloudwego/eino/adk/middlewares/automemory/internal" + fsmw "github.com/cloudwego/eino/adk/middlewares/filesystem" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +const ( + stageResolveSessionID = "resolve_session_id" + stageRecordTouch = "record_touch" + stageRunDream = "run_dream" +) + +type middleware[M adk.MessageType] struct { + adk.TypedBaseChatModelAgentMiddleware[M] + + cfg *Config[M] + resolvedMemoryDir string + fsHandler adk.ChatModelAgentMiddleware + sessionSearchTool tool.BaseTool + now func() time.Time +} + +// New creates middleware that triggers dream automatically after agent runs. +func New[M adk.MessageType](ctx context.Context, cfg *Config[M]) (adk.TypedChatModelAgentMiddleware[M], error) { + cfg = cloneConfig(cfg) + if err := applyScheduleDefaults(cfg); err != nil { + return nil, err + } + return newMiddleware(ctx, cfg) +} + +// Run executes a dream immediately, without schedule gating or locking. +func Run[M adk.MessageType](ctx context.Context, cfg *Config[M], req *RunRequest) error { + cfg = cloneConfig(cfg) + if err := applyCoreDefaults(cfg); err != nil { + return err + } + m, err := newMiddleware(ctx, cfg) + if err != nil { + return err + } + if req == nil { + req = &RunRequest{} + } + sessionID := strings.TrimSpace(req.SessionID) + if sessionID == "" { + sessionID, err = cfg.SessionIDFunc(ctx, nil) + if err != nil { + m.onErr(ctx, stageResolveSessionID, err) + return err + } + } + return m.runDream(ctx, sessionID, nil) +} + +type RunRequest struct { + // SessionID identifies the current session. + // Optional. When empty, `SessionIDFunc` is used. + SessionID string +} + +func newMiddleware[M adk.MessageType](ctx context.Context, cfg *Config[M]) (*middleware[M], error) { + resolvedMemoryDir, err := ainternal.ResolveMemoryDir(cfg.MemoryDirectory) + if err != nil { + return nil, fmt.Errorf("auto dream config: resolve memory dir: %w", err) + } + writeFSBackend, err := ainternal.NewFSBackend(cfg.MemoryBackend, ainternal.FSBackendConfig{ + BaseDir: resolvedMemoryDir, + AllowLs: true, + NotFoundAsContent: true, + ErrorPrefix: "dream fs backend", + }) + if err != nil { + return nil, err + } + fsHandler, err := fsmw.New(ctx, &fsmw.MiddlewareConfig{ + Backend: writeFSBackend, + GrepToolConfig: &fsmw.ToolConfig{Disable: true}, + }) + if err != nil { + return nil, err + } + var sessionSearchTool tool.BaseTool + if cfg.SessionStore != nil { + sessionSearchTool, err = newSessionHistoryGrepTool[M](cfg.SessionStore) + if err != nil { + return nil, err + } + } + m := &middleware[M]{ + TypedBaseChatModelAgentMiddleware: adk.TypedBaseChatModelAgentMiddleware[M]{}, + cfg: cfg, + resolvedMemoryDir: resolvedMemoryDir, + fsHandler: fsHandler, + sessionSearchTool: sessionSearchTool, + now: time.Now, + } + return m, nil +} + +func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (context.Context, error) { + if m == nil || m.cfg == nil || m.cfg.Schedule == nil { + return ctx, nil + } + sessionID, err := m.cfg.SessionIDFunc(ctx, state) + if err != nil { + m.onErr(ctx, stageResolveSessionID, err) + return ctx, nil + } + now := m.now() + if err := m.cfg.Schedule.Store.RecordSessionTouch(ctx, m.resolvedMemoryDir, sessionID, now); err != nil { + m.onErr(ctx, stageRecordTouch, err) + return ctx, nil + } + if err := m.maybeTrigger(ctx, sessionID, true); err != nil { + m.onErr(ctx, stageRunDream, err) + } + return ctx, nil +} + +func (m *middleware[M]) maybeTrigger(ctx context.Context, currentSessionID string, excludeCurrent bool) error { + st, err := m.cfg.Schedule.Store.GetScheduleState(ctx, m.resolvedMemoryDir) + if err != nil { + return err + } + if st == nil { + st = &ScheduleState{} + } + now := m.now() + if st.NextCheckAt.After(now) { + return nil + } + since := st.LastConsolidatedAt + if !since.IsZero() && now.Sub(since) < m.cfg.Schedule.MinInterval { + st.NextCheckAt = st.LastConsolidatedAt.Add(m.cfg.Schedule.MinInterval) + return m.cfg.Schedule.Store.SetScheduleState(ctx, m.resolvedMemoryDir, st) + } + touchedSessions, err := m.cfg.Schedule.Store.ListSessionsTouchedSince(ctx, m.resolvedMemoryDir, since) + if err != nil { + return err + } + filtered := touchedSessions[:0] + for _, sessionID := range touchedSessions { + if excludeCurrent && currentSessionID != "" && sessionID == currentSessionID { + continue + } + filtered = append(filtered, sessionID) + } + if len(filtered) < m.cfg.Schedule.MinTouchedSession { + st.NextCheckAt = now.Add(m.cfg.Schedule.ScanInterval) + return m.cfg.Schedule.Store.SetScheduleState(ctx, m.resolvedMemoryDir, st) + } + unlock, ok, err := m.cfg.Schedule.Store.AcquireRunLock(ctx, m.resolvedMemoryDir, m.cfg.Schedule.LockTTL) + if err != nil || !ok { + return err + } + runFn := func() { + defer func() { _ = unlock(context.Background()) }() + if err := m.runDream(context.Background(), currentSessionID, filtered); err != nil { + m.onErr(context.Background(), stageRunDream, err) + st.NextCheckAt = m.now().Add(m.cfg.Schedule.ScanInterval) + _ = m.cfg.Schedule.Store.SetScheduleState(context.Background(), m.resolvedMemoryDir, st) + return + } + st.LastConsolidatedAt = m.now() + st.NextCheckAt = st.LastConsolidatedAt.Add(m.cfg.Schedule.MinInterval) + _ = m.cfg.Schedule.Store.SetScheduleState(context.Background(), m.resolvedMemoryDir, st) + } + if m.cfg.Schedule.RunInline { + runFn() + return nil + } + go runFn() + return nil +} + +func (m *middleware[M]) runDream(ctx context.Context, sessionID string, touchedSessions []string) error { + agent, err := m.newDreamAgent(ctx) + if err != nil { + return err + } + prompt := buildConsolidationPrompt(m.resolvedMemoryDir, touchedSessions, m.cfg.SessionStore != nil) + searchSessionIDs := touchedSessions + if len(searchSessionIDs) == 0 && sessionID != "" { + searchSessionIDs = []string{sessionID} + } + runCtx := withDreamRunMeta(ctx, &dreamRunMeta{ + MemoryDirectory: m.resolvedMemoryDir, + SessionID: sessionID, + SearchSessionIDs: append([]string(nil), searchSessionIDs...), + }) + iter := agent.Run(runCtx, &adk.AgentInput{Messages: []adk.Message{schema.UserMessage(prompt)}}) + if m.cfg.HandleIterator != nil { + return m.cfg.HandleIterator(runCtx, iter) + } + for { + ev, ok := iter.Next() + if !ok { + break + } + if ev.Err != nil { + return ev.Err + } + } + return nil +} + +func (m *middleware[M]) newDreamAgent(ctx context.Context) (*adk.ChatModelAgent, error) { + tools := make([]tool.BaseTool, 0, 1) + if m.sessionSearchTool != nil { + tools = append(tools, m.sessionSearchTool) + } + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "automemory_dream", + Description: "Internal auto dream consolidation agent", + Model: m.cfg.Model, + Handlers: []adk.ChatModelAgentMiddleware{m.fsHandler}, + ToolsConfig: adk.ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: tools}}, + MaxIterations: 12, + }) + if err != nil { + return nil, fmt.Errorf("auto dream create agent: %w", err) + } + return agent, nil +} + +func (m *middleware[M]) onErr(ctx context.Context, stage string, err error) { + if err == nil || m == nil || m.cfg == nil || m.cfg.OnError == nil { + return + } + m.cfg.OnError(ctx, stage, err) +} diff --git a/adk/middlewares/automemory/dream/dream_test.go b/adk/middlewares/automemory/dream/dream_test.go new file mode 100644 index 000000000..b3aae2393 --- /dev/null +++ b/adk/middlewares/automemory/dream/dream_test.go @@ -0,0 +1,434 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dream + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/middlewares/automemory" + adksession "github.com/cloudwego/eino/adk/session" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" +) + +type dreamModel struct { + mu sync.Mutex + prompts []string + toolNames [][]string + calls int32 +} + +func (m *dreamModel) BindTools(tools []*schema.ToolInfo) []string { + names := make([]string, 0, len(tools)) + for _, ti := range tools { + if ti != nil { + names = append(names, ti.Name) + } + } + return names +} + +func (m *dreamModel) Generate(_ context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + callCount := atomic.AddInt32(&m.calls, 1) + toolList := model.GetCommonOptions(nil, opts...).Tools + m.mu.Lock() + m.toolNames = append(m.toolNames, m.BindTools(toolList)) + for _, msg := range input { + if msg.Role == schema.User { + m.prompts = append(m.prompts, messageText(msg)) + } + } + m.mu.Unlock() + content := "" + for _, msg := range input { + if msg.Role == schema.User { + content = messageText(msg) + } + } + if callCount > 1 { + return schema.AssistantMessage("dream complete", nil), nil + } + calls := []schema.ToolCall{ + {ID: "1", Function: schema.FunctionCall{Name: "read_file", Arguments: `{"file_path":"MEMORY.md"}`}}, + {ID: "2", Function: schema.FunctionCall{Name: "write_file", Arguments: `{"file_path":"dream.md","content":"consolidated"}`}}, + {ID: "3", Function: schema.FunctionCall{Name: "write_file", Arguments: `{"file_path":"MEMORY.md","content":"- [Dream](dream.md) - consolidated"}`}}, + } + if strings.Contains(content, "Optional session search") { + calls = append([]schema.ToolCall{{ID: "0", Function: schema.FunctionCall{Name: "grep_session_history", Arguments: `{"query":"build failure"}`}}}, calls...) + } + return schema.AssistantMessage("dream", calls), nil +} + +func (m *dreamModel) Stream(context.Context, []*schema.Message, ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("not implemented") +} + +func (m *dreamModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + m.mu.Lock() + m.toolNames = append(m.toolNames, m.BindTools(tools)) + m.mu.Unlock() + return m, nil +} + +func messageText(msg *schema.Message) string { + if msg == nil { + return "" + } + return msg.Content +} + +type mainAgentModel struct { + reply string +} + +func (m *mainAgentModel) Generate(context.Context, []*schema.Message, ...model.Option) (*schema.Message, error) { + return schema.AssistantMessage(m.reply, nil), nil +} + +func (m *mainAgentModel) Stream(context.Context, []*schema.Message, ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("not implemented") +} + +func (m *mainAgentModel) WithTools([]*schema.ToolInfo) (model.ToolCallingChatModel, error) { + return m, nil +} + +func drainIterator(t *testing.T, iter *adk.AsyncIterator[*adk.AgentEvent]) []*adk.AgentEvent { + t.Helper() + var out []*adk.AgentEvent + for { + ev, ok := iter.Next() + if !ok { + return out + } + out = append(out, ev) + if ev != nil && ev.Err != nil { + return out + } + } +} + +type countingSessionStore struct { + adk.SessionEventStore[*schema.Message] + loadCalls int32 +} + +func (s *countingSessionStore) LoadEvents(ctx context.Context, req *adk.LoadSessionEventsRequest) (*adk.LoadSessionEventsResult[*schema.Message], error) { + atomic.AddInt32(&s.loadCalls, 1) + return s.SessionEventStore.LoadEvents(ctx, req) +} + +type nilStateStore struct { + Store +} + +func (s *nilStateStore) GetScheduleState(context.Context, string) (*ScheduleState, error) { + return nil, nil +} + +func TestBuildConsolidationPrompt_OmitsSessionSearchSectionWhenProviderMissing(t *testing.T) { + prompt := buildConsolidationPrompt("/mem", []string{"a", "b"}, false) + require.NotContains(t, prompt, "Optional session search") + require.Contains(t, prompt, "Sessions since last consolidation (2)") +} + +func TestBuildConsolidationPrompt_Chinese(t *testing.T) { + require.NoError(t, adk.SetLanguage(adk.LanguageChinese)) + defer func() { + require.NoError(t, adk.SetLanguage(adk.LanguageEnglish)) + }() + + prompt := buildConsolidationPrompt("/mem", []string{"a", "b"}, true) + require.Contains(t, prompt, "## 可选的 session 搜索") + require.Contains(t, prompt, "它只会搜索本次 dream 运行范围内包含的 session 历史") + require.Contains(t, prompt, "自上次 consolidation 以来触达过的 sessions(2)") +} + +func TestNew_DoesNotMutateConfig(t *testing.T) { + ctx := context.Background() + cfg := &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: automemory.NewInMemoryBackend(), + Model: &dreamModel{}, + SessionStore: &SessionStoreProvider[*schema.Message]{ + SessionStore: adksession.NewInMemoryStore[*schema.Message](nil), + Serializer: &schema.HumanReadableSerializer{}, + }, + Schedule: &ScheduleConfig{}, + } + + _, err := New(ctx, cfg) + require.NoError(t, err) + require.Nil(t, cfg.SessionIDFunc) + require.Zero(t, cfg.Schedule.MinInterval) + require.Zero(t, cfg.Schedule.MinTouchedSession) + require.Zero(t, cfg.Schedule.ScanInterval) + require.Zero(t, cfg.Schedule.LockTTL) + require.Nil(t, cfg.Schedule.Store) +} + +func TestMiddleware_AfterAgent_RunInlineWithSessionStore(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("- [Existing](existing.md) - old"), 0o644)) + store := NewLocalStore() + model := &dreamModel{} + eventStore := &countingSessionStore{SessionEventStore: adksession.NewInMemoryStore[*schema.Message](nil)} + serializer := &schema.HumanReadableSerializer{} + payload, err := serializer.Marshal(&adk.SessionEvent[*schema.Message]{ + EventID: "e1", + Kind: adk.SessionEventMessage, + Message: schema.AssistantMessage("build failure: missing dependency", nil), + }) + require.NoError(t, err) + _, err = eventStore.AppendEvents(ctx, &adk.AppendSessionEventsRequest[*schema.Message]{ + SessionID: "session-a", + Events: []*adk.SessionEvent[*schema.Message]{{ + EventID: "e1", + Kind: adk.SessionEventMessage, + Message: schema.AssistantMessage("build failure: missing dependency", nil), + }}, + }) + require.NoError(t, err) + _ = payload + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: model, + SessionStore: &SessionStoreProvider[*schema.Message]{ + SessionStore: eventStore, + Serializer: serializer, + }, + Schedule: &ScheduleConfig{ + RunInline: true, + Store: store, + MinInterval: time.Hour, + MinTouchedSession: 1, + ScanInterval: time.Minute, + }, + }) + require.NoError(t, err) + impl, ok := mw.(*middleware[*schema.Message]) + require.True(t, ok) + now := time.Now() + impl.now = func() time.Time { return now } + require.NoError(t, store.SetScheduleState(ctx, tmp, &ScheduleState{LastConsolidatedAt: now.Add(-2 * time.Hour), NextCheckAt: now})) + require.NoError(t, store.RecordSessionTouch(ctx, tmp, "session-a", now.Add(-30*time.Minute))) + + _, err = impl.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{}) + require.NoError(t, err) + + raw, err := os.ReadFile(filepath.Join(tmp, "dream.md")) + require.NoError(t, err) + require.Equal(t, "consolidated", string(raw)) + require.GreaterOrEqual(t, atomic.LoadInt32(&eventStore.loadCalls), int32(1)) + model.mu.Lock() + defer model.mu.Unlock() + require.NotEmpty(t, model.prompts) + require.Contains(t, model.prompts[0], "Optional session search") +} + +func TestMiddleware_AfterAgent_FirstEligibleTouchCanTriggerImmediately(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte(""), 0o644)) + store := NewLocalStore() + model := &dreamModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: model, + Schedule: &ScheduleConfig{ + RunInline: true, + Store: store, + MinInterval: time.Hour, + MinTouchedSession: 1, + ScanInterval: time.Minute, + }, + }) + require.NoError(t, err) + impl, ok := mw.(*middleware[*schema.Message]) + require.True(t, ok) + now := time.Now() + impl.now = func() time.Time { return now } + require.NoError(t, store.RecordSessionTouch(ctx, tmp, "older-session", now.Add(-2*time.Minute))) + + _, err = impl.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{}) + require.NoError(t, err) + + raw, err := os.ReadFile(filepath.Join(tmp, "dream.md")) + require.NoError(t, err) + require.Equal(t, "consolidated", string(raw)) +} + +func TestMiddleware_AfterAgent_NilScheduleStateDoesNotPanic(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte(""), 0o644)) + + baseStore := NewLocalStore() + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: &dreamModel{}, + Schedule: &ScheduleConfig{ + RunInline: true, + Store: &nilStateStore{Store: baseStore}, + MinInterval: time.Hour, + MinTouchedSession: 1, + ScanInterval: time.Minute, + }, + }) + require.NoError(t, err) + + _, err = mw.(*middleware[*schema.Message]).AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{}) + require.NoError(t, err) +} + +func TestRun_ManualDreamWithoutSchedule(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte(""), 0o644)) + model := &dreamModel{} + + err := Run(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: model, + }, &RunRequest{ + SessionID: "manual-session", + }) + require.NoError(t, err) + + raw, err := os.ReadFile(filepath.Join(tmp, "dream.md")) + require.NoError(t, err) + require.Equal(t, "consolidated", string(raw)) + model.mu.Lock() + defer model.mu.Unlock() + require.NotEmpty(t, model.prompts) + require.NotContains(t, model.prompts[0], "Sessions since last consolidation") + require.NotContains(t, model.prompts[0], "Optional session search") +} + +func TestIntegration_UserPerspective_AgentMiddlewareAutoDream(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("- [Existing](existing.md) - old\n"), 0o644)) + + store := NewLocalStore() + dreamModel := &dreamModel{} + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: dreamModel, + Schedule: &ScheduleConfig{ + RunInline: true, + Store: store, + MinInterval: time.Hour, + MinTouchedSession: 1, + ScanInterval: time.Minute, + }, + }) + require.NoError(t, err) + impl, ok := mw.(*middleware[*schema.Message]) + require.True(t, ok) + now := time.Now() + impl.now = func() time.Time { return now } + require.NoError(t, store.RecordSessionTouch(ctx, tmp, "older-session", now.Add(-3*time.Minute))) + + agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "main_agent", + Description: "main agent for dream integration test", + Model: &mainAgentModel{reply: "main answer"}, + Handlers: []adk.ChatModelAgentMiddleware{mw}, + }) + require.NoError(t, err) + + events := drainIterator(t, agent.Run(ctx, &adk.AgentInput{ + Messages: []adk.Message{schema.UserMessage("please help")}, + })) + require.NotEmpty(t, events) + last := events[len(events)-1] + require.NotNil(t, last) + require.Nil(t, last.Err) + require.NotNil(t, last.Output) + require.Equal(t, "main answer", last.Output.MessageOutput.Message.Content) + + raw, err := os.ReadFile(filepath.Join(tmp, "dream.md")) + require.NoError(t, err) + require.Equal(t, "consolidated", string(raw)) + index, err := os.ReadFile(filepath.Join(tmp, "MEMORY.md")) + require.NoError(t, err) + require.Contains(t, string(index), "dream.md") +} + +func TestIntegration_UserPerspective_RunReturnsCallbackErrors(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte(""), 0o644)) + + expected := fmt.Errorf("resolve session failed") + var onErrStages []string + err := Run(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: &dreamModel{}, + SessionIDFunc: func(context.Context, *adk.TypedChatModelAgentState[*schema.Message]) (string, error) { + return "", expected + }, + OnError: func(_ context.Context, stage string, err error) { + onErrStages = append(onErrStages, stage+":"+err.Error()) + }, + }, nil) + require.ErrorIs(t, err, expected) + require.Equal(t, []string{stageResolveSessionID + ":" + expected.Error()}, onErrStages) +} + +func TestIntegration_UserPerspective_RunFallsBackToSessionIDFuncWithoutState(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte(""), 0o644)) + + model := &dreamModel{} + var resolvedState *adk.TypedChatModelAgentState[*schema.Message] + err := Run(ctx, &Config[*schema.Message]{ + MemoryDirectory: tmp, + MemoryBackend: automemory.NewLocalBackend(), + Model: model, + SessionIDFunc: func(_ context.Context, state *adk.TypedChatModelAgentState[*schema.Message]) (string, error) { + resolvedState = state + return "fallback-session", nil + }, + }, nil) + require.NoError(t, err) + require.Nil(t, resolvedState) + + raw, err := os.ReadFile(filepath.Join(tmp, "dream.md")) + require.NoError(t, err) + require.Equal(t, "consolidated", string(raw)) +} diff --git a/adk/middlewares/automemory/dream/prompt.go b/adk/middlewares/automemory/dream/prompt.go new file mode 100644 index 000000000..58f357bbc --- /dev/null +++ b/adk/middlewares/automemory/dream/prompt.go @@ -0,0 +1,137 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dream + +import ( + "fmt" + "strings" + + "github.com/cloudwego/eino/adk/internal" +) + +func buildConsolidationPrompt(memoryRoot string, touchedSessions []string, includeSessionSearch bool) string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: buildConsolidationPromptEnglish(memoryRoot, touchedSessions, includeSessionSearch), + Chinese: buildConsolidationPromptChinese(memoryRoot, touchedSessions, includeSessionSearch), + }) +} + +func buildConsolidationPromptEnglish(memoryRoot string, touchedSessions []string, includeSessionSearch bool) string { + extra := "" + if len(touchedSessions) > 0 { + extra = fmt.Sprintf("\n\nSessions since last consolidation (%d):\n%s", len(touchedSessions), bulletList(touchedSessions)) + } + sessionSearchSection := "" + if includeSessionSearch { + sessionSearchSection = ` + +## Optional session search + +- Use grep_session_history with narrow terms when you already suspect something matters +- It searches only the session histories included in this dream run +- Do not exhaustively scan session history; use it only to confirm details` + } + return fmt.Sprintf(`# Dream: Memory Consolidation + +You are performing a dream: a reflective pass over persistent memory files. Synthesize what was learned recently into durable, well-organized memory so future sessions can orient quickly. + +Memory directory: %s + +## Phase 1 - Orient +- Use ls/glob to inspect the memory directory +- Read MEMORY.md first to understand the current index +- Skim existing topic files before creating new ones so you improve or merge instead of duplicating%s + +## Phase 2 - Gather signal +- Focus on durable information that has emerged across recent sessions +- Prefer updating an existing topic file over creating a near-duplicate +- Convert relative time references into absolute dates when they matter +- Remove or correct stale facts at the source + +## Phase 3 - Consolidate +- Keep each memory file focused on one topic +- Use read_file before write_file/edit_file for every file you plan to touch +- Write only inside the memory directory +- Do not investigate the codebase outside memory files and the current session history during this run + +## Phase 4 - Prune and index +- Keep MEMORY.md concise; it is an index, not the full memory body +- Ensure new or updated topic files are reflected in MEMORY.md +- Remove stale or superseded pointers from MEMORY.md + +Return a brief summary of what you consolidated, updated, or pruned. If nothing changed, say so.%s`, memoryRoot, sessionSearchSection, extra) +} + +func buildConsolidationPromptChinese(memoryRoot string, touchedSessions []string, includeSessionSearch bool) string { + extra := "" + if len(touchedSessions) > 0 { + extra = fmt.Sprintf("\n\n自上次 consolidation 以来触达过的 sessions(%d):\n%s", len(touchedSessions), bulletList(touchedSessions)) + } + sessionSearchSection := "" + if includeSessionSearch { + sessionSearchSection = ` + +## 可选的 session 搜索 + +- 当你已经怀疑某条信息重要时,再用 grep_session_history 做精确搜索 +- 它只会搜索本次 dream 运行范围内包含的 session 历史 +- 不要穷举扫描 session 历史,只在需要核实细节时使用` + } + return fmt.Sprintf(`# Dream:记忆整理 + +你正在执行一次 dream:对持久化记忆文件做反思式整理。请把最近学到的内容沉淀成稳定、清晰且结构化的长期记忆,帮助未来会话快速建立上下文。 + +记忆目录:%s + +## 阶段 1 - 建立整体认识 +- 使用 ls/glob 查看记忆目录 +- 先阅读 MEMORY.md,理解当前索引结构 +- 在创建新主题文件前,先浏览现有主题文件,优先改进或合并,而不是重复创建%s + +## 阶段 2 - 收集有效信号 +- 关注在最近多个 session 中沉淀下来的长期有效信息 +- 优先更新已有主题文件,而不是创建内容接近的重复文件 +- 当相对时间表述会影响理解时,将其转换为绝对日期 +- 在源头处删除或修正过时事实 + +## 阶段 3 - 整理与归并 +- 让每个记忆文件只聚焦一个主题 +- 对每个计划修改的文件,都先 read_file,再 write_file/edit_file +- 只在记忆目录内写入 +- 本次运行中,不要调查记忆文件与当前 session 历史之外的代码库内容 + +## 阶段 4 - 修剪与更新索引 +- 保持 MEMORY.md 简洁;它是索引,不是完整记忆正文 +- 确保新增或更新过的主题文件都同步反映到 MEMORY.md +- 从 MEMORY.md 中移除陈旧或已被替代的索引项 + +请简要总结你本次 consolidation、更新或修剪了什么;如果没有任何变更,也请明确说明。%s`, memoryRoot, sessionSearchSection, extra) +} + +func bulletList(items []string) string { + if len(items) == 0 { + return "" + } + var b strings.Builder + b.WriteString("- ") + b.WriteString(items[0]) + for i := 1; i < len(items); i++ { + b.WriteString("\n- ") + b.WriteString(items[i]) + } + return b.String() +} diff --git a/adk/middlewares/automemory/dream/session.go b/adk/middlewares/automemory/dream/session.go new file mode 100644 index 000000000..465b0e7a9 --- /dev/null +++ b/adk/middlewares/automemory/dream/session.go @@ -0,0 +1,193 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dream + +import ( + "context" + "fmt" + "strings" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/adk/internal" + "github.com/cloudwego/eino/components/tool" + toolutils "github.com/cloudwego/eino/components/tool/utils" + "github.com/cloudwego/eino/schema" +) + +// SessionStoreProvider supplies the session timeline dependencies required by +// dream's optional session-history search tool. +// +// When configured, dream exposes a narrow grep-like tool that scans persisted +// message events for the current session only. +type SessionStoreProvider[M adk.MessageType] struct { + // SessionStore loads persisted session events for the current session. + // Required when SessionStoreProvider is configured. + SessionStore adk.SessionEventStore[M] + // Serializer decodes SessionEvent payload bytes returned by SessionStore. + // It must match the serializer used when the events were persisted. + Serializer schema.Serializer +} + +type grepSessionHistoryInput struct { + Query string `json:"query" jsonschema:"required,description=the narrow term to search in current session history"` + Limit int `json:"limit,omitempty" jsonschema:"description=maximum number of matching lines to return"` +} + +type dreamRunMeta struct { + MemoryDirectory string + SessionID string + SearchSessionIDs []string +} + +type dreamRunMetaKey struct{} + +func withDreamRunMeta(ctx context.Context, meta *dreamRunMeta) context.Context { + return context.WithValue(ctx, dreamRunMetaKey{}, meta) +} + +func getDreamRunMeta(ctx context.Context) *dreamRunMeta { + if v := ctx.Value(dreamRunMetaKey{}); v != nil { + if meta, ok := v.(*dreamRunMeta); ok { + return meta + } + } + return nil +} + +func newSessionHistoryGrepTool[M adk.MessageType](provider *SessionStoreProvider[M]) (tool.BaseTool, error) { + if provider == nil || provider.SessionStore == nil || provider.Serializer == nil { + return nil, nil + } + + t, err := toolutils.InferTool("grep_session_history", internal.SelectPrompt(internal.I18nPrompts{ + English: "Search the session histories included in the current dream run with a narrow query and return matching lines.", + Chinese: "在当前 dream 运行范围内的会话历史中按精确关键词搜索,并返回匹配行。", + }), func(ctx context.Context, input grepSessionHistoryInput) (string, error) { + meta := getDreamRunMeta(ctx) + if meta == nil { + return "", fmt.Errorf("grep_session_history: missing dream run metadata") + } + sessionIDs := resolveSearchSessionIDs(meta) + if len(sessionIDs) == 0 { + return "", fmt.Errorf("grep_session_history: no searchable sessions in current dream run") + } + query := strings.TrimSpace(input.Query) + if query == "" { + return "", fmt.Errorf("grep_session_history: empty query") + } + limit := input.Limit + if limit <= 0 { + limit = 50 + } + + pageSize := limit + if pageSize < 100 { + pageSize = 100 + } + + var ( + after string + found []string + ) + includeSessionPrefix := len(sessionIDs) > 1 + for _, sessionID := range sessionIDs { + after = "" + for len(found) < limit { + result, err := provider.SessionStore.LoadEvents(ctx, &adk.LoadSessionEventsRequest{ + SessionID: sessionID, + After: after, + Limit: pageSize, + Reverse: true, + Kinds: []adk.SessionEventKind{adk.SessionEventMessage}, + IncludeSessionTail: false, + }) + if err != nil { + return "", err + } + if result == nil || len(result.Events) == 0 { + break + } + for _, ev := range result.Events { + found = appendMatchingSessionHistoryLines(found, sessionID, sessionEventMessageString(ev.Message), query, limit, includeSessionPrefix) + if len(found) >= limit { + break + } + } + if result.Next == "" { + break + } + after = result.Next + } + if len(found) >= limit { + break + } + } + + return strings.Join(found, "\n"), nil + }) + if err != nil { + return nil, err + } + + return t, nil +} + +func resolveSearchSessionIDs(meta *dreamRunMeta) []string { + if meta == nil { + return nil + } + if len(meta.SearchSessionIDs) > 0 { + return meta.SearchSessionIDs + } + if meta.SessionID != "" { + return []string{meta.SessionID} + } + return nil +} + +func appendMatchingSessionHistoryLines(dst []string, sessionID, message, query string, limit int, includeSessionPrefix bool) []string { + needle := strings.ToLower(query) + for _, line := range strings.Split(message, "\n") { + if strings.Contains(strings.ToLower(line), needle) { + if includeSessionPrefix { + line = fmt.Sprintf("[%s] %s", sessionID, line) + } + dst = append(dst, line) + if len(dst) >= limit { + return dst + } + } + } + return dst +} + +func sessionEventMessageString[M adk.MessageType](msg M) string { + switch m := any(msg).(type) { + case *schema.Message: + if m == nil { + return "" + } + return m.String() + case *schema.AgenticMessage: + if m == nil { + return "" + } + return m.String() + default: + return "" + } +} diff --git a/adk/middlewares/automemory/dream/session_test.go b/adk/middlewares/automemory/dream/session_test.go new file mode 100644 index 000000000..69df4fc7c --- /dev/null +++ b/adk/middlewares/automemory/dream/session_test.go @@ -0,0 +1,125 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dream + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudwego/eino/adk" + adksession "github.com/cloudwego/eino/adk/session" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" +) + +func TestNewSessionHistoryGrepTool(t *testing.T) { + ctx := context.Background() + store := adksession.NewInMemoryStore[*schema.Message](nil) + serializer := &schema.HumanReadableSerializer{} + sessionID := "session-1" + tail := "" + + appendEvent := func(eventID string, msg *schema.Message) { + res, err := store.AppendEvents(ctx, &adk.AppendSessionEventsRequest[*schema.Message]{ + SessionID: sessionID, + ExpectedSessionTailEventID: tail, + Events: []*adk.SessionEvent[*schema.Message]{{ + EventID: eventID, + Kind: adk.SessionEventMessage, + Message: msg, + }}, + }) + require.NoError(t, err) + tail = res.SessionTailEventID + } + + appendEvent("e1", schema.UserMessage("hello there")) + appendEvent("e2", schema.AssistantMessage("build failure: missing dependency", nil)) + appendEvent("e3", schema.ToolMessage("Build Failure: retry later", "call-1")) + + bt, err := newSessionHistoryGrepTool(&SessionStoreProvider[*schema.Message]{ + SessionStore: store, + Serializer: serializer, + }) + require.NoError(t, err) + + result, err := bt.(tool.InvokableTool).InvokableRun( + withDreamRunMeta(ctx, &dreamRunMeta{SessionID: sessionID, SearchSessionIDs: []string{sessionID}}), + `{"query":"build failure","limit":2}`, + ) + require.NoError(t, err) + require.Equal(t, "tool: Build Failure: retry later\nassistant: build failure: missing dependency", result) +} + +func TestNewSessionHistoryGrepTool_SearchesRunScopedSessions(t *testing.T) { + ctx := context.Background() + store := adksession.NewInMemoryStore[*schema.Message](nil) + serializer := &schema.HumanReadableSerializer{} + + appendEvent := func(sessionID, eventID string, msg *schema.Message) { + _, err := store.AppendEvents(ctx, &adk.AppendSessionEventsRequest[*schema.Message]{ + SessionID: sessionID, + Events: []*adk.SessionEvent[*schema.Message]{{ + EventID: eventID, + Kind: adk.SessionEventMessage, + Message: msg, + }}, + }) + require.NoError(t, err) + } + + appendEvent("session-a", "a1", schema.AssistantMessage("build failure: missing dependency", nil)) + appendEvent("session-b", "b1", schema.ToolMessage("build failure: retry later", "call-1")) + appendEvent("session-c", "c1", schema.AssistantMessage("build failure: should not be searched", nil)) + + bt, err := newSessionHistoryGrepTool(&SessionStoreProvider[*schema.Message]{ + SessionStore: store, + Serializer: serializer, + }) + require.NoError(t, err) + + result, err := bt.(tool.InvokableTool).InvokableRun( + withDreamRunMeta(ctx, &dreamRunMeta{ + SessionID: "session-c", + SearchSessionIDs: []string{"session-a", "session-b"}, + }), + `{"query":"build failure","limit":5}`, + ) + require.NoError(t, err) + require.Contains(t, result, "[session-a] assistant: build failure: missing dependency") + require.Contains(t, result, "[session-b] tool: build failure: retry later") + require.NotContains(t, result, "should not be searched") +} + +func TestNewSessionHistoryGrepTool_InfoUsesChineseDescription(t *testing.T) { + require.NoError(t, adk.SetLanguage(adk.LanguageChinese)) + defer func() { + require.NoError(t, adk.SetLanguage(adk.LanguageEnglish)) + }() + + bt, err := newSessionHistoryGrepTool(&SessionStoreProvider[*schema.Message]{ + SessionStore: adksession.NewInMemoryStore[*schema.Message](nil), + Serializer: &schema.HumanReadableSerializer{}, + }) + require.NoError(t, err) + + info, err := bt.Info(context.Background()) + require.NoError(t, err) + require.Contains(t, info.Desc, "在当前 dream 运行范围内的会话历史中按精确关键词搜索") +} diff --git a/adk/middlewares/automemory/dream/store.go b/adk/middlewares/automemory/dream/store.go new file mode 100644 index 000000000..c5f80d64f --- /dev/null +++ b/adk/middlewares/automemory/dream/store.go @@ -0,0 +1,135 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dream + +import ( + "context" + "sync" + "time" +) + +// ScheduleState stores per-memory-directory scheduling state. +type ScheduleState struct { + // LastConsolidatedAt is the completion time of the last successful run. + LastConsolidatedAt time.Time + + // NextCheckAt is the next time the middleware should re-check this directory. + NextCheckAt time.Time +} + +// Store persists middleware scheduling state for one resolved `MemoryDirectory`, +// including touched sessions, backoff state, and the run lock. +type Store interface { + // RecordSessionTouch records that a session produced new signal. + RecordSessionTouch(ctx context.Context, memoryDir, sessionID string, at time.Time) error + + // ListSessionsTouchedSince returns distinct sessions touched after `since`. + ListSessionsTouchedSince(ctx context.Context, memoryDir string, since time.Time) ([]string, error) + + // GetScheduleState loads the scheduling state for one memory directory. + GetScheduleState(ctx context.Context, memoryDir string) (*ScheduleState, error) + + // SetScheduleState persists the scheduling state. + // Passing nil should clear it when supported. + SetScheduleState(ctx context.Context, memoryDir string, state *ScheduleState) error + + // AcquireRunLock tries to acquire the per-memory-directory run lock. + // It returns `ok=false` when another process already holds the lock. + AcquireRunLock(ctx context.Context, memoryDir string, ttl time.Duration) (unlock func(context.Context) error, ok bool, err error) +} + +type localStore struct { + mu sync.Mutex + touches map[string]map[string]time.Time + states map[string]ScheduleState + locks map[string]time.Time +} + +// NewLocalStore returns an in-process `Store`. +// It is suitable for tests and single-process use only. +func NewLocalStore() Store { + return &localStore{ + touches: make(map[string]map[string]time.Time), + states: make(map[string]ScheduleState), + locks: make(map[string]time.Time), + } +} + +func (s *localStore) RecordSessionTouch(_ context.Context, memoryDir, sessionID string, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.touches[memoryDir] == nil { + s.touches[memoryDir] = make(map[string]time.Time) + } + s.touches[memoryDir][sessionID] = at + st := s.states[memoryDir] + if st.NextCheckAt.IsZero() { + st.NextCheckAt = at + s.states[memoryDir] = st + } + return nil +} + +func (s *localStore) ListSessionsTouchedSince(_ context.Context, memoryDir string, since time.Time) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + items := s.touches[memoryDir] + if len(items) == 0 { + return nil, nil + } + out := make([]string, 0, len(items)) + for sessionID, touchedAt := range items { + if touchedAt.After(since) { + out = append(out, sessionID) + } + } + return out, nil +} + +func (s *localStore) GetScheduleState(_ context.Context, memoryDir string) (*ScheduleState, error) { + s.mu.Lock() + defer s.mu.Unlock() + st := s.states[memoryDir] + cp := st + return &cp, nil +} + +func (s *localStore) SetScheduleState(_ context.Context, memoryDir string, state *ScheduleState) error { + s.mu.Lock() + defer s.mu.Unlock() + if state == nil { + delete(s.states, memoryDir) + return nil + } + s.states[memoryDir] = *state + return nil +} + +func (s *localStore) AcquireRunLock(_ context.Context, memoryDir string, ttl time.Duration) (func(context.Context) error, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + if until, ok := s.locks[memoryDir]; ok && until.After(time.Now()) { + return nil, false, nil + } + s.locks[memoryDir] = time.Now().Add(ttl) + return func(context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.locks, memoryDir) + return nil + }, true, nil +} diff --git a/adk/middlewares/automemory/inmemory_backend.go b/adk/middlewares/automemory/inmemory_backend.go new file mode 100644 index 000000000..9b775b139 --- /dev/null +++ b/adk/middlewares/automemory/inmemory_backend.go @@ -0,0 +1,210 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +import ( + "context" + "fmt" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/bmatcuk/doublestar/v4" +) + +type memFile struct { + content string + modifiedAt time.Time +} + +// InMemoryBackend is a simple in-memory Backend implementation intended for tests +// and demos. Paths are treated as filesystem-like, and should be absolute. +type InMemoryBackend struct { + mu sync.RWMutex + files map[string]*memFile +} + +// NewInMemoryBackend returns an empty in-memory Backend implementation. +func NewInMemoryBackend() *InMemoryBackend { + return &InMemoryBackend{ + files: make(map[string]*memFile), + } +} + +func (b *InMemoryBackend) put(path string, content string, modifiedAt time.Time) { + b.mu.Lock() + defer b.mu.Unlock() + b.files[filepath.Clean(path)] = &memFile{content: content, modifiedAt: modifiedAt} +} + +func (b *InMemoryBackend) Write(_ context.Context, req *WriteRequest) error { + if req == nil || req.FilePath == "" { + return fmt.Errorf("write: invalid request") + } + // Default to full replace. + b.put(req.FilePath, req.Content, time.Now()) + return nil +} + +func (b *InMemoryBackend) Edit(_ context.Context, req *EditRequest) error { + if req == nil || req.FilePath == "" { + return fmt.Errorf("edit: invalid request") + } + b.mu.Lock() + defer b.mu.Unlock() + + path := filepath.Clean(req.FilePath) + f, ok := b.files[path] + if !ok { + return fmt.Errorf("file not found: %s", path) + } + + if req.OldString == "" { + return fmt.Errorf("edit: old string must be non-empty") + } + if req.OldString == req.NewString { + return fmt.Errorf("edit: new string must differ from old string") + } + + out := f.content + if req.ReplaceAll { + out = strings.ReplaceAll(out, req.OldString, req.NewString) + } else { + if strings.Count(out, req.OldString) != 1 { + return fmt.Errorf("edit: old string must appear exactly once when ReplaceAll is false") + } + out = strings.Replace(out, req.OldString, req.NewString, 1) + } + f.content = out + f.modifiedAt = time.Now() + return nil +} + +func (b *InMemoryBackend) Read(_ context.Context, req *ReadRequest) (*FileContent, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + if req == nil || req.FilePath == "" { + return nil, fmt.Errorf("read: invalid request") + } + path := filepath.Clean(req.FilePath) + f, ok := b.files[path] + if !ok { + return nil, fmt.Errorf("file not found: %s", path) + } + + offset := req.Offset - 1 + if offset < 0 { + offset = 0 + } + limit := req.Limit + + content := f.content + if offset == 0 && limit <= 0 { + return &FileContent{Content: content}, nil + } + + start := 0 + for i := 0; i < offset; i++ { + idx := strings.IndexByte(content[start:], '\n') + if idx == -1 { + return &FileContent{Content: ""}, nil + } + start += idx + 1 + } + + if limit <= 0 { + return &FileContent{Content: content[start:]}, nil + } + + end := start + for i := 0; i < limit; i++ { + idx := strings.IndexByte(content[end:], '\n') + if idx == -1 { + return &FileContent{Content: content[start:]}, nil + } + end += idx + 1 + } + + // Trim trailing newline. + return &FileContent{Content: strings.TrimSuffix(content[start:end], "\n")}, nil +} + +func (b *InMemoryBackend) GlobInfo(_ context.Context, req *GlobInfoRequest) ([]FileInfo, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + if req == nil || req.Pattern == "" { + return nil, fmt.Errorf("glob: invalid request") + } + base := filepath.Clean(req.Path) + if base == "." { + base = "" + } + + type item struct { + fi FileInfo + t time.Time + } + var out []item + + for p, f := range b.files { + if base != "" { + // Require p under base. + if p != base && !strings.HasPrefix(p, base+string(filepath.Separator)) { + continue + } + } + + rel := p + if base != "" { + rel = strings.TrimPrefix(p, base+string(filepath.Separator)) + if rel == p { + rel = strings.TrimPrefix(p, base) + rel = strings.TrimPrefix(rel, string(filepath.Separator)) + } + } + rel = filepath.ToSlash(rel) + + ok, err := doublestar.Match(req.Pattern, rel) + if err != nil { + return nil, err + } + if !ok { + continue + } + + out = append(out, item{ + fi: FileInfo{ + Path: p, + IsDir: false, + Size: int64(len(f.content)), + ModifiedAt: f.modifiedAt.Format(time.RFC3339Nano), + }, + t: f.modifiedAt, + }) + } + + sort.Slice(out, func(i, j int) bool { return out[i].t.After(out[j].t) }) + ret := make([]FileInfo, 0, len(out)) + for _, it := range out { + ret = append(ret, it.fi) + } + return ret, nil +} diff --git a/adk/middlewares/automemory/internal/backend.go b/adk/middlewares/automemory/internal/backend.go new file mode 100644 index 000000000..cd44c29cd --- /dev/null +++ b/adk/middlewares/automemory/internal/backend.go @@ -0,0 +1,235 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package internal contains bounded filesystem adapters used by automemory +// middleware implementations. +package internal + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + adkfs "github.com/cloudwego/eino/adk/filesystem" +) + +type Backend interface { + Read(ctx context.Context, req *adkfs.ReadRequest) (*adkfs.FileContent, error) + GlobInfo(ctx context.Context, req *adkfs.GlobInfoRequest) ([]adkfs.FileInfo, error) + Write(ctx context.Context, req *adkfs.WriteRequest) error + Edit(ctx context.Context, req *adkfs.EditRequest) error +} + +type FSBackendConfig struct { + BaseDir string + AllowLs bool + AllowGrep bool + NotFoundAsContent bool + ErrorPrefix string +} + +type FSBackend struct { + backend Backend + baseClean string + allowLs bool + allowGrep bool + notFoundAsContent bool + errorPrefix string +} + +// ResolveMemoryDir returns the cleaned absolute path for a memory directory. +func ResolveMemoryDir(dir string) (string, error) { + abs, err := filepath.Abs(dir) + if err != nil { + return "", err + } + return filepath.Clean(abs), nil +} + +// NewFSBackend wraps a Backend with path-bounding and optional tool behaviors. +func NewFSBackend(backend Backend, cfg FSBackendConfig) (*FSBackend, error) { + if backend == nil { + return nil, fmt.Errorf("%s: nil backend", prefixOrDefault(cfg.ErrorPrefix)) + } + if cfg.BaseDir == "" { + return nil, fmt.Errorf("%s: empty base dir", prefixOrDefault(cfg.ErrorPrefix)) + } + baseClean, err := ResolveMemoryDir(cfg.BaseDir) + if err != nil { + return nil, fmt.Errorf("%s: resolve base dir: %w", prefixOrDefault(cfg.ErrorPrefix), err) + } + return &FSBackend{ + backend: backend, + baseClean: baseClean, + allowLs: cfg.AllowLs, + allowGrep: cfg.AllowGrep, + notFoundAsContent: cfg.NotFoundAsContent, + errorPrefix: prefixOrDefault(cfg.ErrorPrefix), + }, nil +} + +func prefixOrDefault(prefix string) string { + if prefix == "" { + return "fs backend" + } + return prefix +} + +func isFileNotFoundErr(err error) bool { + if err == nil { + return false + } + if os.IsNotExist(err) { + return true + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "file not found") || strings.Contains(msg, "no such file or directory") +} + +func (f *FSBackend) resolveFilePath(p string) (string, error) { + if p == "" { + return "", fmt.Errorf("%s: empty path", f.errorPrefix) + } + if !filepath.IsAbs(p) { + p = filepath.Join(f.baseClean, p) + } + p = filepath.Clean(p) + if p != f.baseClean && !strings.HasPrefix(p, f.baseClean+string(filepath.Separator)) { + return "", fmt.Errorf("%s: path out of bounds: %s", f.errorPrefix, p) + } + return p, nil +} + +func (f *FSBackend) resolveDirPath(p string) (string, error) { + if p == "" { + return f.baseClean, nil + } + if !filepath.IsAbs(p) { + p = filepath.Join(f.baseClean, p) + } + p = filepath.Clean(p) + if p != f.baseClean && !strings.HasPrefix(p, f.baseClean+string(filepath.Separator)) { + return "", fmt.Errorf("%s: dir out of bounds: %s", f.errorPrefix, p) + } + return p, nil +} + +func (f *FSBackend) Read(ctx context.Context, req *adkfs.ReadRequest) (*adkfs.FileContent, error) { + if req == nil { + return nil, fmt.Errorf("read: invalid request") + } + fp, err := f.resolveFilePath(req.FilePath) + if err != nil { + return nil, err + } + n := *req + n.FilePath = fp + content, err := f.backend.Read(ctx, &n) + if err != nil { + if f.notFoundAsContent && isFileNotFoundErr(err) { + return &adkfs.FileContent{Content: fmt.Sprintf("File not found: %s", fp)}, nil + } + return nil, err + } + return (*adkfs.FileContent)(content), nil +} + +func (f *FSBackend) Write(ctx context.Context, req *adkfs.WriteRequest) error { + if req == nil { + return fmt.Errorf("write: invalid request") + } + fp, err := f.resolveFilePath(req.FilePath) + if err != nil { + return err + } + n := *req + n.FilePath = fp + return f.backend.Write(ctx, &n) +} + +func (f *FSBackend) Edit(ctx context.Context, req *adkfs.EditRequest) error { + if req == nil { + return fmt.Errorf("edit: invalid request") + } + fp, err := f.resolveFilePath(req.FilePath) + if err != nil { + return err + } + n := *req + n.FilePath = fp + return f.backend.Edit(ctx, &n) +} + +func (f *FSBackend) GlobInfo(ctx context.Context, req *adkfs.GlobInfoRequest) ([]adkfs.FileInfo, error) { + if req == nil || req.Pattern == "" { + return nil, fmt.Errorf("glob: invalid request") + } + pathAbs, err := f.resolveDirPath(req.Path) + if err != nil { + return nil, err + } + pattern := req.Pattern + if filepath.IsAbs(pattern) { + cp := filepath.Clean(pattern) + if cp == pathAbs { + pattern = "." + } else if strings.HasPrefix(cp, pathAbs+string(filepath.Separator)) { + rel, rerr := filepath.Rel(pathAbs, cp) + if rerr != nil { + return nil, rerr + } + pattern = filepath.ToSlash(rel) + } else if strings.HasPrefix(cp, f.baseClean+string(filepath.Separator)) { + rel, rerr := filepath.Rel(f.baseClean, cp) + if rerr != nil { + return nil, rerr + } + pattern = filepath.ToSlash(rel) + pathAbs = f.baseClean + } else { + return nil, fmt.Errorf("%s: glob pattern out of bounds: %s", f.errorPrefix, cp) + } + } else { + pattern = filepath.ToSlash(pattern) + } + n := *req + n.Path = pathAbs + n.Pattern = pattern + return f.backend.GlobInfo(ctx, &n) +} + +func (f *FSBackend) LsInfo(ctx context.Context, req *adkfs.LsInfoRequest) ([]adkfs.FileInfo, error) { + if !f.allowLs { + return nil, fmt.Errorf("ls: disabled") + } + if req == nil { + return nil, fmt.Errorf("ls: invalid request") + } + base, err := f.resolveDirPath(req.Path) + if err != nil { + return nil, err + } + return f.GlobInfo(ctx, &adkfs.GlobInfoRequest{Path: base, Pattern: "*"}) +} + +func (f *FSBackend) GrepRaw(context.Context, *adkfs.GrepRequest) ([]adkfs.GrepMatch, error) { + if !f.allowGrep { + return nil, fmt.Errorf("grep: disabled") + } + return nil, fmt.Errorf("grep: not implemented") +} diff --git a/adk/middlewares/automemory/local_backend.go b/adk/middlewares/automemory/local_backend.go new file mode 100644 index 000000000..d437caaef --- /dev/null +++ b/adk/middlewares/automemory/local_backend.go @@ -0,0 +1,192 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/bmatcuk/doublestar/v4" +) + +// LocalBackend implements Backend on the local OS filesystem. +// It is intentionally minimal (Read + GlobInfo) to match the "方案一" storage abstraction. +type LocalBackend struct{} + +// NewLocalBackend returns a filesystem-backed Backend implementation. +func NewLocalBackend() *LocalBackend { + return &LocalBackend{} +} + +func (b *LocalBackend) Read(_ context.Context, req *ReadRequest) (*FileContent, error) { + if req == nil || req.FilePath == "" { + return nil, fmt.Errorf("read: invalid request") + } + + raw, err := os.ReadFile(req.FilePath) + if err != nil { + return nil, err + } + + content := string(raw) + offset := req.Offset - 1 + if offset < 0 { + offset = 0 + } + limit := req.Limit + + if offset == 0 && limit <= 0 { + return &FileContent{Content: content}, nil + } + + start := 0 + for i := 0; i < offset; i++ { + idx := strings.IndexByte(content[start:], '\n') + if idx == -1 { + return &FileContent{Content: ""}, nil + } + start += idx + 1 + } + + if limit <= 0 { + return &FileContent{Content: content[start:]}, nil + } + + end := start + for i := 0; i < limit; i++ { + idx := strings.IndexByte(content[end:], '\n') + if idx == -1 { + return &FileContent{Content: content[start:]}, nil + } + end += idx + 1 + } + + return &FileContent{Content: strings.TrimSuffix(content[start:end], "\n")}, nil +} + +func (b *LocalBackend) GlobInfo(_ context.Context, req *GlobInfoRequest) ([]FileInfo, error) { + if req == nil || req.Pattern == "" || req.Path == "" { + return nil, fmt.Errorf("glob: invalid request") + } + + root := filepath.Clean(req.Path) + var matches []FileInfo + type item struct { + fi FileInfo + t time.Time + } + var tmp []item + + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + + rel, err := filepath.Rel(root, path) + if err != nil { + return err + } + rel = filepath.ToSlash(rel) + + ok, err := doublestar.Match(req.Pattern, rel) + if err != nil { + return err + } + if !ok { + return nil + } + + st, err := os.Stat(path) + if err != nil { + return err + } + + tmp = append(tmp, item{ + fi: FileInfo{ + Path: path, + IsDir: false, + Size: st.Size(), + ModifiedAt: st.ModTime().Format(time.RFC3339Nano), + }, + t: st.ModTime(), + }) + return nil + }) + if err != nil { + return nil, err + } + + sort.Slice(tmp, func(i, j int) bool { return tmp[i].t.After(tmp[j].t) }) + matches = make([]FileInfo, 0, len(tmp)) + for _, it := range tmp { + matches = append(matches, it.fi) + } + return matches, nil +} + +func (b *LocalBackend) Write(_ context.Context, req *WriteRequest) error { + if req == nil || req.FilePath == "" { + return fmt.Errorf("write: invalid request") + } + path := filepath.Clean(req.FilePath) + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + tmp := path + ".tmp" + if err := os.WriteFile(tmp, []byte(req.Content), 0o644); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func (b *LocalBackend) Edit(ctx context.Context, req *EditRequest) error { + if req == nil || req.FilePath == "" { + return fmt.Errorf("edit: invalid request") + } + fc, err := b.Read(ctx, &ReadRequest{FilePath: req.FilePath}) + if err != nil { + return err + } + if req.OldString == "" { + return fmt.Errorf("edit: old string must be non-empty") + } + if req.OldString == req.NewString { + return fmt.Errorf("edit: new string must differ from old string") + } + + out := fc.Content + if req.ReplaceAll { + out = strings.ReplaceAll(out, req.OldString, req.NewString) + } else { + if strings.Count(out, req.OldString) != 1 { + return fmt.Errorf("edit: old string must appear exactly once when ReplaceAll is false") + } + out = strings.Replace(out, req.OldString, req.NewString, 1) + } + return b.Write(ctx, &WriteRequest{FilePath: req.FilePath, Content: out}) +} diff --git a/adk/middlewares/automemory/prompt.go b/adk/middlewares/automemory/prompt.go new file mode 100644 index 000000000..2f5d9130c --- /dev/null +++ b/adk/middlewares/automemory/prompt.go @@ -0,0 +1,316 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package automemory + +import ( + "fmt" + "strings" + + "github.com/cloudwego/eino/adk/internal" +) + +const ( + defaultMemoryInstruction = `# auto memory + +You have a persistent auto memory directory at "{memory_dir}". Its contents persist across conversations. + +As you work, consult your memory files to build on previous experience. + +## How to save memories: +- Organize memory semantically by topic, not chronologically +- Use the Write and Edit tools to update your memory files +- 'MEMORY.md' is always loaded into your conversation context — content is truncated after 200 lines or 4KB, so keep it concise +- Create separate topic files (e.g., 'debugging.md'', 'patterns.md'') for detailed notes and link to them from MEMORY.md +- Update or remove memories that turn out to be wrong or outdated +- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one. + +## What to save: +- Stable patterns and conventions confirmed across multiple interactions +- Key architectural decisions, important file paths, and project structure +- User preferences for workflow, tools, and communication style +- Solutions to recurring problems and debugging insights + +## What NOT to save: +- Session-specific context (current task details, in-progress work, temporary state) +- Information that might be incomplete — verify against project docs before writing +- Anything that duplicates or contradicts existing AGENTS.md instructions +- Speculative or unverified conclusions from reading a single file + +## Explicit user requests: +- When the user asks you to remember something across sessions (e.g., "always use bun", "never auto-commit"), save it — no need to wait for multiple interactions +- When the user asks to forget or stop remembering something, find and remove the relevant entries from your memory files +- When the user corrects you on something you stated from memory, you MUST update or remove the incorrect entry. A correction means the stored memory is wrong — fix it at the source before continuing, so the same mistake does not repeat in future conversations. + +## Searching past context +- Search topic files in your memory directory: Grep with pattern="" path="{memory_dir}" glob="*.md" +- Use narrow search terms (error messages, file paths, function names) rather than broad keywords. + +` + + defaultAppendCurrentIndexTruncNotify = `WARNING: MEMORY.md was truncated (lines: {memory_lines}, limit: 200; byte limit: 4096). Move detailed content into separate topic files and keep MEMORY.md as a concise index.` + + defaultAppendEmptyIndexTemplate = `Your MEMORY.md is currently empty. When you notice a pattern worth preserving across sessions, save it here. Anything in MEMORY.md will be included in your system prompt next time.` + + defaultTopicSelectionSystemPrompt = `You are selecting memories that will be useful to the agent as it processes a user's query. You will be given the user's query and a list of available memory files with their filenames and descriptions. + +Return a list of RELATIVE FILE PATHS (relative to the memory directory) for the memories that will clearly be useful to the agent as it processes the user's query (up to 5). Only include memories that you are certain will be helpful based on their name/description/type. +- If you are unsure if a memory will be useful in processing the user's query, then do not include it in your list. Be selective and discerning. +- If there are no memories in the list that would clearly be useful, feel free to return an empty list. +- If a list of recently-used tools is provided, do not select memories that are usage reference or API documentation for those tools (the agent is already exercising them). DO still select memories containing warnings, gotchas, or known issues about those tools — active use is exactly when those matter.` + + defaultTopicSelectionUserPrompt = `Query: {user_query} + +Available memories: +{available_memories} + +Recently used tools: +{tools}` + + defaultTopicMemoryTruncNotify = ` +> This memory file was truncated ({reason}). Use the Read tool to view the complete file at: {abs_path}` + + defaultMemoryInstructionChinese = `# 自动记忆 + +你有一个持久化的自动记忆目录 "{memory_dir}"。其中的内容会在不同会话之间保留。 + +在工作过程中,请查阅这些记忆文件,以便基于过去的经验继续推进。 + +## 如何保存记忆: +- 按主题组织记忆,而不是按时间顺序堆叠 +- 使用 Write 和 Edit 工具更新你的记忆文件 +- 'MEMORY.md' 会始终被加载到对话上下文中,其内容在超过 200 行或 4KB 时会被截断,因此请保持简洁 +- 将详细内容写入单独的主题文件(例如 'debugging.md'、'patterns.md'),并在 MEMORY.md 中链接它们 +- 当某条记忆被证明错误或过时时,请更新或删除它 +- 不要写入重复记忆。创建新记忆前,先检查是否已有可更新的现有文件 + +## 应该保存什么: +- 已在多次交互中得到确认的稳定模式和约定 +- 关键架构决策、重要文件路径和项目结构 +- 用户在工作流、工具使用和沟通方式上的偏好 +- 可复用的问题解决经验与调试结论 + +## 不应保存什么: +- 仅属于当前会话的上下文(当前任务细节、进行中的工作、临时状态) +- 可能不完整的信息,在写入前应先根据项目文档核实 +- 与现有 AGENTS.md 指令重复或冲突的内容 +- 仅基于阅读单个文件得到的猜测性或未经验证的结论 + +## 用户的明确要求: +- 当用户明确要求你跨会话记住某件事时(例如“始终使用 bun”“不要自动提交”),应立即保存,无需等待多轮交互确认 +- 当用户要求你遗忘某件事或停止记忆时,找到对应条目并从记忆文件中删除 +- 当用户指出你基于记忆给出的内容有误时,你必须更新或删除错误条目。纠正意味着原有记忆已经错误,必须先从源头修正,避免今后重复犯错 + +## 如何检索历史上下文 +- 在记忆目录中搜索主题文件:使用 Grep,pattern="<搜索词>" path="{memory_dir}" glob="*.md" +- 尽量使用更窄的检索词,例如报错信息、文件路径、函数名,而不是宽泛关键词 + +` + + defaultAppendCurrentIndexTruncNotifyChinese = `警告:MEMORY.md 已被截断(总行数:{memory_lines},限制:200 行;字节限制:4096)。请将详细内容迁移到独立的主题文件中,并让 MEMORY.md 只保留简洁索引。` + + defaultAppendEmptyIndexTemplateChinese = `你的 MEMORY.md 当前为空。当你发现值得跨会话保留的模式时,请把它写在这里。下一次对话中,MEMORY.md 的内容会被自动加入 system prompt。` + + defaultTopicSelectionSystemPromptChinese = `你需要从记忆列表中选择对当前用户问题真正有帮助的记忆。你会拿到用户问题,以及一组可用记忆文件的文件名和描述。 + +请返回一个 RELATIVE FILE PATHS 列表(相对于 memory directory),列出那些在处理当前用户问题时显然有帮助的记忆文件(最多 5 个)。只有在你能够基于名称、描述或类型确认其确实有帮助时才选择。 +- 如果你不能确定某条记忆是否有帮助,就不要选它。请保持克制和甄别。 +- 如果列表中没有任何明显有帮助的记忆,可以返回空列表。 +- 如果提供了最近使用过的工具列表,不要选择那些仅包含这些工具使用说明或 API 文档的记忆(agent 已经在使用它们)。但如果记忆中包含这些工具的警告、坑点或已知问题,仍然应该选择,因为这些内容在实际调用时尤其重要。` + + defaultTopicSelectionUserPromptChinese = `问题:{user_query} + +可用记忆: +{available_memories} + +最近使用的工具: +{tools}` + + defaultTopicMemoryTruncNotifyChinese = ` +> 该记忆文件已被截断({reason})。请使用 Read 工具查看完整文件:{abs_path}` +) + +func buildExtractAutoOnlyPrompt(memoryDir string, newMessageCount int, existingMemories string, skipIndex bool) string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: buildExtractAutoOnlyPromptEnglish(memoryDir, newMessageCount, existingMemories, skipIndex), + Chinese: buildExtractAutoOnlyPromptChinese(memoryDir, newMessageCount, existingMemories, skipIndex), + }) +} + +func joinLines(lines []string) string { + if len(lines) == 0 { + return "" + } + var b strings.Builder + b.WriteString(lines[0]) + for i := 1; i < len(lines); i++ { + b.WriteString("\n") + b.WriteString(lines[i]) + } + return b.String() +} + +func getDefaultMemoryInstruction() string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: defaultMemoryInstruction, + Chinese: defaultMemoryInstructionChinese, + }) +} + +func getAppendCurrentIndexTruncNotify() string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: defaultAppendCurrentIndexTruncNotify, + Chinese: defaultAppendCurrentIndexTruncNotifyChinese, + }) +} + +func getAppendEmptyIndexTemplate() string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: defaultAppendEmptyIndexTemplate, + Chinese: defaultAppendEmptyIndexTemplateChinese, + }) +} + +func getTopicSelectionSystemPrompt() string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: defaultTopicSelectionSystemPrompt, + Chinese: defaultTopicSelectionSystemPromptChinese, + }) +} + +func getTopicSelectionUserPrompt() string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: defaultTopicSelectionUserPrompt, + Chinese: defaultTopicSelectionUserPromptChinese, + }) +} + +func getTopicMemoryTruncNotify() string { + return internal.SelectPrompt(internal.I18nPrompts{ + English: defaultTopicMemoryTruncNotify, + Chinese: defaultTopicMemoryTruncNotifyChinese, + }) +} + +func buildExtractAutoOnlyPromptEnglish(memoryDir string, newMessageCount int, existingMemories string, skipIndex bool) string { + manifest := "" + if existingMemories != "" { + manifest = fmt.Sprintf("\n\n## Existing memory files\n\n%s\n\nCheck this list before writing — update an existing file rather than creating a duplicate.", existingMemories) + } + + howToSave := []string{ + "## How to save memories", + "", + "Saving a memory is a two-step process:", + "", + "Step 1 — write the memory to its own file.", + "Step 2 — add a pointer to that file in MEMORY.md. MEMORY.md is an index, not the memory body.", + "", + "- Keep MEMORY.md concise because it is loaded into system prompt context.", + "- Organize memory semantically by topic, not chronologically.", + "- Update or remove memories that turn out to be wrong or outdated.", + "- Do not write duplicate memories.", + } + if skipIndex { + howToSave = []string{ + "## How to save memories", + "", + "Write each memory to its own file. Do not create duplicate files.", + } + } + + parts := []string{ + fmt.Sprintf("You are now acting as the memory extraction subagent. Analyze only the most recent ~%d messages above and use them to update persistent memory.", newMessageCount), + "", + fmt.Sprintf("Memory directory: %s", memoryDir), + "", + "Available tools: read_file, glob, write_file, edit_file. Only paths inside the memory directory are allowed. All other tools are denied.", + "", + "You have a limited turn budget. read_file should happen first for every file you may update, then write_file/edit_file should happen after that. Do not interleave read and write across many turns.", + "", + fmt.Sprintf("You MUST only use content from the last ~%d messages to update memories. Do not investigate code or verify against source files further.", newMessageCount) + manifest, + "", + "If the user explicitly asks you to remember something, save it immediately. If they ask you to forget something, find and remove the relevant memory.", + "", + "## What to save", + "- Stable patterns and conventions confirmed across multiple interactions", + "- Important file paths, architectural decisions, and user preferences", + "- Recurring debugging insights and known gotchas", + "", + "## What NOT to save", + "- Session-specific temporary state or current task details", + "- Secrets, credentials, or personal data", + "- Speculative or unverified conclusions", + "", + } + parts = append(parts, howToSave...) + return joinLines(parts) +} + +func buildExtractAutoOnlyPromptChinese(memoryDir string, newMessageCount int, existingMemories string, skipIndex bool) string { + manifest := "" + if existingMemories != "" { + manifest = fmt.Sprintf("\n\n## 现有记忆文件\n\n%s\n\n写入前请先检查这份列表,优先更新已有文件,而不是创建重复记忆。", existingMemories) + } + + howToSave := []string{ + "## 如何保存记忆", + "", + "保存记忆分为两步:", + "", + "第 1 步:将记忆写入独立文件。", + "第 2 步:在 MEMORY.md 中添加指向该文件的索引。MEMORY.md 只是索引,不应存放记忆正文。", + "", + "- 保持 MEMORY.md 简洁,因为它会被加载进 system prompt。", + "- 按主题组织记忆,而不是按时间顺序堆叠。", + "- 当记忆被证明错误或过时时,要及时更新或删除。", + "- 不要写入重复记忆。", + } + if skipIndex { + howToSave = []string{ + "## 如何保存记忆", + "", + "将每条记忆写入各自独立的文件中,不要创建重复文件。", + } + } + + parts := []string{ + fmt.Sprintf("你现在扮演 memory extraction subagent。只分析上方最近约 %d 条消息,并用它们来更新持久化记忆。", newMessageCount), + "", + fmt.Sprintf("记忆目录:%s", memoryDir), + "", + "可用工具:read_file、glob、write_file、edit_file。只允许访问记忆目录内的路径,其他工具均禁止使用。", + "", + "你的轮次预算有限。对于每个可能更新的文件,应先 read_file,再进行 write_file/edit_file;不要在多轮里交叉读写大量文件。", + "", + fmt.Sprintf("你必须只使用最近约 %d 条消息中的内容来更新记忆。不要继续调查代码,也不要再去源码中额外验证。", newMessageCount) + manifest, + "", + "如果用户明确要求你记住某件事,请立即保存;如果用户要求遗忘某件事,请找到对应记忆并删除。", + "", + "## 应该保存什么", + "- 已在多次交互中得到确认的稳定模式和约定", + "- 重要文件路径、架构决策和用户偏好", + "- 可复用的调试经验与已知坑点", + "", + "## 不应保存什么", + "- 仅属于当前会话的临时状态或当前任务细节", + "- 密钥、凭据或个人数据", + "- 猜测性或未经验证的结论", + "", + } + parts = append(parts, howToSave...) + return joinLines(parts) +} diff --git a/adk/middlewares/dynamictool/toolsearch/toolsearch.go b/adk/middlewares/dynamictool/toolsearch/toolsearch.go index 3b10e95b5..66566867c 100644 --- a/adk/middlewares/dynamictool/toolsearch/toolsearch.go +++ b/adk/middlewares/dynamictool/toolsearch/toolsearch.go @@ -134,7 +134,7 @@ type typedMiddleware[M adk.MessageType] struct { sr string } -func (m *typedMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { +func (m *typedMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext[M]) (context.Context, *adk.ChatModelAgentContext[M], error) { if runCtx == nil { return ctx, runCtx, nil } diff --git a/adk/middlewares/filesystem/filesystem.go b/adk/middlewares/filesystem/filesystem.go index 619b46ca2..cb72c9e74 100644 --- a/adk/middlewares/filesystem/filesystem.go +++ b/adk/middlewares/filesystem/filesystem.go @@ -434,7 +434,7 @@ type typedFilesystemMiddleware[M adk.MessageType] struct { additionalTools []tool.BaseTool } -func (m *typedFilesystemMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { +func (m *typedFilesystemMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext[M]) (context.Context, *adk.ChatModelAgentContext[M], error) { if runCtx == nil { return ctx, runCtx, nil } diff --git a/adk/middlewares/filesystem/filesystem_test.go b/adk/middlewares/filesystem/filesystem_test.go index 11c5b07ea..b816997fe 100644 --- a/adk/middlewares/filesystem/filesystem_test.go +++ b/adk/middlewares/filesystem/filesystem_test.go @@ -961,7 +961,7 @@ func TestFilesystemMiddleware_BeforeAgent(t *testing.T) { m, err := New(ctx, &MiddlewareConfig{Backend: backend}) assert.NoError(t, err) - runCtx := &adk.ChatModelAgentContext{ + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ Instruction: "Original instruction", Tools: nil, } diff --git a/adk/middlewares/plantask/plantask.go b/adk/middlewares/plantask/plantask.go index fb201bddb..69ac74e20 100644 --- a/adk/middlewares/plantask/plantask.go +++ b/adk/middlewares/plantask/plantask.go @@ -63,7 +63,7 @@ type typedMiddleware[M adk.MessageType] struct { baseDir string } -func (m *typedMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { +func (m *typedMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext[M]) (context.Context, *adk.ChatModelAgentContext[M], error) { if runCtx == nil { return ctx, runCtx, nil } diff --git a/adk/middlewares/plantask/plantask_test.go b/adk/middlewares/plantask/plantask_test.go index 2354e79fd..6a76a70ce 100644 --- a/adk/middlewares/plantask/plantask_test.go +++ b/adk/middlewares/plantask/plantask_test.go @@ -62,7 +62,7 @@ func TestMiddlewareBeforeAgent(t *testing.T) { assert.NoError(t, err) assert.Nil(t, runCtx) - runCtx = &adk.ChatModelAgentContext{ + runCtx = &adk.ChatModelAgentContext[*schema.Message]{ Tools: []tool.BaseTool{}, } ctx, newRunCtx, err := mw.BeforeAgent(ctx, runCtx) diff --git a/adk/middlewares/skill/skill.go b/adk/middlewares/skill/skill.go index 8f8b2cad3..940d71f84 100644 --- a/adk/middlewares/skill/skill.go +++ b/adk/middlewares/skill/skill.go @@ -272,7 +272,7 @@ type typedSkillHandler[M adk.MessageType] struct { tool *typedSkillTool[M] } -func (h *typedSkillHandler[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { +func (h *typedSkillHandler[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext[M]) (context.Context, *adk.ChatModelAgentContext[M], error) { runCtx.Instruction = runCtx.Instruction + "\n" + h.instruction runCtx.Tools = append(runCtx.Tools, h.tool) return ctx, runCtx, nil diff --git a/adk/middlewares/skill/skill_test.go b/adk/middlewares/skill/skill_test.go index 3cc536abd..5c0596bab 100644 --- a/adk/middlewares/skill/skill_test.go +++ b/adk/middlewares/skill/skill_test.go @@ -456,7 +456,7 @@ func TestBeforeAgent(t *testing.T) { handler, err := NewMiddleware(ctx, &Config{Backend: backend}) require.NoError(t, err) - runCtx := &adk.ChatModelAgentContext{ + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ Instruction: "base instruction", Tools: []tool.BaseTool{}, } diff --git a/adk/prebuilt/deep/deep_test.go b/adk/prebuilt/deep/deep_test.go index 2e9802a35..5f71cee6d 100644 --- a/adk/prebuilt/deep/deep_test.go +++ b/adk/prebuilt/deep/deep_test.go @@ -192,7 +192,7 @@ func TestDeepAgentFilesystemExecuteDefaults(t *testing.T) { assert.NoError(t, err) assert.Len(t, handlers, 1) - _, runCtx, err := handlers[0].BeforeAgent(ctx, &adk.ChatModelAgentContext{}) + _, runCtx, err := handlers[0].BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{}) assert.NoError(t, err) assert.NotNil(t, runCtx) assert.Len(t, runCtx.Tools, tt.wantToolLen) @@ -244,7 +244,7 @@ func TestDeepAgentManualFilesystemMiddlewarePath(t *testing.T) { }) assert.NoError(t, err) - _, runCtx, err := fsMW.BeforeAgent(ctx, &adk.ChatModelAgentContext{}) + _, runCtx, err := fsMW.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{}) assert.NoError(t, err) assert.Len(t, runCtx.Tools, 1) info, err := runCtx.Tools[0].Info(ctx) diff --git a/adk/prebuilt/deep/types.go b/adk/prebuilt/deep/types.go index 781418bf3..7be75c34b 100644 --- a/adk/prebuilt/deep/types.go +++ b/adk/prebuilt/deep/types.go @@ -55,7 +55,7 @@ type typedAppendPromptTool[M adk.MessageType] struct { prompt string } -func (w *typedAppendPromptTool[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) { +func (w *typedAppendPromptTool[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext[M]) (context.Context, *adk.ChatModelAgentContext[M], error) { nRunCtx := *runCtx nRunCtx.Instruction += w.prompt if w.t != nil { diff --git a/adk/session_extra_test.go b/adk/session_extra_test.go index eb127c76e..e12823ba2 100644 --- a/adk/session_extra_test.go +++ b/adk/session_extra_test.go @@ -475,172 +475,6 @@ func agenticToolResultMessage(callID, name, text string) *schema.AgenticMessage } } -func TestStreamPersistence_AgenticToolResultChunksConcat(t *testing.T) { - ctx := context.Background() - store := newSessionHelperStore() - sid := "agentic-tool-stream-session" - - agent := &agenticSessionStreamingAgent{ - chunks: []*schema.AgenticMessage{ - agenticToolResultMessage("call_1", "execute", "first\n"), - agenticToolResultMessage("call_1", "execute", "second\n"), - }, - turnEnd: &TurnEndState[*schema.AgenticMessage]{ - Messages: []*schema.AgenticMessage{ - schema.UserAgenticMessage("q"), - agenticToolResultMessage("call_1", "execute", "first\nsecond\n"), - }, - }, - } - - runner := NewTypedRunner(TypedRunnerConfig[*schema.AgenticMessage]{ - Agent: agent, - EnableStreaming: true, - SessionID: sid, - SessionStore: store, - SessionConfig: &SessionConfig{EventFlushBatchSize: 1}, - }) - - iter := runner.Run(ctx, []*schema.AgenticMessage{schema.UserAgenticMessage("q")}) - for { - ev, ok := iter.Next() - if !ok { - break - } - require.NoError(t, ev.Err) - if ev.Output != nil && ev.Output.MessageOutput != nil && - ev.Output.MessageOutput.IsStreaming && ev.Output.MessageOutput.MessageStream != nil { - for { - _, err := ev.Output.MessageOutput.MessageStream.Recv() - if err == io.EOF { - break - } - require.NoError(t, err) - } - } - } - - var stored *SessionEvent[*schema.AgenticMessage] - store.mu.Lock() - snapshot := append([]SessionEventPayload{}, store.events...) - store.mu.Unlock() - for _, ep := range snapshot { - se, err := decodeSessionEvent[*schema.AgenticMessage](ep.Data) - require.NoError(t, err) - if se.Kind == SessionEventMessage && se.Message != nil && - len(se.Message.ContentBlocks) == 1 && - se.Message.ContentBlocks[0].Type == schema.ContentBlockTypeFunctionToolResult { - stored = se - break - } - } - - require.NotNil(t, stored) - require.NotNil(t, stored.Message) - require.Len(t, stored.Message.ContentBlocks, 1) - ftr := stored.Message.ContentBlocks[0].FunctionToolResult - require.NotNil(t, ftr) - assert.Equal(t, "call_1", ftr.CallID) - assert.Equal(t, "execute", ftr.Name) - require.Len(t, ftr.Content, 1) - assert.Equal(t, "first\nsecond\n", ftr.Content[0].Text.Text) - assert.Nil(t, stored.Message.ContentBlocks[0].StreamingMeta) -} - -func TestStreamPersistence_AgenticToolResultChunksWithStreamingMeta(t *testing.T) { - ctx := context.Background() - store := newSessionHelperStore() - sid := "agentic-tool-stream-meta-session" - - first := agenticToolResultMessage("call_1", "execute", "first\n") - second := agenticToolResultMessage("call_1", "execute", "second\n") - first.ContentBlocks[0].StreamingMeta = &schema.StreamingMeta{Index: 0} - second.ContentBlocks[0].StreamingMeta = &schema.StreamingMeta{Index: 0} - - agent := &agenticSessionStreamingAgent{ - chunks: []*schema.AgenticMessage{first, second}, - turnEnd: &TurnEndState[*schema.AgenticMessage]{ - Messages: []*schema.AgenticMessage{ - schema.UserAgenticMessage("q"), - agenticToolResultMessage("call_1", "execute", "first\nsecond\n"), - }, - }, - } - - runner := NewTypedRunner(TypedRunnerConfig[*schema.AgenticMessage]{ - Agent: agent, - EnableStreaming: true, - SessionID: sid, - SessionStore: store, - SessionConfig: &SessionConfig{EventFlushBatchSize: 1}, - }) - - iter := runner.Run(ctx, []*schema.AgenticMessage{schema.UserAgenticMessage("q")}) - for { - ev, ok := iter.Next() - if !ok { - break - } - require.NoError(t, ev.Err) - if ev.Output != nil && ev.Output.MessageOutput != nil && - ev.Output.MessageOutput.IsStreaming && ev.Output.MessageOutput.MessageStream != nil { - for { - _, err := ev.Output.MessageOutput.MessageStream.Recv() - if err == io.EOF { - break - } - require.NoError(t, err) - } - } - } - - var stored *schema.AgenticMessage - store.mu.Lock() - snapshot := append([]SessionEventPayload{}, store.events...) - store.mu.Unlock() - for _, ep := range snapshot { - se, err := decodeSessionEvent[*schema.AgenticMessage](ep.Data) - require.NoError(t, err) - if se.Kind == SessionEventMessage && se.Message != nil && - len(se.Message.ContentBlocks) == 1 && - se.Message.ContentBlocks[0].Type == schema.ContentBlockTypeFunctionToolResult { - stored = se.Message - break - } - } - - require.NotNil(t, stored) - require.Len(t, stored.ContentBlocks, 1) - block := stored.ContentBlocks[0] - assert.Nil(t, block.StreamingMeta) - require.NotNil(t, block.FunctionToolResult) - assert.Equal(t, "call_1", block.FunctionToolResult.CallID) - assert.Equal(t, "execute", block.FunctionToolResult.Name) - require.Len(t, block.FunctionToolResult.Content, 1) - assert.Equal(t, "first\nsecond\n", block.FunctionToolResult.Content[0].Text.Text) -} - -func agenticToolResultMessage(callID, name, text string) *schema.AgenticMessage { - return &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - { - Type: schema.ContentBlockTypeFunctionToolResult, - FunctionToolResult: &schema.FunctionToolResult{ - CallID: callID, - Name: name, - Content: []*schema.FunctionToolResultContentBlock{ - { - Type: schema.FunctionToolResultContentBlockTypeText, - Text: &schema.UserInputText{Text: text}, - }, - }, - }, - }, - }, - } -} - // TestStreamPersistence_GetMessageError_NotEnqueued verifies that a stream // materialization error sets persistErr (failing the turn commit) and does NOT // enqueue a corrupt SessionEvent. diff --git a/permission_middleware_comprehensive_review.md b/permission_middleware_comprehensive_review.md deleted file mode 100644 index 84b35683a..000000000 --- a/permission_middleware_comprehensive_review.md +++ /dev/null @@ -1,105 +0,0 @@ -# Comprehensive Review Summary: Permission Middleware - -## Overview - -- **Iterations**: Stage 1: 1, Stage 2: 1, Stage 3: 1 -- **Scope**: `adk/middlewares/permission`, permission decision observation, and message-path timeline propagation -- **Files modified**: 4 -- **Lines changed**: +212 / -9 before this report -- **Final verification**: `go test ./...` passed - -## Stage 1: Design Review - -### Findings Resolved - -| # | Dimension | Severity | Finding | Fix Applied | Files | -|---|-----------|----------|---------|-------------|-------| -| 1 | API Safety | P1 | Targeted resume approval executed the current invocation arguments instead of the arguments shown in the persisted `AskState`. | Targeted resumes now require `AskState` and approve the saved interrupted arguments by default. | `adk/middlewares/permission/permission.go` | -| 2 | Observability | P1 | `ToolSpanMeta.EvaluatedPermission` was exposed but permission decisions were never recorded by the middleware. | Permission decisions are now stored for allow, deny, ask, approve, reject, and respond paths; tool span start/end events carry the resolved permission decision. | `adk/middlewares/permission/permission.go`, `adk/wrappers.go` | -| 3 | API Expressiveness | P2 | `UpdatedInput string` could not intentionally replace arguments with an empty string. | Added `HasUpdatedInput` flags while preserving existing non-empty `UpdatedInput` behavior for compatibility. | `adk/middlewares/permission/permission.go` | -| 4 | Timeline Propagation | P1 | The `*schema.Message` ReAct exec context did not copy session/timeline flags, suppressing tool-use timeline observations. | Propagated `sessionEvents`, `timelineEvents`, and `internalTimelineEvents` into the message-path exec context. | `adk/chatmodel.go` | - -### Final Scorecard - -| Dimension | Rating | Notes | -|-----------|--------|-------| -| Concept Coherence | 5/5 | Permission checking, resume resolution, and observation are now aligned. | -| API Usability | 4/5 | `HasUpdatedInput` makes empty replacement explicit while remaining backward compatible. | -| Minimum API Surface | 4/5 | One explicit flag was added to each input-update API; no new exported helper was introduced. | -| Backward Compatibility | 5/5 | Existing non-empty `UpdatedInput` behavior remains unchanged. | -| Module Separation | 4/5 | Middleware records decisions; event sender remains responsible for observation emission. | -| Readability | 4/5 | Resume binding is explicit and fail-fast on missing `AskState`. | - -## Stage 2: Attack Review - -### Bugs Fixed - -| # | Severity | Bug | Fix | Test | -|---|----------|-----|-----|------| -| 1 | P1 | An approved permission ask could execute mutated arguments supplied at resume time. | Resume approve uses `AskState.Info.Arguments` unless `HasUpdatedInput` or non-empty `UpdatedInput` explicitly overrides it. | `TestWrapInvokableToolCall_ResumeApproveUsesSavedInterruptedArguments` | -| 2 | P1 | Permission decisions were not observable in tool-use timeline events. | Decisions are recorded before tool-use observation; message-path timeline flags are propagated. | `TestPermissionDecisionAppearsInToolUseTimeline` | -| 3 | P2 | Empty argument replacement was impossible through `UpdatedInput`. | Added explicit `HasUpdatedInput` flags. | `TestWrapInvokableToolCall_AllowWithExplicitEmptyUpdatedInput`, `TestWrapInvokableToolCall_ResumeApproveWithExplicitEmptyUpdatedInput` | - -### Attack Test Results - -- **Total focused regression tests**: 4 -- **Result**: all passing -- **Additional package coverage**: full `./adk/middlewares/permission` package passing - -## Stage 3: Test Audit - -### Improvements Applied - -| # | Category | Change | LOC Impact | -|---|----------|--------|------------| -| 1 | Coverage Gap | Added saved-argument resume binding regression. | +37 LOC | -| 2 | Coverage Gap | Added explicit empty input replacement coverage for allow and resume approve paths. | +45 LOC | -| 3 | Observability Gap | Added end-to-end Runner timeline coverage for `evaluated_permission`. | +70 LOC | -| 4 | Test Utility | Added a small in-package session store and capture tool for permission middleware tests. | +24 LOC | - -### Audit Verdict - -- No duplicate permission tests were introduced. -- Assertions check endpoint arguments and observable timeline fields, not only non-nil outcomes. -- The new session store helper is local to the test and keeps the timeline regression self-contained. - -## Verification Log - -| Command | Result | -|---------|--------| -| `go test ./adk/middlewares/permission -run 'TestWrapInvokableToolCall_(ResumeApproveUsesSavedInterruptedArguments|ResumeApproveWithExplicitEmptyUpdatedInput|AllowWithExplicitEmptyUpdatedInput)|TestPermissionDecisionAppearsInToolUseTimeline' -count=1 -v` | Pass | -| `go test ./adk/middlewares/permission -run 'TestWrapInvokableToolCall_(ResumeApproveUsesSavedInterruptedArguments|ResumeApproveWithExplicitEmptyUpdatedInput|AllowWithExplicitEmptyUpdatedInput|Respond)|TestPermissionGate_(AskThenResumeApprovedWithUpdatedInput|AskThenResumeDenied|AskThenResumeRespond|ResumeRejectDoesNotExecute|InvalidResumeAction|InvalidGateDecision)|TestPermissionDecisionAppearsInToolUseTimeline' -count=1 -v` | Pass, second-pass attack review. | -| `go test ./adk/middlewares/permission -count=1` | Pass | -| `go test ./adk/middlewares/permission -coverprofile=/tmp/eino_permission_cover.out && go tool cover -func=/tmp/eino_permission_cover.out` | Pass, total 85.6%. | -| `go test ./adk -run 'TestWithTimelineEvents_LiveExposure|TestToolPermissionDecisionScopedByToolUseID|TestChatModelAgentRun/.*Tool|TestChatModelAgent_Middleware|TestChatModelAgentToolCallMiddleware' -count=1` | Pass | -| `go test ./...` | Pass | -| `git diff --check` | Pass | -| VS Code diagnostics on edited files | No errors; only pre-existing informational `infertypeargs` hints in untouched wrapper locations. | - -## Second-Pass Review - -| Stage | Result | Notes | -|-------|--------|-------| -| Design Review | Pass | Re-reviewed all 12 dimensions after fixes. No blocker found. The only residual design trade-off is that permission-aware `tool_use` observations are emitted after permission evaluation rather than strictly at raw tool start. | -| Attack Review | Pass | Re-ran adversarial coverage for saved-argument binding, explicit empty updates, invalid resume/gate actions, reject/respond paths, and evaluated permission timeline exposure. | -| Test Audit | Pass | Package coverage is 85.6%; no high-priority duplicate, weak assertion, or coverage-only tests found. | - -## Cumulative File Change List - -| File | Stage(s) | Summary | -|------|----------|---------| -| `adk/middlewares/permission/permission.go` | 1, 2 | Binds targeted resume to persisted ask arguments, records permission decisions, and adds explicit empty-update flags. | -| `adk/middlewares/permission/permission_test.go` | 2, 3 | Adds regressions for saved arguments, explicit empty updates, and evaluated permission timeline exposure. | -| `adk/wrappers.go` | 1, 2 | Emits tool-use observations after wrapped endpoint evaluation so decision metadata is available. | -| `adk/chatmodel.go` | 1, 2 | Propagates session and timeline flags through the message ReAct exec context. | -| `permission_middleware_comprehensive_review.md` | 4 | Records the comprehensive review process, fixes, and verification. | - -## Remaining Items - -| # | Priority | Item | Recommendation | -|---|----------|------|----------------| -| 1 | Low | The default event sender still reports the original input for string-based tool calls, while enhanced paths can report the updated `ToolArgument`. | Consider a future observation payload field for both original and effective tool input if this distinction becomes important. | - -## Verdict - -**APPROVE** after fixes. The confirmed blockers are resolved, regressions cover the failure modes, and the full repository test suite passes. diff --git a/resume_wait_timeout_comprehensive_review.md b/resume_wait_timeout_comprehensive_review.md deleted file mode 100644 index 9e8d05ce9..000000000 --- a/resume_wait_timeout_comprehensive_review.md +++ /dev/null @@ -1,152 +0,0 @@ -# Comprehensive Review: ResumeWaitTimeout (uncommitted changes) - -## Pre-Flight -- Files in scope: `adk/turn_loop.go` (+~250 LOC), `adk/turn_loop_test.go` (+~600 LOC) -- Baseline: `go build ./...` OK; `go test ./adk/ -run TestTurnLoop` OK; new tests pass under `-race`. -- Feature: a new `ResumeWaitTimeout` config that bounds how long a managed business - interrupt (`TurnLoopInterruptWaitsForExplicitResume`) waits for `Resume(...)`. - On expiry the loop persists the runner checkpoint and exits with `*InterruptError`. - Also adds: pre-load `Resume()` buffering, `InterruptContexts` carried in the - checkpoint, and a restored-session watcher. - ---- - -## Stage 1: Design Review - -### Iteration 1 — Scorecard - -| # | Dimension | Rating | Notes | -|---|-----------|--------|-------| -| 1 | Concept coherence | ⭐⭐⭐⭐⭐ | `ResumeWaitTimeout` reads naturally beside `InterruptMode`; "bounded wait → persist + InterruptError" is a clean concept. | -| 2 | API usability | ⭐⭐⭐⭐⭐ | Single `time.Duration` field, zero = unbounded (matches Go idiom). Doc comment states the no-op-unless-managed precondition. | -| 3 | Minimum API surface | ⭐⭐⭐⭐⭐ | Only one new public field. All other machinery is unexported. | -| 4 | Backward compatibility | ⭐⭐⭐⭐ | Zero value preserves old unbounded behavior. New `InterruptContexts` checkpoint field decodes to nil on old data. See F1 (gob risk). | -| 5 | Module separation | ⭐⭐⭐⭐⭐ | All within turn_loop.go; no layer leakage. | -| 6 | Cohesion vs tension | ⭐⭐⭐⭐ | Watcher↔cleanup↔takePendingResume coordination via `timerCancel`/`timedOut` is inherently distributed but well-commented. See F2. | -| 7 | Elegance vs complexity | ⭐⭐⭐⭐ | The pre-load Resume adoption defer + 3-case switch is the most accidental-feeling complexity. Justified but dense. See F3. | -| 8 | Naming | ⭐⭐⭐⭐⭐ | `interruptCtxSnapshot` deliberately distinct from `interrupted` and `l.interruptContexts`; `timerCancel`, `timedOut`, `closeTimerCancelLocked` all clear. | -| 9 | Readability | ⭐⭐⭐⭐ | Watcher double-check race handling is subtle but heavily commented. | -| 10 | Duplication | ⭐⭐⭐ | The arm-watcher block is duplicated verbatim between Phase 2 (run) and `armRestoredManagedWatcherIfNeeded`. See F4. | -| 11 | Public API docs | ⭐⭐⭐⭐⭐ | `ResumeWaitTimeout` doc covers expiry behavior, push-no-reset, zero-default, precondition. | -| 12 | Internal comments | ⭐⭐⭐⭐⭐ | Exceptionally thorough on the concurrency-sensitive paths. | - -### Findings - -- **F1 (gob durability of `InterruptContexts`)** — `nice-to-have/doc`: `turnLoopCheckpoint.InterruptContexts []*InterruptCtx` is gob-encoded. `InterruptCtx.Info` is `any`. If a real interrupt carries a non-gob-registered concrete type in `Info`, `saveTurnLoopCheckpoint` fails → surfaces as `CheckpointErr`. The runner checkpoint (`resumeBytes`) already encodes the same interrupt info, so this is partially redundant. Verdict pending. -- **F2 (watcher commitStop ordering)** — verify in Stage 2 (attack): watcher releases `resumeMu` then calls `commitStop()`; a `Resume()` racing in between. Move to attack tests rather than design fix. -- **F3 (pre-load adoption switch)** — `nice-to-have`: dense but each branch is commented and tested. Counter-argue likely "won't fix". -- **F4 (duplicated arm-watcher block)** — candidate fix: extract a small `armResumeWaitWatcherLocked` helper used by both the Phase 2 site and `armRestoredManagedWatcherIfNeeded`. - -### 1.2 Validate & Counter-Argue - -- **F1**: Real but low-severity. The pre-existing `cancel.go` path (`InterruptError` already carries `[]*InterruptCtx`) and the runner checkpoint already rely on the same `Info any` being serializable in practice, so this introduces no *new* class of failure beyond what resumable interrupts already require. Adding the contexts to the TurnLoop checkpoint is what lets a restored session re-synthesize the error (Test #13). **Verdict: Won't Fix** (consistent with existing serialization assumptions); no code change. -- **F2**: Not a design issue — defer to Stage 2 attack tests. **Verdict: Defer to Stage 2.** -- **F3**: Extracting would scatter the tightly-coupled branch logic across functions and hurt readability; it is exercised by Tests #9/#10/#12. **Verdict: Won't Fix.** -- **F4**: Genuine duplication of a 6-line block with identical guard semantics. Extracting a `*Locked` helper removes the duplication without changing behavior and centralizes the arming invariant. **Verdict: Fix.** - -### 1.3 Fix — F4 - -Extracted `armResumeWaitWatcherLocked(pr) (shouldArm bool)` (caller holds `resumeMu`), used by both the Phase 2 interrupt site and `armRestoredManagedWatcherIfNeeded`. - -### 1.5 Loop decision: all dimensions >= 4/5, single fix applied. Proceed to Stage 2. - ---- - -## Stage 2: Attack Review - -### Iteration 1 — attack tests (`adk/turn_loop_attack_test.go`) - -| # | Severity | Probe | Test | Result | -|---|----------|-------|------|--------| -| 1 | green | Resume vs timeout watcher race (50 iters) | `TestAttack_ResumeRacesTimeoutWatcher` | Always nil OR *InterruptError | -| 2 | green | Resume after timeout committed Stop | `TestAttack_ResumeAfterTimeoutFired` | Returns `ErrTurnLoopStopped` | -| 3 | green | Watcher goroutine leak (Stop/Resume/timeout) | `TestAttack_NoWatcherGoroutineLeak` | No leak | -| 4 | green | Concurrent pre-load Resume (16 goroutines) | `TestAttack_ConcurrentPreLoadResume` | Exactly 1 accepted, 15 in-progress | -| 5 | green | Pre-load Resume vs checkpoint accepted resume | `TestAttack_PreLoadResumeLosesToAcceptedCheckpointResume` | Checkpoint wins | -| 6 | green | Timeout still persists checkpoint | `TestAttack_TimeoutWithNilInterruptContexts` | Checkpoint attempted, no err | -| 7 | green | Stop vs timeout race (40 iters) | `TestAttack_StopAndTimeoutRace` | Deterministic, checkpoint persisted | -| 8 | green | Context cancel during resume wait | `TestAttack_ContextCancelDuringWait` | Prompt exit | - -All probes PASS under `-race`. Zero confirmed bugs. No production fixes required. - -F2 (watcher sets timedOut, releases lock, then Resume races before commitStop) -is resolved by the existing design: cleanup gates the synthesized error on -`pending.timedOut && !pending.resumeSubmitted`, so a Resume that wins sets -resumeSubmitted and suppresses the timeout error. Verified by probe #1. - -### 2.6 Loop decision: zero confirmed bugs. Proceed to Stage 3. - - ---- - -## Stage 3: Test Audit - -### Findings (PR tests in turn_loop_test.go) - -| Priority | Issue | Verdict | Action | -|----------|-------|---------|--------| -| High | Test #11 `NonManagedRestore_PreRunPushStillLegacy` re-invokes an existing test under a new name (no new coverage, misleading name) | Fix | Deleted | -| Medium | drain-then-stop `OnAgentEvents` + fresh `PrepareAgent` duplicated across #8/#9/#10/#12 | Fix | Extracted `freshStopPrepareAgent()` and `drainAndStop` helpers | -| Low | #4 uses `assert.GreaterOrEqual(elapsed, timeout/2)` loose lower bound | Won't Fix | Intentional timing tolerance under -race | -| Low | #6 vs #8 both assert "parked → Resume releases" | Won't Fix | Distinct paths (live vs restore); intentional pair | - -### Coverage (new production code, via TestTurnLoop + attack tests) - -| Function | Coverage | -|----------|----------| -| armResumeWaitWatcherLocked | 100% | -| armRestoredManagedWatcherIfNeeded | 100% | -| closeTimerCancelLocked | 100% | -| cleanup | 100% | -| tryLoadCheckpoint | 93.4% | -| Resume | 90.5% | -| watchResumeWait | 85.7% (remaining = nondeterministic post-lock race re-check) | - -Diff coverage exceeds the 85% target on meaningful paths. Full `go test ./adk/` -passes (33.9s); TurnLoop subset passes under `-race`. - -### 3.5 Loop decision: no High findings remain. Proceed to Stage 4. - ---- - -## Stage 4: Final Summary - -### Overview -- Iterations: Stage 1: 1, Stage 2: 1, Stage 3: 1 (no safety valves triggered) -- Production files modified: 1 (`adk/turn_loop.go`) -- Test files modified: 1 (`adk/turn_loop_test.go`) -- Net change vs baseline of the PR: +1104 / -9 - -### Stage 1 (Design) — change applied -| # | Dimension | Finding | Fix | File | -|---|-----------|---------|-----|------| -| F4 | Duplication | Arm-watcher 6-line block duplicated between Phase 2 and restored-watcher | Extracted `armResumeWaitWatcherLocked(pr) bool` helper; both sites call it | `adk/turn_loop.go` | - -F1/F2/F3 examined and resolved as Won't-Fix / Defer with recorded rationale. - -### Stage 2 (Attack) — bugs found -Zero confirmed bugs. 9 adversarial tests written; all pass under `-race`. -Per user decision, all 9 were merged into `adk/turn_loop_test.go` as durable -regression tests (concurrency hardening section). - -### Stage 3 (Test Audit) — changes applied -| # | Category | Change | LOC | -|---|----------|--------|-----| -| 1 | Semantic value | Deleted Test #11 (re-invoked an existing test under a new name) | -6 | -| 2 | Boilerplate | Extracted `freshStopPrepareAgent()` + `drainAndStop`; applied to #8/#9/#10/#12 | net negative inline | - -### Cumulative file change list -| File | Stage(s) | Summary | -|------|----------|---------| -| `adk/turn_loop.go` | 1 | Added `armResumeWaitWatcherLocked` helper; Phase 2 + restored-watcher now share it. No behavior change. | -| `adk/turn_loop_test.go` | 3 | Deleted noise test; extracted 2 test helpers; merged 9 race-hardening attack tests. | - -### Verification (final) -- `go build ./...` OK -- `gofmt -l` clean on both files -- `go test ./adk/` full suite OK (33.8s) -- TurnLoop + attack subset OK under `-race` -- New-code coverage: all new functions 85–100% - -### Remaining items -None. No safety valves triggered. From dc72a3ea6dc6eb36a592ec4c8e234463efe77f70 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Fri, 12 Jun 2026 14:05:14 +0800 Subject: [PATCH 4/8] feat: adapt agentic message (#1071) --- adk/middlewares/automemory/automemory.go | 205 ++++++++++-------- adk/middlewares/automemory/automemory_test.go | 28 +-- adk/middlewares/automemory/dream/config.go | 8 +- adk/middlewares/automemory/dream/dream.go | 24 +- 4 files changed, 156 insertions(+), 109 deletions(-) diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go index ec3aeb6a2..72d1fb4d4 100644 --- a/adk/middlewares/automemory/automemory.go +++ b/adk/middlewares/automemory/automemory.go @@ -49,17 +49,17 @@ type Config[M adk.MessageType] struct { MemoryBackend Backend - // Model is the default tool-calling model used by topic selection and memory extraction. + // Model is the default model used by topic selection and memory extraction. // Per-read/per-write overrides can be configured in Read.Model / Write.Model. - Model model.ToolCallingChatModel + Model model.BaseModel[M] // Read controls how memories are loaded and injected. // Optional. Defaults to Sync load with topic selection enabled (if Model is set). - Read *ReadConfig + Read *ReadConfig[M] // Write controls post-run memory extraction and persistence. // Optional. Default: disabled. - Write *WriteConfig + Write *WriteConfig[M] // Coordination controls session identity and distributed async extraction coordination. // Optional. Defaults to a local in-process coordinator. @@ -78,11 +78,11 @@ const ( ReadModeAsync ReadMode = "async" ) -type ReadConfig struct { +type ReadConfig[M adk.MessageType] struct { Mode ReadMode // Model is used for topic selection. Defaults to Config.Model. - Model model.ToolCallingChatModel + Model model.BaseModel[M] // Instruction overrides the default auto memory instruction block appended to system prompt. // Optional. @@ -126,11 +126,11 @@ const ( WriteModeSync WriteMode = "sync" ) -type WriteConfig struct { +type WriteConfig[M adk.MessageType] struct { Mode WriteMode // Model is used for memory extraction. Defaults to Config.Model. - Model model.ToolCallingChatModel + Model model.BaseModel[M] // MaxTurns caps the extractor's tool-call loop. MaxTurns int @@ -144,7 +144,7 @@ type WriteConfig struct { // // If nil, automemory uses the default drain behavior: ignore all events and // return the first ev.Err encountered (if any). - HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error + HandleExtractionIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.TypedAgentEvent[M]]) error } type middleware[M adk.MessageType] struct { @@ -154,8 +154,8 @@ type middleware[M adk.MessageType] struct { resolvedMemoryDirectory string - topicSelectionModel model.ToolCallingChatModel - extractionHandler adk.ChatModelAgentMiddleware + topicSelectionModel model.BaseModel[M] + extractionHandler adk.TypedChatModelAgentMiddleware[M] topicSelectionTool *schema.ToolInfo coordination *CoordinationConfig[M] } @@ -201,7 +201,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh return nil, fmt.Errorf("auto memory config: resolve memory directory: %w", err) } if cfg.Read == nil { - cfg.Read = &ReadConfig{} + cfg.Read = &ReadConfig[M]{} } applyReadDefaults(cfg) @@ -214,11 +214,10 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh m.topicSelectionTool = topicSelectionToolInfo() if cfg.Read.TopicSelection != nil && cfg.Read.Model != nil { - bound, err := cfg.Read.Model.WithTools([]*schema.ToolInfo{m.topicSelectionTool}) - if err != nil { - return nil, fmt.Errorf("auto memory topic selection model init failed: %w", err) + m.topicSelectionModel = &modelWithTools[M]{ + base: cfg.Read.Model, + tools: []*schema.ToolInfo{m.topicSelectionTool}, } - m.topicSelectionModel = bound } if cfg.Write.Mode != WriteModeDisabled && cfg.Write.Model != nil { @@ -230,7 +229,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh if err != nil { return nil, err } - fileSystemMiddleware, err := fsmw.New(ctx, &fsmw.MiddlewareConfig{ + fileSystemMiddleware, err := fsmw.NewTyped[M](ctx, &fsmw.MiddlewareConfig{ Backend: writeFSBackend, LsToolConfig: &fsmw.ToolConfig{Disable: true}, GrepToolConfig: &fsmw.ToolConfig{Disable: true}, @@ -412,7 +411,7 @@ func applyReadDefaults[M adk.MessageType](cfg *Config[M]) { } if cfg.Write == nil { - cfg.Write = &WriteConfig{Mode: WriteModeDisabled} + cfg.Write = &WriteConfig[M]{Mode: WriteModeDisabled} } if cfg.Write.Mode == "" { cfg.Write.Mode = WriteModeDisabled @@ -764,11 +763,11 @@ func (m *middleware[M]) selectTopicCandidates( toolInfo := topicSelectionToolInfo() resp, err := m.topicSelectionModel.Generate( ctx, - []*schema.Message{ - schema.SystemMessage(getTopicSelectionSystemPrompt()), - schema.UserMessage(userMsg), + []M{ + makeSystemMsg[M](getTopicSelectionSystemPrompt()), + makeUserMsg[M](userMsg), }, - model.WithToolChoice(schema.ToolChoiceForced, toolInfo.Name), + makeToolChoiceForced[M](toolInfo.Name), ) if err != nil { return nil, err @@ -864,11 +863,12 @@ func topicSelectionToolInfo() *schema.ToolInfo { } } -func parseTopicSelectionFromToolCall(msg *schema.Message, valid map[string]struct{}) ([]string, error) { - if msg == nil || len(msg.ToolCalls) == 0 { +func parseTopicSelectionFromToolCall[M adk.MessageType](msg M, valid map[string]struct{}) ([]string, error) { + toolCalls := messageToolCalls(msg) + if len(toolCalls) == 0 { return nil, fmt.Errorf("no tool calls") } - tc := msg.ToolCalls[0] + tc := toolCalls[0] if tc.Function.Name != topicSelectionToolName { return nil, fmt.Errorf("unexpected tool call: %s", tc.Function.Name) } @@ -1002,6 +1002,31 @@ func setMsgExtra[M adk.MessageType](msg M, key string, value any) { } } +func copyMsgExtra[M adk.MessageType](dst, src M) { + srcExtra := getMsgExtra(src) + if len(srcExtra) == 0 { + return + } + switch d := any(dst).(type) { + case *schema.Message: + if d.Extra == nil { + d.Extra = make(map[string]any, len(srcExtra)) + } + for k, v := range srcExtra { + d.Extra[k] = v + } + case *schema.AgenticMessage: + if d.Extra == nil { + d.Extra = make(map[string]any, len(srcExtra)) + } + for k, v := range srcExtra { + d.Extra[k] = v + } + default: + panic("unreachable") + } +} + func makeUserMsg[M adk.MessageType](text string) M { var zero M switch any(zero).(type) { @@ -1014,6 +1039,35 @@ func makeUserMsg[M adk.MessageType](text string) M { } } +func makeSystemMsg[M adk.MessageType](text string) M { + var zero M + switch any(zero).(type) { + case *schema.Message: + return any(schema.SystemMessage(text)).(M) + case *schema.AgenticMessage: + return any(schema.SystemAgenticMessage(text)).(M) + default: + panic("unreachable") + } +} + +func makeToolChoiceForced[M adk.MessageType](name string) model.Option { + var zero M + switch any(zero).(type) { + case *schema.Message: + return model.WithToolChoice(schema.ToolChoiceForced, name) + case *schema.AgenticMessage: + return model.WithAgenticToolChoice(&schema.AgenticToolChoice{ + Type: schema.ToolChoiceForced, + Forced: &schema.AgenticForcedToolChoice{ + Tools: []*schema.AllowedTool{{FunctionName: name}}, + }, + }) + default: + panic("unreachable") + } +} + func messageToolCalls[M adk.MessageType](msg M) []schema.ToolCall { switch m := any(msg).(type) { case *schema.Message: @@ -1069,40 +1123,6 @@ func messageToolNames[M adk.MessageType](msg M) []string { } } -func projectMessagesToSchema[M adk.MessageType](msgs []M) []adk.Message { - out := make([]adk.Message, 0, len(msgs)) - for _, msg := range msgs { - if projected := projectMessageToSchema(msg); projected != nil { - out = append(out, projected) - } - } - return out -} - -func projectMessageToSchema[M adk.MessageType](msg M) adk.Message { - switch m := any(msg).(type) { - case *schema.Message: - return m - case *schema.AgenticMessage: - if m == nil { - return nil - } - text := m.String() - switch m.Role { - case schema.AgenticRoleTypeSystem: - return schema.SystemMessage(text) - case schema.AgenticRoleTypeAssistant: - return schema.AssistantMessage(text, messageToolCalls(msg)) - case schema.AgenticRoleTypeUser: - return schema.UserMessage(text) - default: - return schema.UserMessage(text) - } - default: - panic("unreachable") - } -} - func alreadyInjected[M adk.MessageType](msgs []M) bool { for _, m := range msgs { if isMemoryMessage(m) { @@ -1464,7 +1484,7 @@ func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int return countModelVisibleMessages(msgs[cursor:]) } -func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.ChatModelAgent, error) { +func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.TypedChatModelAgent[M], error) { if m.cfg == nil || m.cfg.Write == nil || m.cfg.Write.Model == nil { return nil, fmt.Errorf("auto memory extraction agent init failed: missing write model") } @@ -1472,13 +1492,12 @@ func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*sch return nil, fmt.Errorf("auto memory extraction agent init failed: missing extraction handler") } - agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ - Name: "automemory_extractor", - Description: "Internal auto memory extraction subagent", - Model: m.cfg.Write.Model, - Handlers: []adk.ChatModelAgentMiddleware{ + agent, err := adk.NewTypedChatModelAgent[M](ctx, &adk.TypedChatModelAgentConfig[M]{ + Name: "automemory_extractor", + Model: m.cfg.Write.Model, + Handlers: []adk.TypedChatModelAgentMiddleware[M]{ m.extractionHandler, // fs middleware - &toolInfoOverrideMiddleware{toolInfos: toolInfos}, // tool info override, for prefix cache + &toolInfoOverrideMiddleware[M]{toolInfos: toolInfos}, // tool info override, for prefix cache }, ToolsConfig: adk.ToolsConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{ @@ -1506,13 +1525,13 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [ } newMessageCount := countModelVisibleMessagesSince(snapshot, cursor) userPrompt := buildExtractAutoOnlyPrompt(m.resolvedMemoryDirectory, newMessageCount, manifest, m.cfg.Write.SkipIndex) - msgs := append(projectMessagesToSchema(snapshot), schema.UserMessage(userPrompt)) + msgs := append(append([]M{}, snapshot...), makeUserMsg[M](userPrompt)) extractionAgent, err := m.newExtractionAgent(ctx, toolInfos) if err != nil { return err } - iter := extractionAgent.Run(ctx, &adk.AgentInput{ + iter := extractionAgent.Run(ctx, &adk.TypedAgentInput[M]{ Messages: msgs, EnableStreaming: true, }) @@ -1583,30 +1602,46 @@ func parseRFC3339NanoBestEffort(s string) time.Time { return time.Time{} } -type toolInfoOverrideMiddleware struct { - adk.BaseChatModelAgentMiddleware +type toolInfoOverrideMiddleware[M adk.MessageType] struct { + adk.TypedBaseChatModelAgentMiddleware[M] - once sync.Once toolInfos []*schema.ToolInfo } -func (t *toolInfoOverrideMiddleware) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[*schema.Message], _ *adk.TypedModelContext[*schema.Message]) ( - context.Context, *adk.TypedChatModelAgentState[*schema.Message], error) { +func (t *toolInfoOverrideMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[M], _ *adk.TypedModelContext[M]) ( + context.Context, *adk.TypedChatModelAgentState[M], error) { - t.once.Do(func() { - toolNameMapping := make(map[string]struct{}, len(t.toolInfos)) - for _, tool := range t.toolInfos { - toolNameMapping[tool.Name] = struct{}{} - } + toolNameMapping := make(map[string]struct{}, len(t.toolInfos)) + for _, tool := range t.toolInfos { + toolNameMapping[tool.Name] = struct{}{} + } - overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...) - for _, tool := range state.ToolInfos { - if _, ok := toolNameMapping[tool.Name]; !ok { // add fs tools if not exists - overrideTools = append(overrideTools, tool) - } + overrideTools := append([]*schema.ToolInfo{}, t.toolInfos...) + for _, tool := range state.ToolInfos { + if _, ok := toolNameMapping[tool.Name]; !ok { + overrideTools = append(overrideTools, tool) } - state.ToolInfos = overrideTools - }) + } + state.ToolInfos = overrideTools return ctx, state, nil } + +type modelWithTools[M adk.MessageType] struct { + base model.BaseModel[M] + tools []*schema.ToolInfo +} + +func (m *modelWithTools[M]) Generate(ctx context.Context, input []M, opts ...model.Option) (M, error) { + newOpts := make([]model.Option, len(opts)+1) + copy(newOpts, opts) + newOpts[len(opts)] = model.WithTools(m.tools) + return m.base.Generate(ctx, input, newOpts...) +} + +func (m *modelWithTools[M]) Stream(ctx context.Context, input []M, opts ...model.Option) (*schema.StreamReader[M], error) { + newOpts := make([]model.Option, len(opts)+1) + copy(newOpts, opts) + newOpts[len(opts)] = model.WithTools(m.tools) + return m.base.Stream(ctx, input, newOpts...) +} diff --git a/adk/middlewares/automemory/automemory_test.go b/adk/middlewares/automemory/automemory_test.go index e128d75ec..1cbc5dd71 100644 --- a/adk/middlewares/automemory/automemory_test.go +++ b/adk/middlewares/automemory/automemory_test.go @@ -119,8 +119,8 @@ func TestNew_DoesNotMutateConfig(t *testing.T) { MemoryDirectory: "/mem", MemoryBackend: b, Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, - Read: &ReadConfig{}, - Write: &WriteConfig{}, + Read: &ReadConfig[*schema.Message]{}, + Write: &WriteConfig[*schema.Message]{}, Coordination: &CoordinationConfig[*schema.Message]{}, } _, err = New(ctx, cfgExplicitNested) @@ -182,7 +182,7 @@ func TestMiddleware_TopicSelection_AsyncInjectsInBeforeModel(t *testing.T) { MemoryDirectory: "/mem", MemoryBackend: b, Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, - Read: &ReadConfig{Mode: ReadModeAsync}, + Read: &ReadConfig[*schema.Message]{Mode: ReadModeAsync}, }) require.NoError(t, err) @@ -387,7 +387,7 @@ func TestMiddleware_TopicSelection_SmallCandidateSetBypassesModel(t *testing.T) MemoryDirectory: "/mem", MemoryBackend: b, Model: &panicModel{}, - Read: &ReadConfig{ + Read: &ReadConfig[*schema.Message]{ Mode: ReadModeSync, TopicSelection: &TopicSelectionConfig{ TopK: 5, @@ -419,7 +419,7 @@ func TestMiddleware_AfterAgent_SyncExtractionWritesMemoryFiles(t *testing.T) { mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeSync, Model: extModel, }, @@ -481,7 +481,7 @@ func TestMiddleware_AfterAgent_SyncExtraction_IteratorHandlerCanDrain(t *testing mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeSync, Model: extModel, HandleExtractionIterator: func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error { @@ -534,7 +534,7 @@ func TestMiddleware_AfterAgent_SkipsExtractionWhenMainAgentAlreadyWroteMemory(t mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeSync, Model: extModel, }, @@ -590,7 +590,7 @@ func TestMiddleware_AfterAgent_AsyncExtractionKeepsLatestPendingSnapshot(t *test mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeAsync, Model: extModel, }, @@ -749,7 +749,7 @@ func TestMiddleware_TopicSelection_ToolCallParsingAndFiltering(t *testing.T) { MemoryDirectory: "/mem", MemoryBackend: b, Model: selModel, - Read: &ReadConfig{ + Read: &ReadConfig[*schema.Message]{ Mode: ReadModeSync, TopicSelection: &TopicSelectionConfig{ TopK: 1, @@ -782,7 +782,7 @@ func TestMiddleware_TopicSelection_AsyncProtectsMemoryMessageFromMutation(t *tes MemoryDirectory: "/mem", MemoryBackend: b, Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`}, - Read: &ReadConfig{Mode: ReadModeAsync}, + Read: &ReadConfig[*schema.Message]{Mode: ReadModeAsync}, }) require.NoError(t, err) @@ -824,7 +824,7 @@ func TestMiddleware_AfterAgent_SyncExtraction_SkipIndexPrompt(t *testing.T) { mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeSync, Model: extModel, SkipIndex: true, @@ -862,7 +862,7 @@ func TestMiddleware_AfterAgent_SyncExtraction_ChinesePrompt(t *testing.T) { mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeSync, Model: extModel, }, @@ -903,7 +903,7 @@ func TestMiddleware_AfterAgent_RelativeMemoryDirRendersAbsolutePath(t *testing.T mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: ".", MemoryBackend: NewLocalBackend(), - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeSync, Model: extModel, }, @@ -965,7 +965,7 @@ func TestMiddleware_AfterAgent_AsyncSetsPendingSnapshotWhenLockHeld(t *testing.T mwI, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: "/mem", MemoryBackend: b, - Write: &WriteConfig{ + Write: &WriteConfig[*schema.Message]{ Mode: WriteModeAsync, Model: extModel, }, diff --git a/adk/middlewares/automemory/dream/config.go b/adk/middlewares/automemory/dream/config.go index 193270363..a74b1c6c0 100644 --- a/adk/middlewares/automemory/dream/config.go +++ b/adk/middlewares/automemory/dream/config.go @@ -42,7 +42,7 @@ type OnError func(ctx context.Context, stage string, err error) // HandleIterator handles the dream sub-agent event stream. // Optional. Nil means dream drains the iterator itself. -type HandleIterator func(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) error +type HandleIterator[M adk.MessageType] func(ctx context.Context, iter *adk.AsyncIterator[*adk.TypedAgentEvent[M]]) error // Config configures auto dream for both `New(...)` and `Run(...)`. type Config[M adk.MessageType] struct { @@ -54,9 +54,9 @@ type Config[M adk.MessageType] struct { // Required. MemoryBackend automemory.Backend - // Model is the tool-calling model used by the internal dream agent. + // Model is the model used by the internal dream agent. // Required. - Model model.ToolCallingChatModel + Model model.BaseModel[M] // SessionIDFunc resolves the current session ID. // Optional. Default: a generated session-scoped ID. @@ -84,7 +84,7 @@ type Config[M adk.MessageType] struct { // HandleIterator overrides iterator consumption. // Optional. Default: nil. - HandleIterator HandleIterator + HandleIterator HandleIterator[M] } // ScheduleConfig controls middleware-triggered runs. diff --git a/adk/middlewares/automemory/dream/dream.go b/adk/middlewares/automemory/dream/dream.go index e05645a72..afc4dcb17 100644 --- a/adk/middlewares/automemory/dream/dream.go +++ b/adk/middlewares/automemory/dream/dream.go @@ -41,7 +41,7 @@ type middleware[M adk.MessageType] struct { cfg *Config[M] resolvedMemoryDir string - fsHandler adk.ChatModelAgentMiddleware + fsHandler adk.TypedChatModelAgentMiddleware[M] sessionSearchTool tool.BaseTool now func() time.Time } @@ -99,7 +99,7 @@ func newMiddleware[M adk.MessageType](ctx context.Context, cfg *Config[M]) (*mid if err != nil { return nil, err } - fsHandler, err := fsmw.New(ctx, &fsmw.MiddlewareConfig{ + fsHandler, err := fsmw.NewTyped[M](ctx, &fsmw.MiddlewareConfig{ Backend: writeFSBackend, GrepToolConfig: &fsmw.ToolConfig{Disable: true}, }) @@ -215,7 +215,7 @@ func (m *middleware[M]) runDream(ctx context.Context, sessionID string, touchedS SessionID: sessionID, SearchSessionIDs: append([]string(nil), searchSessionIDs...), }) - iter := agent.Run(runCtx, &adk.AgentInput{Messages: []adk.Message{schema.UserMessage(prompt)}}) + iter := agent.Run(runCtx, &adk.TypedAgentInput[M]{Messages: []M{makeUserMsg[M](prompt)}}) if m.cfg.HandleIterator != nil { return m.cfg.HandleIterator(runCtx, iter) } @@ -231,16 +231,16 @@ func (m *middleware[M]) runDream(ctx context.Context, sessionID string, touchedS return nil } -func (m *middleware[M]) newDreamAgent(ctx context.Context) (*adk.ChatModelAgent, error) { +func (m *middleware[M]) newDreamAgent(ctx context.Context) (*adk.TypedChatModelAgent[M], error) { tools := make([]tool.BaseTool, 0, 1) if m.sessionSearchTool != nil { tools = append(tools, m.sessionSearchTool) } - agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + agent, err := adk.NewTypedChatModelAgent[M](ctx, &adk.TypedChatModelAgentConfig[M]{ Name: "automemory_dream", Description: "Internal auto dream consolidation agent", Model: m.cfg.Model, - Handlers: []adk.ChatModelAgentMiddleware{m.fsHandler}, + Handlers: []adk.TypedChatModelAgentMiddleware[M]{m.fsHandler}, ToolsConfig: adk.ToolsConfig{ToolsNodeConfig: compose.ToolsNodeConfig{Tools: tools}}, MaxIterations: 12, }) @@ -256,3 +256,15 @@ func (m *middleware[M]) onErr(ctx context.Context, stage string, err error) { } m.cfg.OnError(ctx, stage, err) } + +func makeUserMsg[M adk.MessageType](text string) M { + var zero M + switch any(zero).(type) { + case *schema.Message: + return any(schema.UserMessage(text)).(M) + case *schema.AgenticMessage: + return any(schema.UserAgenticMessage(text)).(M) + default: + panic("unreachable") + } +} From b856230aef62155218432003c93fdbd74b0f97dc Mon Sep 17 00:00:00 2001 From: mrh997 Date: Fri, 12 Jun 2026 18:26:37 +0800 Subject: [PATCH 5/8] feat: copy extra before set in automemory (#1073) --- adk/middlewares/automemory/automemory.go | 46 ++++++------------------ 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go index 72d1fb4d4..f8c023f95 100644 --- a/adk/middlewares/automemory/automemory.go +++ b/adk/middlewares/automemory/automemory.go @@ -985,43 +985,19 @@ func getMsgExtra[M adk.MessageType](msg M) map[string]any { } } -func setMsgExtra[M adk.MessageType](msg M, key string, value any) { - switch m := any(msg).(type) { - case *schema.Message: - if m.Extra == nil { - m.Extra = map[string]any{} - } - m.Extra[key] = value - case *schema.AgenticMessage: - if m.Extra == nil { - m.Extra = map[string]any{} - } - m.Extra[key] = value - default: - panic("unreachable") +func copyAndSetMsgExtra[M adk.MessageType](msg M, key string, value any) { + existing := getMsgExtra(msg) + newExtra := make(map[string]any, len(existing)+1) + for k, v := range existing { + newExtra[k] = v } -} + newExtra[key] = value -func copyMsgExtra[M adk.MessageType](dst, src M) { - srcExtra := getMsgExtra(src) - if len(srcExtra) == 0 { - return - } - switch d := any(dst).(type) { + switch m := any(msg).(type) { case *schema.Message: - if d.Extra == nil { - d.Extra = make(map[string]any, len(srcExtra)) - } - for k, v := range srcExtra { - d.Extra[k] = v - } + m.Extra = newExtra case *schema.AgenticMessage: - if d.Extra == nil { - d.Extra = make(map[string]any, len(srcExtra)) - } - for k, v := range srcExtra { - d.Extra[k] = v - } + m.Extra = newExtra default: panic("unreachable") } @@ -1151,7 +1127,7 @@ func hasInstructionInjected(instruction string) bool { func newMemoryMessage[M adk.MessageType](content string) M { msg := makeUserMsg[M](content) - setMsgExtra(msg, memoryExtraKey, &memoryExtra{Type: "memory"}) + copyAndSetMsgExtra(msg, memoryExtraKey, &memoryExtra{Type: "memory"}) return msg } @@ -1348,7 +1324,7 @@ func markWriteCursor[M adk.MessageType](state *adk.TypedChatModelAgentState[M], return state } - setMsgExtra(last, memoryExtraKey, &memoryExtra{ + copyAndSetMsgExtra(last, memoryExtraKey, &memoryExtra{ Type: "write_cursor", Cursor: cursor, UpdatedAt: time.Now().Format(time.RFC3339Nano), From 3e1d2b08a55ce665b62ec968236edf9e70b8c5a9 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Sat, 13 Jun 2026 01:01:54 +0800 Subject: [PATCH 6/8] fix(adk): auto memory transfer context (#1076) --- adk/middlewares/automemory/automemory.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go index f8c023f95..f8ad6adad 100644 --- a/adk/middlewares/automemory/automemory.go +++ b/adk/middlewares/automemory/automemory.go @@ -1274,7 +1274,7 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode } return ctx, nil } - go m.runExtractionDrain(context.Background(), sessionID, unlock, snap) + go m.runExtractionDrain(ctx, sessionID, unlock, snap) return ctx, nil default: From 0affdf6f9965bad78a79ff35482bb6f13c460704 Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Wed, 25 Mar 2026 15:36:52 +0800 Subject: [PATCH 7/8] feat(adk): auto memory middleware --- adk/middlewares/automemory/automemory.go | 34 ++++++ .../automemory/internal/backend.go | 2 +- uncommitted_comprehensive_review.md | 106 ------------------ 3 files changed, 35 insertions(+), 107 deletions(-) delete mode 100644 uncommitted_comprehensive_review.md diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go index f8ad6adad..6899c458b 100644 --- a/adk/middlewares/automemory/automemory.go +++ b/adk/middlewares/automemory/automemory.go @@ -1099,6 +1099,40 @@ func messageToolNames[M adk.MessageType](msg M) []string { } } +func projectMessagesToSchema[M adk.MessageType](msgs []M) []adk.Message { + out := make([]adk.Message, 0, len(msgs)) + for _, msg := range msgs { + if projected := projectMessageToSchema(msg); projected != nil { + out = append(out, projected) + } + } + return out +} + +func projectMessageToSchema[M adk.MessageType](msg M) adk.Message { + switch m := any(msg).(type) { + case *schema.Message: + return m + case *schema.AgenticMessage: + if m == nil { + return nil + } + text := m.String() + switch m.Role { + case schema.AgenticRoleTypeSystem: + return schema.SystemMessage(text) + case schema.AgenticRoleTypeAssistant: + return schema.AssistantMessage(text, messageToolCalls(msg)) + case schema.AgenticRoleTypeUser: + return schema.UserMessage(text) + default: + return schema.UserMessage(text) + } + default: + panic("unreachable") + } +} + func alreadyInjected[M adk.MessageType](msgs []M) bool { for _, m := range msgs { if isMemoryMessage(m) { diff --git a/adk/middlewares/automemory/internal/backend.go b/adk/middlewares/automemory/internal/backend.go index cd44c29cd..6f866764e 100644 --- a/adk/middlewares/automemory/internal/backend.go +++ b/adk/middlewares/automemory/internal/backend.go @@ -146,7 +146,7 @@ func (f *FSBackend) Read(ctx context.Context, req *adkfs.ReadRequest) (*adkfs.Fi } return nil, err } - return (*adkfs.FileContent)(content), nil + return content, nil } func (f *FSBackend) Write(ctx context.Context, req *adkfs.WriteRequest) error { diff --git a/uncommitted_comprehensive_review.md b/uncommitted_comprehensive_review.md deleted file mode 100644 index 30a410d06..000000000 --- a/uncommitted_comprehensive_review.md +++ /dev/null @@ -1,106 +0,0 @@ -# Comprehensive Review Summary: Uncommitted Changes - -## Overview - -- Total iterations: Stage 1: 1, Stage 2: 1, Stage 3: 1 -- Files modified by review: 2 -- Current diff size: +95 / -0 -- Baseline before review: `go test ./...` passed -- Final verification: `go test ./...` passed - -## Scope - -| File | Role | -| --- | --- | -| `adk/middlewares/permission/permission.go` | Permission gate resume routing | -| `adk/middlewares/permission/permission_test.go` | Resume pass-through and attack coverage | - -## Stage 1: Design Review - -### Scorecard - -| Dimension | Rating | Notes | -| --- | --- | --- | -| Concept Coherence | 4/5 | Passing through non-permission interrupt state is consistent with tool middleware acting as a conduit. | -| API Usability | 5/5 | No public API changes. | -| Minimum API Surface | 5/5 | No new exported types/functions. | -| Backward Compatibility | 4/5 | Permission `AskState` paths remain fail-closed; business interrupt paths now resume correctly. | -| Module Separation | 5/5 | Logic stays inside the permission middleware wrapper. | -| Cohesion vs. Tension | 4/5 | The middleware must distinguish its own persisted state from underlying tool state. | -| Elegance vs. Complexity | 4/5 | A single guard keeps the common pass-through path simple. | -| Naming | 5/5 | New helper/test names describe targeted and non-target resume semantics. | -| Readability | 4/5 | The critical branch is concise; tests document the scenario. | -| Duplication | 4/5 | Resume-context helpers share `genericResumeContext`; one extra non-target helper is acceptable. | -| Public API Documentation | N/A | No public API additions. | -| Internal Comments | 4/5 | Existing tests make intent explicit; no extra production comment needed. | - -### Finding Resolved - -| # | Dimension | Finding | Verdict | Fix | -| --- | --- | --- | --- | --- | -| 1 | Concept Coherence / Backward Compatibility | The original targeted-only pass-through let targeted business interrupts resume, but non-target replay of the same business interrupt still failed with `missing AskState` before the underlying tool could re-interrupt. | Fix | Generalized the pass-through to any resumed interrupt whose saved state is not a permission `AskState`. | - -### Validation and Counter-Argument - -| Finding | Validation | Counter-Argument | Decision | -| --- | --- | --- | --- | -| Business interrupt non-target replay fails | `TestAttack_BusinessInterruptNonTargetReplayPassesThrough` reproduced the failure before the production fix. | Passing through a non-`AskState` could theoretically hide corrupted permission state, but the existing change already accepts non-`AskState` for targeted business resumes. Consistent pass-through is required by the ADK explicit targeted resume contract. | Fix | - -## Stage 2: Attack Review - -### Attack Tests - -| Test | Category | Result | Notes | -| --- | --- | --- | --- | -| `TestWrapInvokableToolCall_PassesThroughBusinessInterruptResume` | Feature interaction | Passed | Verifies permission approval can be followed by an underlying business interrupt and targeted business resume. | -| `TestAttack_BusinessInterruptNonTargetReplayPassesThrough` | Conflict detection / feature interaction | Failed before fix, passed after fix | Verifies non-target replay preserves the underlying business interrupt instead of returning a permission `AskState` error. | -| `TestAttack_InvalidRespondDoesNotPersistDecisionEvent` | Validation gap | Passed | Existing attack coverage remains green. | - -### Bug Fixed - -| # | Severity | Bug | Fix | Test | -| --- | --- | --- | --- | --- | -| 1 | High | A permission-wrapped tool with non-permission interrupt state could not participate in sibling/non-target replay because the permission gate treated missing `AskState` as a permission error. | In `permissionGate`, return an allowed pass-through result whenever `wasInterrupted && !hasState`, allowing the underlying tool to inspect its own state and target status. | `TestAttack_BusinessInterruptNonTargetReplayPassesThrough` | - -## Stage 3: Test Audit - -### Audit Result - -| Category | Result | -| --- | --- | -| Duplicates | No true duplicates found in the changed tests. | -| Assertion Quality | Assertions check target flags, resume payloads, preserved state, error type, and call counts. | -| Boilerplate | `genericResumeContext` and `nonTargetResumeContext` keep setup explicit without over-abstracting. | -| Logical Grouping | New tests are placed near existing invokable resume tests. | -| Semantic Value | Both added tests cover distinct targeted and non-target business interrupt semantics. | -| Coverage | Package coverage is 81.9%; the changed production branch is directly covered by the new tests. | - -### Coverage - -- Command: `go test -coverprofile=/tmp/eino_permission_cover.out ./adk/middlewares/permission && go tool cover -func=/tmp/eino_permission_cover.out` -- Package coverage: 81.9% statements -- `permissionGate`: 72.7% statements -- Diff coverage: covered for the new `wasInterrupted && !hasState` branch -- Remaining package-level gap: existing functions such as `publicInfo` still report low coverage, but they are outside this review's diff. - -## Cumulative File Change List - -| File | Stage(s) | Summary | -| --- | --- | --- | -| `adk/middlewares/permission/permission.go` | 1, 2 | Generalized resumed non-`AskState` pass-through so underlying business interrupts handle both targeted and non-target replay. | -| `adk/middlewares/permission/permission_test.go` | 2, 3 | Added targeted business interrupt resume coverage, non-target attack coverage, and reusable resume-context helpers. | - -## Verification Commands - -| Command | Result | -| --- | --- | -| `go test ./...` | Passed before review | -| `go test ./adk/middlewares/permission -run 'TestAttack_BusinessInterruptNonTargetReplayPassesThrough|TestWrapInvokableToolCall_PassesThroughBusinessInterruptResume' -v -count=1` | Passed | -| `go test ./adk/middlewares/permission -run 'TestAttack_|TestWrapInvokableToolCall_PassesThroughBusinessInterruptResume' -v -count=1` | Passed | -| `go test -coverprofile=/tmp/eino_permission_cover.out ./adk/middlewares/permission && go tool cover -func=/tmp/eino_permission_cover.out` | Passed, 81.9% package coverage | -| `go test ./...` | Passed after review | - -## Remaining Items - -- No unresolved blockers. -- Package-level coverage remains below the skill's 85% target, but the uncovered regions are pre-existing and outside the uncommitted diff; the changed branch is covered. From dc8a96b61c6790ff7aa0ba3db6df27e329c94d2c Mon Sep 17 00:00:00 2001 From: xuzhaonan Date: Mon, 15 Jun 2026 11:32:17 +0800 Subject: [PATCH 8/8] feat(adk): updates auto memory --- adk/middlewares/automemory/automemory.go | 74 ++++++++-------- adk/middlewares/automemory/automemory_test.go | 84 +++++++++++++++++++ adk/middlewares/automemory/backend.go | 2 +- adk/middlewares/automemory/dream/config.go | 6 +- adk/middlewares/automemory/dream/dream.go | 7 +- .../automemory/dream/dream_test.go | 22 +---- adk/middlewares/automemory/dream/session.go | 20 +---- .../automemory/dream/session_test.go | 23 ++--- 8 files changed, 145 insertions(+), 93 deletions(-) diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go index 6899c458b..f2dd19f9b 100644 --- a/adk/middlewares/automemory/automemory.go +++ b/adk/middlewares/automemory/automemory.go @@ -153,6 +153,7 @@ type middleware[M adk.MessageType] struct { cfg *Config[M] resolvedMemoryDirectory string + boundedMemoryBackend Backend topicSelectionModel model.BaseModel[M] extractionHandler adk.TypedChatModelAgentMiddleware[M] @@ -178,11 +179,8 @@ const ( ) type memoryExtra struct { - Type string - Cursor int - UpdatedAt string - Visibility string - SchemaVer int + Type string + Cursor int } // New creates an automemory middleware from the provided configuration. @@ -200,6 +198,14 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh if err != nil { return nil, fmt.Errorf("auto memory config: resolve memory directory: %w", err) } + boundedMemoryBackend, err := ainternal.NewFSBackend(cfg.MemoryBackend, ainternal.FSBackendConfig{ + BaseDir: resolvedMemoryDir, + NotFoundAsContent: true, + ErrorPrefix: "memory backend", + }) + if err != nil { + return nil, err + } if cfg.Read == nil { cfg.Read = &ReadConfig[M]{} } @@ -209,6 +215,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh TypedBaseChatModelAgentMiddleware: adk.TypedBaseChatModelAgentMiddleware[M]{}, cfg: cfg, resolvedMemoryDirectory: resolvedMemoryDir, + boundedMemoryBackend: boundedMemoryBackend, coordination: cfg.Coordination, } @@ -221,11 +228,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh } if cfg.Write.Mode != WriteModeDisabled && cfg.Write.Model != nil { - writeFSBackend, err := ainternal.NewFSBackend(cfg.MemoryBackend, ainternal.FSBackendConfig{ - BaseDir: resolvedMemoryDir, - NotFoundAsContent: true, - ErrorPrefix: "fs backend", - }) + writeFSBackend, err := newFSBackend(cfg.MemoryBackend, resolvedMemoryDir) if err != nil { return nil, err } @@ -486,14 +489,18 @@ func (m *middleware[M]) injectIndexIntoInstruction(ctx context.Context, baseInst memDesc = s } - indexPath := filepath.Join(m.cfg.MemoryDirectory, m.cfg.Read.Index.FileName) + indexPath := filepath.Join(m.resolvedMemoryDirectory, m.cfg.Read.Index.FileName) indexContent := "" totalLines := 0 - fc, err := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{FilePath: indexPath}) + fc, err := m.boundedMemoryBackend.Read(ctx, &ReadRequest{FilePath: indexPath}) if err == nil && fc != nil { - indexContent = fc.Content - totalLines = strings.Count(indexContent, "\n") + 1 + if isFileNotFoundContent(fc.Content) { + indexContent = "" + } else { + indexContent = fc.Content + totalLines = strings.Count(indexContent, "\n") + 1 + } } else { // Missing index is not fatal; keep empty. indexContent = "" @@ -555,6 +562,10 @@ func linesOrSizeTrunc(content string, lines, size int) (newContent string, reaso return } +func isFileNotFoundContent(content string) bool { + return strings.HasPrefix(strings.TrimSpace(content), "File not found: ") +} + func (m *middleware[M]) onErr(ctx context.Context, stage string, err error) { if err == nil { return @@ -657,15 +668,15 @@ func (m *middleware[M]) listTopicCandidates(ctx context.Context) (map[string]top } func (m *middleware[M]) topicSelectionCandidates(ctx context.Context) ([]FileInfo, error) { - files, err := m.cfg.MemoryBackend.GlobInfo(ctx, &GlobInfoRequest{ + files, err := m.boundedMemoryBackend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: m.cfg.Read.TopicSelection.CandidateGlob, - Path: m.cfg.MemoryDirectory, + Path: m.resolvedMemoryDirectory, }) if err != nil || len(files) == 0 { return nil, err } - indexAbs := filepath.Join(m.cfg.MemoryDirectory, m.cfg.Read.Index.FileName) + indexAbs := filepath.Join(m.resolvedMemoryDirectory, m.cfg.Read.Index.FileName) candidates := make([]FileInfo, 0, len(files)) for _, fi := range files { if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) { @@ -687,17 +698,17 @@ func (m *middleware[M]) topicSelectionCandidates(ctx context.Context) ([]FileInf } func (m *middleware[M]) buildTopicCandidateBundle(ctx context.Context, fi FileInfo) (topicCandidateBundle, string, bool) { - rel, relErr := filepath.Rel(m.cfg.MemoryDirectory, fi.Path) + rel, relErr := filepath.Rel(m.resolvedMemoryDirectory, fi.Path) if relErr != nil { rel = filepath.Base(fi.Path) } rel = filepath.ToSlash(rel) - preview, err := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{ + preview, err := m.boundedMemoryBackend.Read(ctx, &ReadRequest{ FilePath: fi.Path, Limit: m.cfg.Read.TopicSelection.CandidatePreviewLines, }) - if err != nil || preview == nil { + if err != nil || preview == nil || isFileNotFoundContent(preview.Content) { return topicCandidateBundle{}, "", false } @@ -824,8 +835,8 @@ func (m *middleware[M]) renderTopicMemories( } func (m *middleware[M]) renderTopicMemory(ctx context.Context, bundle topicCandidateBundle) (string, bool) { - full, err := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{FilePath: bundle.AbsPath}) - if err != nil || full == nil { + full, err := m.boundedMemoryBackend.Read(ctx, &ReadRequest{FilePath: bundle.AbsPath}) + if err != nil || full == nil || isFileNotFoundContent(full.Content) { return "", false } @@ -1359,11 +1370,8 @@ func markWriteCursor[M adk.MessageType](state *adk.TypedChatModelAgentState[M], } copyAndSetMsgExtra(last, memoryExtraKey, &memoryExtra{ - Type: "write_cursor", - Cursor: cursor, - UpdatedAt: time.Now().Format(time.RFC3339Nano), - Visibility: "internal", - SchemaVer: 1, + Type: "write_cursor", + Cursor: cursor, }) return state @@ -1565,17 +1573,17 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [ } func (m *middleware[M]) buildMemoryManifest(ctx context.Context) (string, error) { - files, err := m.cfg.MemoryBackend.GlobInfo(ctx, &GlobInfoRequest{ + files, err := m.boundedMemoryBackend.GlobInfo(ctx, &GlobInfoRequest{ Pattern: CandidateGlobPattern, - Path: m.cfg.MemoryDirectory, + Path: m.resolvedMemoryDirectory, }) if err != nil { return "", err } - indexAbs := filepath.Join(m.cfg.MemoryDirectory, m.cfg.Read.Index.FileName) + indexAbs := filepath.Join(m.resolvedMemoryDirectory, m.cfg.Read.Index.FileName) lines := make([]string, 0, len(files)) for _, fi := range files { - rel, relErr := filepath.Rel(m.cfg.MemoryDirectory, fi.Path) + rel, relErr := filepath.Rel(m.resolvedMemoryDirectory, fi.Path) if relErr != nil { rel = filepath.Base(fi.Path) } @@ -1584,8 +1592,8 @@ func (m *middleware[M]) buildMemoryManifest(ctx context.Context) (string, error) rel = m.cfg.Read.Index.FileName } desc := "" - preview, rerr := m.cfg.MemoryBackend.Read(ctx, &ReadRequest{FilePath: fi.Path, Limit: defaultCandidatePreviewLine}) - if rerr == nil && preview != nil { + preview, rerr := m.boundedMemoryBackend.Read(ctx, &ReadRequest{FilePath: fi.Path, Limit: defaultCandidatePreviewLine}) + if rerr == nil && preview != nil && !isFileNotFoundContent(preview.Content) { if fm, ok := parseFrontmatter(preview.Content); ok { desc = strings.TrimSpace(fm.Description) } diff --git a/adk/middlewares/automemory/automemory_test.go b/adk/middlewares/automemory/automemory_test.go index 1cbc5dd71..6304bbc7f 100644 --- a/adk/middlewares/automemory/automemory_test.go +++ b/adk/middlewares/automemory/automemory_test.go @@ -266,6 +266,39 @@ type countingBackend struct { paths []string } +type outOfBoundsCandidateBackend struct { + outsideReadCalled int32 +} + +func (b *outOfBoundsCandidateBackend) Read(_ context.Context, req *ReadRequest) (*FileContent, error) { + if req == nil { + return nil, fmt.Errorf("read: invalid request") + } + if filepath.Clean(req.FilePath) == filepath.Clean("/outside/secret.md") { + atomic.StoreInt32(&b.outsideReadCalled, 1) + return &FileContent{Content: "secret"}, nil + } + return nil, fmt.Errorf("file not found: %s", req.FilePath) +} + +func (b *outOfBoundsCandidateBackend) GlobInfo(_ context.Context, req *GlobInfoRequest) ([]FileInfo, error) { + if req == nil { + return nil, fmt.Errorf("glob: invalid request") + } + return []FileInfo{{ + Path: "/outside/secret.md", + ModifiedAt: time.Now().Format(time.RFC3339Nano), + }}, nil +} + +func (b *outOfBoundsCandidateBackend) Write(context.Context, *WriteRequest) error { + return nil +} + +func (b *outOfBoundsCandidateBackend) Edit(context.Context, *EditRequest) error { + return nil +} + func (b *countingBackend) Write(ctx context.Context, req *WriteRequest) error { atomic.AddInt32(&b.writeCalls, 1) b.mu.Lock() @@ -929,6 +962,36 @@ func TestMiddleware_AfterAgent_RelativeMemoryDirRendersAbsolutePath(t *testing.T require.Equal(t, "remember relative", string(raw)) } +func TestMiddleware_BeforeAgent_RelativeMemoryDirReadsResolvedDirectoryAfterCWDChange(t *testing.T) { + ctx := context.Background() + tmp := t.TempDir() + oldwd, err := os.Getwd() + require.NoError(t, err) + require.NoError(t, os.Chdir(tmp)) + defer func() { + _ = os.Chdir(oldwd) + }() + + require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("persisted index\n"), 0o644)) + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: ".", + MemoryBackend: NewLocalBackend(), + }) + require.NoError(t, err) + + other := t.TempDir() + require.NoError(t, os.Chdir(other)) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}}, + } + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Contains(t, out.Instruction, "persisted index") +} + func TestFSBackend_ReadMissingFileReturnsContentInsteadOfError(t *testing.T) { ctx := context.Background() tmp := t.TempDir() @@ -943,6 +1006,27 @@ func TestFSBackend_ReadMissingFileReturnsContentInsteadOfError(t *testing.T) { require.Contains(t, content.Content, filepath.Join(tmp, "missing.md")) } +func TestMiddleware_TopicSelection_IgnoresOutOfBoundsCandidatePaths(t *testing.T) { + ctx := context.Background() + backend := &outOfBoundsCandidateBackend{} + + mw, err := New(ctx, &Config[*schema.Message]{ + MemoryDirectory: "/mem", + MemoryBackend: backend, + Model: &panicModel{}, + }) + require.NoError(t, err) + + runCtx := &adk.ChatModelAgentContext[*schema.Message]{ + Instruction: "base", + AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("show memories")}}, + } + _, out, err := mw.BeforeAgent(ctx, runCtx) + require.NoError(t, err) + require.Len(t, out.AgentInput.Messages, 1) + require.Equal(t, int32(0), atomic.LoadInt32(&backend.outsideReadCalled)) +} + func TestMiddleware_AfterAgent_AsyncSetsPendingSnapshotWhenLockHeld(t *testing.T) { ctx := context.Background() b := NewInMemoryBackend() diff --git a/adk/middlewares/automemory/backend.go b/adk/middlewares/automemory/backend.go index bf58e4e8f..9d217351b 100644 --- a/adk/middlewares/automemory/backend.go +++ b/adk/middlewares/automemory/backend.go @@ -37,7 +37,7 @@ type FileInfo = filesystem.FileInfo type WriteRequest = filesystem.WriteRequest type EditRequest = filesystem.EditRequest -func newFSBackend(backend Backend, baseDir string) (ainternal.Backend, error) { +func newFSBackend(backend Backend, baseDir string) (*ainternal.FSBackend, error) { return ainternal.NewFSBackend(backend, ainternal.FSBackendConfig{ BaseDir: baseDir, NotFoundAsContent: true, diff --git a/adk/middlewares/automemory/dream/config.go b/adk/middlewares/automemory/dream/config.go index a74b1c6c0..21f9c6bd6 100644 --- a/adk/middlewares/automemory/dream/config.go +++ b/adk/middlewares/automemory/dream/config.go @@ -76,7 +76,7 @@ type Config[M adk.MessageType] struct { // the current run scope: // - middleware-triggered runs search the touched sessions selected by the scheduler // - manual `Run(...)` searches the provided/current session only - SessionStore *SessionStoreProvider[M] + SessionStore adk.SessionEventStore[M] // Schedule controls middleware-triggered runs only. // Optional. `Run(...)` ignores it. @@ -133,10 +133,6 @@ func cloneConfig[M adk.MessageType](cfg *Config[M]) *Config[M] { } cp := *cfg - if cfg.SessionStore != nil { - sessionStoreCopy := *cfg.SessionStore - cp.SessionStore = &sessionStoreCopy - } if cfg.Schedule != nil { scheduleCopy := *cfg.Schedule cp.Schedule = &scheduleCopy diff --git a/adk/middlewares/automemory/dream/dream.go b/adk/middlewares/automemory/dream/dream.go index afc4dcb17..e9925ef0a 100644 --- a/adk/middlewares/automemory/dream/dream.go +++ b/adk/middlewares/automemory/dream/dream.go @@ -108,10 +108,7 @@ func newMiddleware[M adk.MessageType](ctx context.Context, cfg *Config[M]) (*mid } var sessionSearchTool tool.BaseTool if cfg.SessionStore != nil { - sessionSearchTool, err = newSessionHistoryGrepTool[M](cfg.SessionStore) - if err != nil { - return nil, err - } + sessionSearchTool, err = newSessionHistoryGrepTool(cfg.SessionStore) } m := &middleware[M]{ TypedBaseChatModelAgentMiddleware: adk.TypedBaseChatModelAgentMiddleware[M]{}, @@ -205,7 +202,7 @@ func (m *middleware[M]) runDream(ctx context.Context, sessionID string, touchedS if err != nil { return err } - prompt := buildConsolidationPrompt(m.resolvedMemoryDir, touchedSessions, m.cfg.SessionStore != nil) + prompt := buildConsolidationPrompt(m.resolvedMemoryDir, touchedSessions, m.sessionSearchTool != nil) searchSessionIDs := touchedSessions if len(searchSessionIDs) == 0 && sessionID != "" { searchSessionIDs = []string{sessionID} diff --git a/adk/middlewares/automemory/dream/dream_test.go b/adk/middlewares/automemory/dream/dream_test.go index b3aae2393..3d9188359 100644 --- a/adk/middlewares/automemory/dream/dream_test.go +++ b/adk/middlewares/automemory/dream/dream_test.go @@ -175,11 +175,8 @@ func TestNew_DoesNotMutateConfig(t *testing.T) { MemoryDirectory: "/mem", MemoryBackend: automemory.NewInMemoryBackend(), Model: &dreamModel{}, - SessionStore: &SessionStoreProvider[*schema.Message]{ - SessionStore: adksession.NewInMemoryStore[*schema.Message](nil), - Serializer: &schema.HumanReadableSerializer{}, - }, - Schedule: &ScheduleConfig{}, + SessionStore: adksession.NewInMemoryStore[*schema.Message](nil), + Schedule: &ScheduleConfig{}, } _, err := New(ctx, cfg) @@ -199,14 +196,7 @@ func TestMiddleware_AfterAgent_RunInlineWithSessionStore(t *testing.T) { store := NewLocalStore() model := &dreamModel{} eventStore := &countingSessionStore{SessionEventStore: adksession.NewInMemoryStore[*schema.Message](nil)} - serializer := &schema.HumanReadableSerializer{} - payload, err := serializer.Marshal(&adk.SessionEvent[*schema.Message]{ - EventID: "e1", - Kind: adk.SessionEventMessage, - Message: schema.AssistantMessage("build failure: missing dependency", nil), - }) - require.NoError(t, err) - _, err = eventStore.AppendEvents(ctx, &adk.AppendSessionEventsRequest[*schema.Message]{ + _, err := eventStore.AppendEvents(ctx, &adk.AppendSessionEventsRequest[*schema.Message]{ SessionID: "session-a", Events: []*adk.SessionEvent[*schema.Message]{{ EventID: "e1", @@ -215,15 +205,11 @@ func TestMiddleware_AfterAgent_RunInlineWithSessionStore(t *testing.T) { }}, }) require.NoError(t, err) - _ = payload mw, err := New(ctx, &Config[*schema.Message]{ MemoryDirectory: tmp, MemoryBackend: automemory.NewLocalBackend(), Model: model, - SessionStore: &SessionStoreProvider[*schema.Message]{ - SessionStore: eventStore, - Serializer: serializer, - }, + SessionStore: eventStore, Schedule: &ScheduleConfig{ RunInline: true, Store: store, diff --git a/adk/middlewares/automemory/dream/session.go b/adk/middlewares/automemory/dream/session.go index 465b0e7a9..bb5c0cb90 100644 --- a/adk/middlewares/automemory/dream/session.go +++ b/adk/middlewares/automemory/dream/session.go @@ -28,20 +28,6 @@ import ( "github.com/cloudwego/eino/schema" ) -// SessionStoreProvider supplies the session timeline dependencies required by -// dream's optional session-history search tool. -// -// When configured, dream exposes a narrow grep-like tool that scans persisted -// message events for the current session only. -type SessionStoreProvider[M adk.MessageType] struct { - // SessionStore loads persisted session events for the current session. - // Required when SessionStoreProvider is configured. - SessionStore adk.SessionEventStore[M] - // Serializer decodes SessionEvent payload bytes returned by SessionStore. - // It must match the serializer used when the events were persisted. - Serializer schema.Serializer -} - type grepSessionHistoryInput struct { Query string `json:"query" jsonschema:"required,description=the narrow term to search in current session history"` Limit int `json:"limit,omitempty" jsonschema:"description=maximum number of matching lines to return"` @@ -68,8 +54,8 @@ func getDreamRunMeta(ctx context.Context) *dreamRunMeta { return nil } -func newSessionHistoryGrepTool[M adk.MessageType](provider *SessionStoreProvider[M]) (tool.BaseTool, error) { - if provider == nil || provider.SessionStore == nil || provider.Serializer == nil { +func newSessionHistoryGrepTool[M adk.MessageType](store adk.SessionEventStore[M]) (tool.BaseTool, error) { + if store == nil { return nil, nil } @@ -107,7 +93,7 @@ func newSessionHistoryGrepTool[M adk.MessageType](provider *SessionStoreProvider for _, sessionID := range sessionIDs { after = "" for len(found) < limit { - result, err := provider.SessionStore.LoadEvents(ctx, &adk.LoadSessionEventsRequest{ + result, err := store.LoadEvents(ctx, &adk.LoadSessionEventsRequest{ SessionID: sessionID, After: after, Limit: pageSize, diff --git a/adk/middlewares/automemory/dream/session_test.go b/adk/middlewares/automemory/dream/session_test.go index 69df4fc7c..1ba2b9c8f 100644 --- a/adk/middlewares/automemory/dream/session_test.go +++ b/adk/middlewares/automemory/dream/session_test.go @@ -31,7 +31,6 @@ import ( func TestNewSessionHistoryGrepTool(t *testing.T) { ctx := context.Background() store := adksession.NewInMemoryStore[*schema.Message](nil) - serializer := &schema.HumanReadableSerializer{} sessionID := "session-1" tail := "" @@ -53,10 +52,7 @@ func TestNewSessionHistoryGrepTool(t *testing.T) { appendEvent("e2", schema.AssistantMessage("build failure: missing dependency", nil)) appendEvent("e3", schema.ToolMessage("Build Failure: retry later", "call-1")) - bt, err := newSessionHistoryGrepTool(&SessionStoreProvider[*schema.Message]{ - SessionStore: store, - Serializer: serializer, - }) + bt, err := newSessionHistoryGrepTool[*schema.Message](store) require.NoError(t, err) result, err := bt.(tool.InvokableTool).InvokableRun( @@ -70,7 +66,6 @@ func TestNewSessionHistoryGrepTool(t *testing.T) { func TestNewSessionHistoryGrepTool_SearchesRunScopedSessions(t *testing.T) { ctx := context.Background() store := adksession.NewInMemoryStore[*schema.Message](nil) - serializer := &schema.HumanReadableSerializer{} appendEvent := func(sessionID, eventID string, msg *schema.Message) { _, err := store.AppendEvents(ctx, &adk.AppendSessionEventsRequest[*schema.Message]{ @@ -88,10 +83,7 @@ func TestNewSessionHistoryGrepTool_SearchesRunScopedSessions(t *testing.T) { appendEvent("session-b", "b1", schema.ToolMessage("build failure: retry later", "call-1")) appendEvent("session-c", "c1", schema.AssistantMessage("build failure: should not be searched", nil)) - bt, err := newSessionHistoryGrepTool(&SessionStoreProvider[*schema.Message]{ - SessionStore: store, - Serializer: serializer, - }) + bt, err := newSessionHistoryGrepTool[*schema.Message](store) require.NoError(t, err) result, err := bt.(tool.InvokableTool).InvokableRun( @@ -113,13 +105,16 @@ func TestNewSessionHistoryGrepTool_InfoUsesChineseDescription(t *testing.T) { require.NoError(t, adk.SetLanguage(adk.LanguageEnglish)) }() - bt, err := newSessionHistoryGrepTool(&SessionStoreProvider[*schema.Message]{ - SessionStore: adksession.NewInMemoryStore[*schema.Message](nil), - Serializer: &schema.HumanReadableSerializer{}, - }) + bt, err := newSessionHistoryGrepTool[*schema.Message](adksession.NewInMemoryStore[*schema.Message](nil)) require.NoError(t, err) info, err := bt.Info(context.Background()) require.NoError(t, err) require.Contains(t, info.Desc, "在当前 dream 运行范围内的会话历史中按精确关键词搜索") } + +func TestNewSessionHistoryGrepTool_AllowsNilStore(t *testing.T) { + bt, err := newSessionHistoryGrepTool[*schema.Message](nil) + require.NoError(t, err) + require.Nil(t, bt) +}