Skip to content

Commit ba37df9

Browse files
committed
feat: auto memory middleware
1 parent 664af32 commit ba37df9

12 files changed

Lines changed: 2925 additions & 713 deletions

File tree

adk/chatmodel.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,10 @@ type execContext struct {
791791
toolUpdated bool // whether needs to pass a compose.WithToolList option to ToolsNode due to tool list change
792792
}
793793

794-
func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execContext) (context.Context, *execContext, error) {
795-
runCtx := &ChatModelAgentContext{
794+
func (a *TypedChatModelAgent[M]) applyBeforeAgent(ctx context.Context, ec *execContext, agentInput *TypedAgentInput[M]) (context.Context, *execContext, error) {
795+
runCtx := &ChatModelAgentContext[M]{
796796
Instruction: ec.instruction,
797+
AgentInput: agentInput,
797798
Tools: cloneSlice(ec.unwrappedTools),
798799
ReturnDirectly: copyMap(ec.returnDirectly),
799800
}
@@ -1394,7 +1395,7 @@ func (a *TypedChatModelAgent[M]) buildRunFunc(ctx context.Context) typedRunFunc[
13941395
return a.run
13951396
}
13961397

1397-
func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Context, typedRunFunc[M], *execContext, error) {
1398+
func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context, agentInput *TypedAgentInput[M]) (context.Context, typedRunFunc[M], *execContext, error) {
13981399
defaultRun := a.buildRunFunc(ctx)
13991400
bc := a.exeCtx
14001401

@@ -1412,7 +1413,7 @@ func (a *TypedChatModelAgent[M]) getRunFunc(ctx context.Context) (context.Contex
14121413
return ctx, defaultRun, runtimeBC, nil
14131414
}
14141415

1415-
ctx, runtimeBC, err := a.applyBeforeAgent(ctx, bc)
1416+
ctx, runtimeBC, err := a.applyBeforeAgent(ctx, bc, agentInput)
14161417
if err != nil {
14171418
return ctx, nil, nil, err
14181419
}
@@ -1447,7 +1448,7 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput
14471448
cancelCtx = getCancelContext(ctx)
14481449
}
14491450

1450-
ctx, run, bc, err := a.getRunFunc(ctx)
1451+
ctx, run, bc, err := a.getRunFunc(ctx, input)
14511452
if err != nil {
14521453
go func() {
14531454
if cancelCtxOwned && cancelCtx != nil {
@@ -1523,7 +1524,7 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o
15231524
cancelCtx = getCancelContext(ctx)
15241525
}
15251526

1526-
ctx, run, bc, err := a.getRunFunc(ctx)
1527+
ctx, run, bc, err := a.getRunFunc(ctx, nil)
15271528
if err != nil {
15281529
go func() {
15291530
if cancelCtxOwned && cancelCtx != nil {

adk/handler.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@ type ModelContext = TypedModelContext[*schema.Message]
8484
// Handlers can modify Instruction, Tools, and ReturnDirectly to customize agent behavior.
8585
//
8686
// This type is specific to ChatModelAgent. Other agent types may define their own context types.
87-
type ChatModelAgentContext struct {
87+
type ChatModelAgentContext[M MessageType] struct {
8888
// Instruction is the current instruction for the Agent execution.
8989
// It includes the instruction configured for the agent, additional instructions appended by framework
9090
// and AgentMiddleware, and modifications applied by previous BeforeAgent handlers.
9191
// The finalized instruction after all BeforeAgent handlers are then passed to GenModelInput,
9292
// to be (optionally) formatted with SessionValues and converted to system message.
9393
Instruction string
9494

95+
AgentInput *TypedAgentInput[M]
96+
9597
// Tools are the raw tools (without any wrapper or tool middleware) currently configured for the Agent execution.
9698
// They includes tools passed in AgentConfig, implicit tools added by framework such as transfer / exit tools,
9799
// and other tools already added by middlewares.
@@ -139,7 +141,7 @@ type ChatModelAgentContext struct {
139141
type TypedChatModelAgentMiddleware[M MessageType] interface {
140142
// BeforeAgent is called before each agent run, allowing modification of
141143
// the agent's instruction and tools configuration.
142-
BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error)
144+
BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[M]) (context.Context, *ChatModelAgentContext[M], error)
143145

144146
// AfterAgent is called after the agent run reaches a successful terminal state.
145147
// Successful terminal states are: final answer (model response with no tool calls),
@@ -296,7 +298,7 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) WrapModel(_ context.Context, m mo
296298
return m, nil
297299
}
298300

299-
func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
301+
func (b *TypedBaseChatModelAgentMiddleware[M]) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[M]) (context.Context, *ChatModelAgentContext[M], error) {
300302
return ctx, runCtx, nil
301303
}
302304

adk/handler_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type testInstructionHandler struct {
3737
text string
3838
}
3939

40-
func (h *testInstructionHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
40+
func (h *testInstructionHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
4141
if runCtx.Instruction == "" {
4242
runCtx.Instruction = h.text
4343
} else if h.text != "" {
@@ -51,7 +51,7 @@ type testInstructionFuncHandler struct {
5151
fn func(ctx context.Context, instruction string) (context.Context, string, error)
5252
}
5353

54-
func (h *testInstructionFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
54+
func (h *testInstructionFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
5555
newCtx, newInstruction, err := h.fn(ctx, runCtx.Instruction)
5656
if err != nil {
5757
return ctx, runCtx, err
@@ -65,7 +65,7 @@ type testToolsHandler struct {
6565
tools []tool.BaseTool
6666
}
6767

68-
func (h *testToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
68+
func (h *testToolsHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
6969
runCtx.Tools = append(runCtx.Tools, h.tools...)
7070
return ctx, runCtx, nil
7171
}
@@ -75,7 +75,7 @@ type testToolsFuncHandler struct {
7575
fn func(ctx context.Context, tools []tool.BaseTool, returnDirectly map[string]bool) (context.Context, []tool.BaseTool, map[string]bool, error)
7676
}
7777

78-
func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
78+
func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
7979
newCtx, newTools, newReturnDirectly, err := h.fn(ctx, runCtx.Tools, runCtx.ReturnDirectly)
8080
if err != nil {
8181
return ctx, runCtx, err
@@ -87,10 +87,10 @@ func (h *testToolsFuncHandler) BeforeAgent(ctx context.Context, runCtx *ChatMode
8787

8888
type testBeforeAgentHandler struct {
8989
*BaseChatModelAgentMiddleware
90-
fn func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error)
90+
fn func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error)
9191
}
9292

93-
func (h *testBeforeAgentHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
93+
func (h *testBeforeAgentHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
9494
return h.fn(ctx, runCtx)
9595
}
9696

@@ -894,10 +894,10 @@ func TestContextPropagation(t *testing.T) {
894894
Description: "Test agent",
895895
Model: cm,
896896
Handlers: []ChatModelAgentMiddleware{
897-
&testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
897+
&testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
898898
return context.WithValue(ctx, key1, "value1"), runCtx, nil
899899
}},
900-
&testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
900+
&testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
901901
handler2ReceivedValue = ctx.Value(key1)
902902
return ctx, runCtx, nil
903903
}},
@@ -962,7 +962,7 @@ func TestHandlerErrorHandling(t *testing.T) {
962962
Description: "Test agent",
963963
Model: cm,
964964
Handlers: []ChatModelAgentMiddleware{
965-
&testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
965+
&testBeforeAgentHandler{fn: func(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
966966
return ctx, runCtx, assert.AnError
967967
}},
968968
},
@@ -1042,7 +1042,7 @@ type countingHandler struct {
10421042
mu sync.Mutex
10431043
}
10441044

1045-
func (h *countingHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext) (context.Context, *ChatModelAgentContext, error) {
1045+
func (h *countingHandler) BeforeAgent(ctx context.Context, runCtx *ChatModelAgentContext[*schema.Message]) (context.Context, *ChatModelAgentContext[*schema.Message], error) {
10461046
h.mu.Lock()
10471047
h.beforeAgentCount++
10481048
h.mu.Unlock()

0 commit comments

Comments
 (0)