Skip to content

Commit 91925ac

Browse files
fix(summarization): use state.ToolInfos instead of deprecated mc.Tools
The token counter was reading tools from TypedModelContext.Tools which is deprecated, constructed once at graph-build time, and never reflects modifications by earlier handlers. Switch to state.ToolInfos which is the persisted source of truth. Add test asserting the correct source is used. Change-Id: I15a41ac2a0b990956bd1f814951a513023dbb55b
1 parent b2d21ce commit 91925ac

2 files changed

Lines changed: 52 additions & 8 deletions

File tree

adk/middlewares/summarization/summarization.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,11 @@ func SummarizeMessages(ctx context.Context, cfg *Config, messages []adk.Message)
368368
}
369369

370370
func (m *typedMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[M],
371-
mtx *adk.TypedModelContext[M]) (context.Context, *adk.TypedChatModelAgentState[M], error) {
372-
373-
var tools []*schema.ToolInfo
374-
if mtx != nil {
375-
tools = mtx.Tools
376-
}
371+
_ *adk.TypedModelContext[M]) (context.Context, *adk.TypedChatModelAgentState[M], error) {
377372

378373
triggered, err := m.shouldSummarize(ctx, &TypedTokenCounterInput[M]{
379374
Messages: state.Messages,
380-
Tools: tools,
375+
Tools: state.ToolInfos,
381376
})
382377
if err != nil {
383378
return nil, nil, err

adk/middlewares/summarization/summarization_generic_test.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,13 @@ func TestSummarizationGeneric(t *testing.T) {
9999
t.Run("Helpers", testSummarizationHelpers[*schema.Message])
100100
t.Run("Flow", testSummarizationFlow[*schema.Message])
101101
t.Run("SummarizeMessages", testTypedSummarizeMessages[*schema.Message])
102+
t.Run("TokenCounterUsesStateToolInfos", testTokenCounterReceivesStateToolInfos[*schema.Message])
102103
})
103104
t.Run("AgenticMessage", func(t *testing.T) {
104105
t.Run("Helpers", testSummarizationHelpers[*schema.AgenticMessage])
105106
t.Run("Flow", testSummarizationFlow[*schema.AgenticMessage])
106107
t.Run("SummarizeMessages", testTypedSummarizeMessages[*schema.AgenticMessage])
108+
t.Run("TokenCounterUsesStateToolInfos", testTokenCounterReceivesStateToolInfos[*schema.AgenticMessage])
107109
})
108110
}
109111

@@ -291,7 +293,54 @@ func testSummarizationFlow[M adk.MessageType](t *testing.T) {
291293
assert.True(t, foundSummary, "should have a summary message")
292294
}
293295

294-
// testTypedSummarizeMessages tests the synchronous TypedSummarizeMessages API.
296+
// testTokenCounterReceivesStateToolInfos verifies that the token counter receives
297+
// state.ToolInfos (not the deprecated mc.Tools). We set up state.ToolInfos and mc.Tools
298+
// with different values and assert the token counter sees only state.ToolInfos.
299+
func testTokenCounterReceivesStateToolInfos[M adk.MessageType](t *testing.T) {
300+
ctx := context.Background()
301+
302+
stateTools := []*schema.ToolInfo{
303+
{Name: "state_tool_a"},
304+
{Name: "state_tool_b"},
305+
}
306+
mcTools := []*schema.ToolInfo{
307+
{Name: "mc_tool_should_not_appear"},
308+
}
309+
310+
var receivedTools []*schema.ToolInfo
311+
tokenCounter := func(_ context.Context, input *TypedTokenCounterInput[M]) (int, error) {
312+
receivedTools = input.Tools
313+
return 0, nil // below threshold — won't trigger summarization
314+
}
315+
316+
cfg := &TypedConfig[M]{
317+
Model: &genericMockModel[M]{
318+
response: smakeAssistantMsg[M]("unused"),
319+
},
320+
TokenCounter: tokenCounter,
321+
Trigger: &TriggerCondition{
322+
ContextTokens: 9999, // high threshold so summarization is not triggered
323+
},
324+
}
325+
326+
mw, err := NewTyped(ctx, cfg)
327+
require.NoError(t, err)
328+
329+
state := &adk.TypedChatModelAgentState[M]{
330+
Messages: []M{smakeUserMsg[M]("hello")},
331+
ToolInfos: stateTools,
332+
}
333+
mc := &adk.TypedModelContext[M]{Tools: mcTools}
334+
335+
_, _, err = mw.BeforeModelRewriteState(ctx, state, mc)
336+
require.NoError(t, err)
337+
338+
// The token counter must have received state.ToolInfos, not mc.Tools.
339+
require.NotNil(t, receivedTools, "token counter should have been called")
340+
require.Len(t, receivedTools, 2)
341+
assert.Equal(t, "state_tool_a", receivedTools[0].Name)
342+
assert.Equal(t, "state_tool_b", receivedTools[1].Name)
343+
}
295344
func testTypedSummarizeMessages[M adk.MessageType](t *testing.T) {
296345
ctx := context.Background()
297346

0 commit comments

Comments
 (0)