Skip to content

Commit 1bf4617

Browse files
Preserve Gemini thought_signature across OpenAI-compatible tool-call round trips (#1631)
## Summary This fixes a loss of Gemini thinking-model state in the OpenAI-compatible adapter. When kagent is configured to use Gemini through an OpenAI-compatible endpoint, the adapter was flattening tool calls and tool results into standard OpenAI chat messages without preserving Gemini-specific `extra_content.google.thought_signature`. That caused thinking-capable Gemini models to fail on the post-tool turn, even though the tool call itself succeeded. This change preserves the thought signature in both directions: - OpenAI-compatible response -> ADK Part - ADK Content -> OpenAI-compatible follow-up tool/result messages It also applies the same preservation in the streaming tool-call accumulation path. ## Why We traced a production stream where: 1. The first model turn succeeded, 2. The model issued a tool call, 3. The tool executed successfully, 4. The follow-up request failed because the original function call’s `thought_signature` was no longer present. The lossy step was in kagent’s Python OpenAI-compatible adapter. The adapter was reconstructing only standard OpenAI fields like text, tool calls, and tool results. Gemini thinking models require the opaque thought signature to survive the round-trip, so dropping it breaks the second turn. ## What Changed - Added logic to extract extra_content.google.thought_signature from OpenAI-compatible tool calls. - Mapped that value onto ADK Part.thought_signature. - Preserved Part.thought_signature when converting function-call parts back into OpenAI-compatible tool_calls. - Preserved the same signature on the paired tool-result message sent after tool execution. - Applied the same preservation in the streaming tool-call aggregation path. - Added regression tests for: - response conversion preserving thought signatures, - outbound message conversion preserving thought signatures, - full round-trip preservation across tool call + tool result. ## Scope This PR does not change public config or CRD APIs. It only fixes internal adapter behavior for OpenAI-compatible model flows. ## Validation Ran: `uv run pytest python/packages/kagent-adk/tests/unittests/models/test_openai.py` Result: `26 passed` --------- Signed-off-by: adminturneddevops <mlevan1992@gmail.com> Co-authored-by: Eitan Yarmush <eitan.yarmush@solo.io>
1 parent 85e038c commit 1bf4617

File tree

4 files changed

+527
-23
lines changed

4 files changed

+527
-23
lines changed

go/adk/pkg/models/openai_adk.go

Lines changed: 122 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/kagent-dev/kagent/go/adk/pkg/telemetry"
1515
"github.com/openai/openai-go/v3"
1616
"github.com/openai/openai-go/v3/packages/param"
17+
"github.com/openai/openai-go/v3/packages/respjson"
1718
"github.com/openai/openai-go/v3/shared"
1819
"github.com/openai/openai-go/v3/shared/constant"
1920
"google.golang.org/adk/model"
@@ -28,6 +29,7 @@ const (
2829
openAIFinishLength = "length"
2930
openAIFinishContentFilter = "content_filter"
3031
openAIToolTypeFunction = "function"
32+
openAIExtraContentKey = "extra_content"
3133
)
3234

3335
// openAIFinishReasonToGenai maps OpenAI finish_reason to genai.FinishReason.
@@ -42,6 +44,82 @@ func openAIFinishReasonToGenai(reason string) genai.FinishReason {
4244
}
4345
}
4446

