Skip to content

Commit 96357c3

Browse files
refactor(adk): replace AfterToolCallsRewriteState with per-run WithAfterToolCallsHook option (#985)
1 parent ed3e2bf commit 96357c3

4 files changed

Lines changed: 186 additions & 134 deletions

File tree

adk/chatmodel.go

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ type typedChatModelAgentExecCtx[M messageType] struct {
5454
// Invariant: any code path that emits model output events MUST check this flag.
5555
suppressEventSend bool
5656
retryVerdictSignal *retryVerdictSignal
57+
58+
afterToolCallsHook func(ctx context.Context) error
5759
}
5860

5961
func (e *typedChatModelAgentExecCtx[M]) send(event *TypedAgentEvent[M]) {
@@ -87,6 +89,8 @@ type chatModelAgentRunOptions struct {
8789
agentToolOptions map[string][]AgentRunOption
8890

8991
historyModifier func(context.Context, []Message) []Message
92+
93+
afterToolCallsHook func(ctx context.Context) error
9094
}
9195

9296
// WithChatModelOptions sets options for the underlying chat model.
@@ -118,6 +122,17 @@ func WithHistoryModifier(f func(context.Context, []Message) []Message) AgentRunO
118122
})
119123
}
120124

125+
// WithAfterToolCallsHook registers a per-run hook that fires synchronously after
126+
// all tool calls in a react iteration complete, before the next ChatModel call.
127+
//
128+
// This is suitable for TurnLoop Push+Preempt patterns where the pushed item
129+
// must be visible to the next turn's GenInput.
130+
func WithAfterToolCallsHook(fn func(ctx context.Context) error) AgentRunOption {
131+
return WrapImplSpecificOptFn(func(t *chatModelAgentRunOptions) {
132+
t.afterToolCallsHook = fn
133+
})
134+
}
135+
121136
type ToolsConfig struct {
122137
compose.ToolsNodeConfig
123138

@@ -461,6 +476,8 @@ type typedRunParams[M messageType] struct {
461476
cancelCtx *cancelContext
462477
cancelCtxOwned bool
463478
composeOpts []compose.Option
479+
480+
afterToolCallsHook func(ctx context.Context) error
464481
}
465482

466483
type typedRunFunc[M messageType] func(ctx context.Context, p *typedRunParams[M])
@@ -1130,6 +1147,7 @@ func (a *TypedChatModelAgent[M]) buildMessageReActRunFunc(ctx context.Context, b
11301147
generator: mp.generator,
11311148
cancelCtx: cancelCtx,
11321149
failoverLastSuccessModel: msgModel,
1150+
afterToolCallsHook: mp.afterToolCallsHook,
11331151
})
11341152

11351153
// Pre-execution cancel check
@@ -1406,6 +1424,7 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput
14061424

14071425
co := getComposeOptions(opts)
14081426
co = append(co, compose.WithCheckPointID(bridgeCheckpointID))
1427+
runOps := GetImplSpecificOptions[chatModelAgentRunOptions](nil, opts...)
14091428

14101429
if bc != nil {
14111430
co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos)))
@@ -1439,14 +1458,15 @@ func (a *TypedChatModelAgent[M]) Run(ctx context.Context, input *TypedAgentInput
14391458
}
14401459

14411460
run(ctx, &typedRunParams[M]{
1442-
input: input,
1443-
generator: generator,
1444-
store: newBridgeStore(),
1445-
instruction: instruction,
1446-
returnDirectly: returnDirectly,
1447-
cancelCtx: cancelCtx,
1448-
cancelCtxOwned: cancelCtxOwned,
1449-
composeOpts: co,
1461+
input: input,
1462+
generator: generator,
1463+
store: newBridgeStore(),
1464+
instruction: instruction,
1465+
returnDirectly: returnDirectly,
1466+
cancelCtx: cancelCtx,
1467+
cancelCtxOwned: cancelCtxOwned,
1468+
composeOpts: co,
1469+
afterToolCallsHook: runOps.afterToolCallsHook,
14501470
})
14511471
}()
14521472

@@ -1480,6 +1500,7 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o
14801500

14811501
co := getComposeOptions(opts)
14821502
co = append(co, compose.WithCheckPointID(bridgeCheckpointID))
1503+
resumeRunOps := GetImplSpecificOptions[chatModelAgentRunOptions](nil, opts...)
14831504

14841505
if bc != nil {
14851506
co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos)))
@@ -1563,14 +1584,15 @@ func (a *TypedChatModelAgent[M]) Resume(ctx context.Context, info *ResumeInfo, o
15631584
}
15641585

15651586
run(ctx, &typedRunParams[M]{
1566-
input: &TypedAgentInput[M]{EnableStreaming: info.EnableStreaming},
1567-
generator: generator,
1568-
store: newResumeBridgeStore(bridgeCheckpointID, stateByte),
1569-
instruction: instruction,
1570-
returnDirectly: returnDirectly,
1571-
cancelCtx: cancelCtx,
1572-
cancelCtxOwned: cancelCtxOwned,
1573-
composeOpts: co,
1587+
input: &TypedAgentInput[M]{EnableStreaming: info.EnableStreaming},
1588+
generator: generator,
1589+
store: newResumeBridgeStore(bridgeCheckpointID, stateByte),
1590+
instruction: instruction,
1591+
returnDirectly: returnDirectly,
1592+
cancelCtx: cancelCtx,
1593+
cancelCtxOwned: cancelCtxOwned,
1594+
composeOpts: co,
1595+
afterToolCallsHook: resumeRunOps.afterToolCallsHook,
15741596
})
15751597
}()
15761598

adk/handler.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,6 @@ type TypedChatModelAgentMiddleware[M messageType] interface {
161161
// - DeferredToolInfos: tools for server-side search (nil if unused)
162162
AfterModelRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], mc *ModelContext) (context.Context, *TypedChatModelAgentState[M], error)
163163

164-
// AfterToolCallsRewriteState is called after all concurrent tool calls in an iteration complete.
165-
// The input state includes all messages up to and including the tool call results.
166-
// The returned state is persisted to the agent's internal state.
167-
//
168-
// The ToolCallsContext provides metadata about the tool calls that just completed,
169-
// derived from the assistant message's ToolCalls field.
170-
AfterToolCallsRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], tc *ToolCallsContext) (context.Context, *TypedChatModelAgentState[M], error)
171-
172164
// WrapInvokableToolCall wraps a tool's synchronous execution with custom behavior.
173165
// Return the input endpoint unchanged and nil error if no wrapping is needed.
174166
//
@@ -297,10 +289,6 @@ func (b *TypedBaseChatModelAgentMiddleware[M]) AfterModelRewriteState(ctx contex
297289
return ctx, state, nil
298290
}
299291

300-
func (b *TypedBaseChatModelAgentMiddleware[M]) AfterToolCallsRewriteState(ctx context.Context, state *TypedChatModelAgentState[M], tc *ToolCallsContext) (context.Context, *TypedChatModelAgentState[M], error) {
301-
return ctx, state, nil
302-
}
303-
304292
func processTypedState(ctx context.Context, fn func(extra map[string]any) map[string]any) error {
305293
runCtx := getRunCtx(ctx)
306294
if runCtx != nil && runCtx.AgenticRootInput != nil {

0 commit comments

Comments
 (0)