47+
type openAIThoughtSignatureExtra struct {
48+
Google struct {
49+
ThoughtSignature string `json:"thought_signature"`
50+
} `json:"google"`
51+
}
52+
53+
func extractThoughtSignatureFromRaw(raw string) []byte {
54+
if raw == "" {
55+
return nil
56+
}
57+
58+
var extra openAIThoughtSignatureExtra
59+
if err := json.Unmarshal([]byte(raw), &extra); err != nil {
60+
return nil
61+
}
62+
if extra.Google.ThoughtSignature == "" {
63+
return nil
64+
}
65+
66+
decoded, err := base64.StdEncoding.DecodeString(extra.Google.ThoughtSignature)
67+
if err != nil {
68+
return nil
69+
}
70+
return decoded
71+
}
72+
73+
func extractThoughtSignatureFromExtraFields(extraFields map[string]respjson.Field) []byte {
74+
if len(extraFields) == 0 {
75+
return nil
76+
}
77+
field, ok := extraFields[openAIExtraContentKey]
78+
if !ok {
79+
return nil
80+
}
81+
return extractThoughtSignatureFromRaw(field.Raw())
82+
}
83+
84+
func openAIExtraContentForThoughtSignature(thoughtSignature []byte) map[string]any {
85+
if len(thoughtSignature) == 0 {
86+
return nil
87+
}
88+
89+
return map[string]any{
90+
"google": map[string]any{
91+
"thought_signature": base64.StdEncoding.EncodeToString(thoughtSignature),
92+
},
93+
}
94+
}
95+
96+
func thoughtSignaturesByToolCallID(contents []*genai.Content) map[string][]byte {
97+
thoughtSignatures := make(map[string][]byte)
98+
for _, content := range contents {
99+
if content == nil || content.Parts == nil {
100+
continue
101+
}
102+
for _, part := range content.Parts {
103+
if part == nil || part.FunctionCall == nil || len(part.ThoughtSignature) == 0 {
104+
continue
105+
}
106+
thoughtSignatures[part.FunctionCall.ID] = part.ThoughtSignature
107+
}
108+
}
109+
return thoughtSignatures
110+
}
111+
112+
func newFunctionCallPart(name string, args map[string]any, id string, thoughtSignature []byte) *genai.Part {
113+
part := genai.NewPartFromFunctionCall(name, args)
114+
if part.FunctionCall != nil {
115+
part.FunctionCall.ID = id
116+
}
117+
if len(thoughtSignature) > 0 {
118+
part.ThoughtSignature = thoughtSignature
119+
}
120+
return part
121+
}
122+
45123
// Name implements model.LLM.
46124
func (m *OpenAIModel) Name() string {
47125
return m.Config.Model
@@ -126,6 +204,7 @@ func genaiContentsToOpenAIMessages(contents []*genai.Content, config *genai.Gene
126204
systemInstruction := strings.TrimSpace(systemBuilder.String())
127205

128206
functionResponses := make(map[string]*genai.FunctionResponse)
207+
thoughtSignatures := thoughtSignaturesByToolCallID(contents)
129208
for _, c := range contents {
130209
if c == nil || c.Parts == nil {
131210
continue
@@ -167,21 +246,35 @@ func genaiContentsToOpenAIMessages(contents []*genai.Content, config *genai.Gene
167246
var toolResponseMessages []openai.ChatCompletionMessageParamUnion
168247
for _, fc := range functionCalls {
169248
argsJSON, _ := json.Marshal(fc.Args)
170-
toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{
171-
OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
172-
ID: fc.ID,
173-
Type: constant.Function(openAIToolTypeFunction),
174-
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
175-
Name: fc.Name,
176-
Arguments: string(argsJSON),
177-
},
249+
toolCall := openai.ChatCompletionMessageFunctionToolCallParam{
250+
ID: fc.ID,
251+
Type: constant.Function(openAIToolTypeFunction),
252+
Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
253+
Name: fc.Name,
254+
Arguments: string(argsJSON),
178255
},
256+
}
257+
if extraContent := openAIExtraContentForThoughtSignature(thoughtSignatures[fc.ID]); extraContent != nil {
258+
toolCall.SetExtraFields(map[string]any{openAIExtraContentKey: extraContent})
259+
}
260+
toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallUnionParam{
261+
OfFunction: &toolCall,
179262
})
180263
contentStr := "No response available for this function call."
181264
if fr := functionResponses[fc.ID]; fr != nil {
182265
contentStr = functionResponseContentString(fr.Response)
183266
}
184-
toolResponseMessages = append(toolResponseMessages, openai.ToolMessage(contentStr, fc.ID))
267+
toolMessage := openai.ChatCompletionToolMessageParam{
268+
Content: openai.ChatCompletionToolMessageParamContentUnion{
269+
OfString: param.NewOpt(contentStr),
270+
},
271+
ToolCallID: fc.ID,
272+
Role: constant.Tool("tool"),
273+
}
274+
if extraContent := openAIExtraContentForThoughtSignature(thoughtSignatures[fc.ID]); extraContent != nil {
275+
toolMessage.SetExtraFields(map[string]any{openAIExtraContentKey: extraContent})
276+
}
277+
toolResponseMessages = append(toolResponseMessages, openai.ChatCompletionMessageParamUnion{OfTool: &toolMessage})
185278
}
186279
textContent := strings.Join(textParts, "\n")
187280
asst := openai.ChatCompletionAssistantMessageParam{
@@ -357,7 +450,7 @@ func runStreaming(ctx context.Context, m *OpenAIModel, params openai.ChatComplet
357450
for _, tc := range delta.ToolCalls {
358451
idx := tc.Index
359452
if toolCallsAcc[idx] == nil {
360-
toolCallsAcc[idx] = map[string]any{"id": "", "name": "", "arguments": ""}
453+
toolCallsAcc[idx] = map[string]any{"id": "", "name": "", "arguments": "", "thought_signature": []byte(nil)}
361454
}
362455
if tc.ID != "" {
363456
toolCallsAcc[idx]["id"] = tc.ID
@@ -369,6 +462,9 @@ func runStreaming(ctx context.Context, m *OpenAIModel, params openai.ChatComplet
369462
prev, _ := toolCallsAcc[idx]["arguments"].(string)
370463
toolCallsAcc[idx]["arguments"] = prev + tc.Function.Arguments
371464
}
465+
if thoughtSignature := extractThoughtSignatureFromExtraFields(tc.JSON.ExtraFields); len(thoughtSignature) > 0 {
466+
toolCallsAcc[idx]["thought_signature"] = thoughtSignature
467+
}
372468
}
373469
if choice.FinishReason != "" {
374470
finishReason = choice.FinishReason
@@ -404,8 +500,8 @@ func runStreaming(ctx context.Context, m *OpenAIModel, params openai.ChatComplet
404500
name, _ := tc["name"].(string)
405501
id, _ := tc["id"].(string)
406502
if name != "" || id != "" {
407-
p := genai.NewPartFromFunctionCall(name, args)
408-
p.FunctionCall.ID = id
503+
thoughtSignature, _ := tc["thought_signature"].([]byte)
504+
p := newFunctionCallPart(name, args, id, thoughtSignature)
409505
finalParts = append(finalParts, p)
410506
}
411507
}
@@ -438,6 +534,12 @@ func runNonStreaming(ctx context.Context, m *OpenAIModel, params openai.ChatComp
438534
yield(&model.LLMResponse{ErrorCode: "API_ERROR", ErrorMessage: "No choices in response"}, nil)
439535
return
440536
}
537+
resp := chatCompletionToLLMResponse(completion)
538+
telemetry.SetLLMResponseAttributes(ctx, resp)
539+
yield(resp, nil)
540+
}
541+
542+
func chatCompletionToLLMResponse(completion *openai.ChatCompletion) *model.LLMResponse {
441543
choice := completion.Choices[0]
442544
msg := choice.Message
443545
nParts := 0
@@ -455,8 +557,13 @@ func runNonStreaming(ctx context.Context, m *OpenAIModel, params openai.ChatComp
455557
if tc.Function.Arguments != "" {
456558
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
457559
}
458-
p := genai.NewPartFromFunctionCall(tc.Function.Name, args)
459-
p.FunctionCall.ID = tc.ID
560+
functionToolCall := tc.AsFunction()
561+
p := newFunctionCallPart(
562+
tc.Function.Name,
563+
args,
564+
tc.ID,
565+
extractThoughtSignatureFromExtraFields(functionToolCall.JSON.ExtraFields),
566+
)
460567
parts = append(parts, p)
461568
}
462569
}
@@ -467,13 +574,11 @@ func runNonStreaming(ctx context.Context, m *OpenAIModel, params openai.ChatComp
467574
CandidatesTokenCount: int32(completion.Usage.CompletionTokens),
468575
}
469576
}
470-
resp := &model.LLMResponse{
577+
return &model.LLMResponse{
471578
Partial: false,
472579
TurnComplete: true,
473580
FinishReason: openAIFinishReasonToGenai(choice.FinishReason),
474581
UsageMetadata: usage,
475582
Content: &genai.Content{Role: string(genai.RoleModel), Parts: parts},
476583
}
477-
telemetry.SetLLMResponseAttributes(ctx, resp)
478-
yield(resp, nil)
479584
}

go/adk/pkg/models/openai_adk_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package models
22

33
import (
4+
"encoding/base64"
5+
"encoding/json"
46
"testing"
57

68
"github.com/openai/openai-go/v3"
@@ -246,3 +248,163 @@ func TestApplyOpenAIConfig(t *testing.T) {
246248
}
247249
})
248250
}
251+
252+
func TestGenaiContentsToOpenAIMessages_PreservesThoughtSignatureOnToolCallAndToolResult(t *testing.T) {
253+
thoughtSignature := []byte("abc")
254+
255+
functionCall := genai.NewPartFromFunctionCall("add", map[string]any{"a": 2, "b": 2})
256+
functionCall.FunctionCall.ID = "call_1"
257+
functionCall.ThoughtSignature = thoughtSignature
258+
259+
functionResponse := genai.NewPartFromFunctionResponse("add", map[string]any{"result": "4"})
260+
functionResponse.FunctionResponse.ID = "call_1"
261+
262+
messages, _ := genaiContentsToOpenAIMessages([]*genai.Content{
263+
{
264+
Role: string(genai.RoleModel),
265+
Parts: []*genai.Part{functionCall},
266+
},
267+
{
268+
Role: string(genai.RoleUser),
269+
Parts: []*genai.Part{functionResponse},
270+
},
271+
}, nil)
272+
273+
if len(messages) != 2 {
274+
t.Fatalf("len(messages) = %d, want 2", len(messages))
275+
}
276+
277+
assistantJSON, err := json.Marshal(messages[0].OfAssistant)
278+
if err != nil {
279+
t.Fatalf("json.Marshal(assistant) error = %v", err)
280+
}
281+
toolJSON, err := json.Marshal(messages[1].OfTool)
282+
if err != nil {
283+
t.Fatalf("json.Marshal(tool) error = %v", err)
284+
}
285+
286+
want := base64.StdEncoding.EncodeToString(thoughtSignature)
287+
assertThoughtSignature := func(name string, payload []byte) {
288+
t.Helper()
289+
var obj map[string]any
290+
if err := json.Unmarshal(payload, &obj); err != nil {
291+
t.Fatalf("%s json.Unmarshal error = %v", name, err)
292+
}
293+
extra, ok := obj["extra_content"].(map[string]any)
294+
if !ok {
295+
t.Fatalf("%s missing extra_content: %s", name, string(payload))
296+
}
297+
googleExtra, ok := extra["google"].(map[string]any)
298+
if !ok {
299+
t.Fatalf("%s missing google extra content: %s", name, string(payload))
300+
}
301+
if got, _ := googleExtra["thought_signature"].(string); got != want {
302+
t.Fatalf("%s thought_signature = %q, want %q", name, got, want)
303+
}
304+
}
305+
306+
var assistantObj map[string]any
307+
if err := json.Unmarshal(assistantJSON, &assistantObj); err != nil {
308+
t.Fatalf("assistant json.Unmarshal error = %v", err)
309+
}
310+
toolCalls, ok := assistantObj["tool_calls"].([]any)
311+
if !ok || len(toolCalls) != 1 {
312+
t.Fatalf("assistant tool_calls = %#v, want 1 tool call", assistantObj["tool_calls"])
313+
}
314+
firstToolCall, ok := toolCalls[0].(map[string]any)
315+
if !ok {
316+
t.Fatalf("assistant tool call = %#v, want object", toolCalls[0])
317+
}
318+
firstToolCallJSON, err := json.Marshal(firstToolCall)
319+
if err != nil {
320+
t.Fatalf("json.Marshal(firstToolCall) error = %v", err)
321+
}
322+
323+
assertThoughtSignature("assistant tool_call", firstToolCallJSON)
324+
assertThoughtSignature("tool message", toolJSON)
325+
}
326+
327+
func TestChatCompletionToLLMResponse_PreservesThoughtSignature(t *testing.T) {
328+
raw := []byte(`{
329+
"id":"chatcmpl-1",
330+
"object":"chat.completion",
331+
"created":123,
332+
"model":"gemini-2.5-flash",
333+
"choices":[{
334+
"index":0,
335+
"finish_reason":"tool_calls",
336+
"message":{
337+
"role":"assistant",
338+
"tool_calls":[{
339+
"id":"call_1",
340+
"type":"function",
341+
"function":{
342+
"name":"add",
343+
"arguments":"{\"a\":2,\"b\":2}"
344+
},
345+
"extra_content":{
346+
"google":{
347+
"thought_signature":"YWJj"
348+
}
349+
}
350+
}]
351+
}
352+
}],
353+
"usage":{
354+
"prompt_tokens":3,
355+
"completion_tokens":4,
356+
"total_tokens":7
357+
}
358+
}`)
359+
360+
var completion openai.ChatCompletion
361+
if err := json.Unmarshal(raw, &completion); err != nil {
362+
t.Fatalf("json.Unmarshal(ChatCompletion) error = %v", err)
363+
}
364+
365+
resp := chatCompletionToLLMResponse(&completion)
366+
if resp.Content == nil || len(resp.Content.Parts) != 1 {
367+
t.Fatalf("response parts = %#v, want 1 function-call part", resp.Content)
368+
}
369+
370+
part := resp.Content.Parts[0]
371+
if part.FunctionCall == nil {
372+
t.Fatalf("part.FunctionCall = nil, want function call")
373+
}
374+
if part.FunctionCall.Name != "add" {
375+
t.Fatalf("part.FunctionCall.Name = %q, want %q", part.FunctionCall.Name, "add")
376+
}
377+
if string(part.ThoughtSignature) != "abc" {
378+
t.Fatalf("part.ThoughtSignature = %q, want %q", string(part.ThoughtSignature), "abc")
379+
}
380+
if resp.UsageMetadata == nil || resp.UsageMetadata.PromptTokenCount != 3 || resp.UsageMetadata.CandidatesTokenCount != 4 {
381+
t.Fatalf("usage metadata = %#v, want prompt=3 completion=4", resp.UsageMetadata)
382+
}
383+
}
384+
385+
func TestExtractThoughtSignatureFromStreamingToolCallChunk(t *testing.T) {
386+
raw := []byte(`{
387+
"index":0,
388+
"id":"call_1",
389+
"type":"function",
390+
"function":{
391+
"name":"add",
392+
"arguments":"{\"a\":2"
393+
},
394+
"extra_content":{
395+
"google":{
396+
"thought_signature":"YWJj"
397+
}
398+
}
399+
}`)
400+
401+
var toolCall openai.ChatCompletionChunkChoiceDeltaToolCall
402+
if err := json.Unmarshal(raw, &toolCall); err != nil {
403+
t.Fatalf("json.Unmarshal(ChatCompletionChunkChoiceDeltaToolCall) error = %v", err)
404+
}
405+
406+
thoughtSignature := extractThoughtSignatureFromExtraFields(toolCall.JSON.ExtraFields)
407+
if string(thoughtSignature) != "abc" {
408+
t.Fatalf("thoughtSignature = %q, want %q", string(thoughtSignature), "abc")
409+
}
410+
}

0 commit comments

Comments
 (0)