Skip to content

Commit 90b124a

Browse files
committed
fix(gemini): preserve thought signatures and token usage
1 parent d0cf913 commit 90b124a

6 files changed

Lines changed: 144 additions & 9 deletions

File tree

Sources/AgentRunKit/LLM/GeminiClient.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ extension GeminiClient {
222222
}()
223223
let arguments = try encodeFunctionCallArgs(functionCall.args)
224224
toolCalls.append(ToolCall(id: callId, name: functionCall.name, arguments: arguments))
225+
if let signature = part.thoughtSignature, !signature.isEmpty {
226+
reasoningDetails.append(GeminiReasoningDetail.functionCallSignature(
227+
toolCallID: callId,
228+
signature: signature
229+
))
230+
}
225231
} else if let text = part.text {
226232
if part.thought == true {
227233
reasoningText = reasoningText.map { $0 + "\n" + text } ?? text

Sources/AgentRunKit/LLM/GeminiClientStreaming.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ extension GeminiClient {
6161
await state.flushThinkingBlock()
6262
let toolIndex = await state.incrementToolCallCount()
6363
let callId = functionCall.id ?? "gemini_call_\(toolIndex)"
64+
if let signature = part.thoughtSignature, !signature.isEmpty {
65+
continuation.yield(.reasoningDetails([
66+
GeminiReasoningDetail.functionCallSignature(
67+
toolCallID: callId,
68+
signature: signature
69+
)
70+
]))
71+
}
6472
continuation.yield(.toolCallStart(
6573
index: toolIndex, id: callId, name: functionCall.name
6674
))

Sources/AgentRunKit/LLM/GeminiClientTypes.swift

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,10 @@ struct GeminiUsageMetadata: Decodable {
170170
let cachedContentTokenCount: Int?
171171

172172
var tokenUsage: TokenUsage {
173-
let thoughts = thoughtsTokenCount ?? 0
174-
let candidates = candidatesTokenCount ?? 0
175-
return TokenUsage(
173+
TokenUsage(
176174
input: promptTokenCount ?? 0,
177-
output: max(0, candidates - thoughts),
178-
reasoning: thoughts,
175+
output: candidatesTokenCount ?? 0,
176+
reasoning: thoughtsTokenCount ?? 0,
179177
cacheRead: cachedContentTokenCount
180178
)
181179
}
@@ -191,6 +189,38 @@ struct GeminiErrorDetail: Decodable {
191189
let status: String
192190
}
193191

192+
enum GeminiReasoningDetail {
193+
private static let functionCallSignatureType = "gemini.function_call"
194+
195+
static func functionCallSignature(
196+
toolCallID: String,
197+
signature: String
198+
) -> JSONValue {
199+
.object([
200+
"type": .string(functionCallSignatureType),
201+
"tool_call_id": .string(toolCallID),
202+
"thought_signature": .string(signature)
203+
])
204+
}
205+
206+
static func functionCallSignatures(
207+
from details: [JSONValue]
208+
) -> [String: String] {
209+
var signatures: [String: String] = [:]
210+
for detail in details {
211+
guard case let .object(dict) = detail,
212+
case .string(functionCallSignatureType) = dict["type"],
213+
case let .string(toolCallID) = dict["tool_call_id"],
214+
case let .string(signature) = dict["thought_signature"]
215+
else {
216+
continue
217+
}
218+
signatures[toolCallID] = signature
219+
}
220+
return signatures
221+
}
222+
}
223+
194224
enum GeminiMessageMapper {
195225
static func mapMessages(
196226
_ messages: [ChatMessage]
@@ -253,6 +283,9 @@ enum GeminiMessageMapper {
253283
_ msg: AssistantMessage
254284
) throws -> GeminiContent {
255285
var parts: [GeminiPart] = []
286+
let functionCallSignatures = GeminiReasoningDetail.functionCallSignatures(
287+
from: msg.reasoningDetails ?? []
288+
)
256289

257290
if let details = msg.reasoningDetails {
258291
for detail in details {
@@ -282,7 +315,8 @@ enum GeminiMessageMapper {
282315
parts.append(GeminiPart(
283316
functionCall: GeminiFunctionCall(
284317
id: call.id, name: call.name, args: args
285-
)
318+
),
319+
thoughtSignature: functionCallSignatures[call.id]
286320
))
287321
}
288322

Tests/AgentRunKitTests/GeminiClientTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,28 @@ struct GeminiMessageMapperTests {
432432
#expect(mapped[0].parts[0].thoughtSignature == nil)
433433
}
434434

435+
@Test
436+
func assistantFunctionCallThoughtSignatureRoundTrips() throws {
437+
let msg = AssistantMessage(
438+
content: "",
439+
toolCalls: [
440+
ToolCall(id: "call_sig", name: "search", arguments: "{\"q\":\"test\"}")
441+
],
442+
reasoningDetails: [
443+
GeminiReasoningDetail.functionCallSignature(
444+
toolCallID: "call_sig",
445+
signature: "sig_fc"
446+
)
447+
]
448+
)
449+
let (_, mapped) = try GeminiMessageMapper.mapMessages([.assistant(msg)])
450+
451+
#expect(mapped.count == 1)
452+
#expect(mapped[0].parts.count == 1)
453+
#expect(mapped[0].parts[0].functionCall?.id == "call_sig")
454+
#expect(mapped[0].parts[0].thoughtSignature == "sig_fc")
455+
}
456+
435457
@Test
436458
func multimodalThrows() {
437459
do {

Tests/AgentRunKitTests/GeminiResponseParsingTests.swift

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ struct GeminiResponseParsingTests {
9393
Issue.record("Expected object in reasoning details")
9494
}
9595
#expect(msg.tokenUsage?.reasoning == 50)
96-
#expect(msg.tokenUsage?.output == 150)
96+
#expect(msg.tokenUsage?.output == 200)
9797
}
9898

9999
@Test
@@ -285,7 +285,7 @@ struct GeminiResponseParsingTests {
285285
#expect(msg.reasoning?.content == "Think first\nThink again")
286286
#expect(msg.reasoningDetails?.count == 2)
287287
#expect(msg.tokenUsage?.reasoning == 40)
288-
#expect(msg.tokenUsage?.output == 80)
288+
#expect(msg.tokenUsage?.output == 120)
289289
}
290290

291291
@Test
@@ -398,3 +398,45 @@ struct GeminiResponseParsingTests {
398398
#expect(inputDict["units"] == .string("celsius"))
399399
}
400400
}
401+
402+
struct GeminiFunctionCallReasoningDetailsTests {
403+
private func makeClient() -> GeminiClient {
404+
GeminiClient(apiKey: "test-key", model: "gemini-2.5-pro")
405+
}
406+
407+
@Test
408+
func functionCallThoughtSignatureIsPreservedInReasoningDetails() throws {
409+
let json = """
410+
{
411+
"candidates": [{
412+
"content": {
413+
"role": "model",
414+
"parts": [
415+
{
416+
"functionCall": {"id": "call_sig", "name": "search", "args": {"q": "test"}},
417+
"thoughtSignature": "sig_fc"
418+
}
419+
]
420+
},
421+
"finishReason": "STOP"
422+
}],
423+
"usageMetadata": {
424+
"promptTokenCount": 40,
425+
"candidatesTokenCount": 20
426+
}
427+
}
428+
"""
429+
let msg = try makeClient().parseResponse(Data(json.utf8))
430+
431+
#expect(msg.toolCalls.count == 1)
432+
#expect(msg.toolCalls[0].id == "call_sig")
433+
#expect(msg.reasoningDetails?.count == 1)
434+
if case let .object(dict) = msg.reasoningDetails?[0] {
435+
#expect(dict["type"] == .string("gemini.function_call"))
436+
#expect(dict["tool_call_id"] == .string("call_sig"))
437+
#expect(dict["thought_signature"] == .string("sig_fc"))
438+
} else {
439+
Issue.record("Expected function-call reasoning detail")
440+
}
441+
}
442+
}

Tests/AgentRunKitTests/GeminiStreamingTests.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ struct GeminiStreamingTests {
8585
}
8686
}
8787

88+
@Test
89+
func functionCallStreamingEmitsThoughtSignatureDetail() async throws {
90+
let lines = [
91+
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"functionCall\":{\"id\":\"call_01\",\"name\":\"get_weather\",\"args\":{\"city\":\"NYC\"}},\"thoughtSignature\":\"sig_fc\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":20,\"candidatesTokenCount\":15}}",
92+
]
93+
let deltas = try await collectDeltas(from: lines)
94+
95+
let detailDeltas = deltas.filter {
96+
if case .reasoningDetails = $0 { return true }; return false
97+
}
98+
#expect(detailDeltas.count == 1)
99+
if case let .reasoningDetails(details) = detailDeltas[0] {
100+
#expect(details.count == 1)
101+
if case let .object(dict) = details[0] {
102+
#expect(dict["type"] == .string("gemini.function_call"))
103+
#expect(dict["tool_call_id"] == .string("call_01"))
104+
#expect(dict["thought_signature"] == .string("sig_fc"))
105+
} else {
106+
Issue.record("Expected function-call reasoning detail")
107+
}
108+
}
109+
}
110+
88111
@Test
89112
func thinkingStreaming() async throws {
90113
let lines = [
@@ -148,7 +171,7 @@ struct GeminiStreamingTests {
148171
#expect(finishedDeltas.count == 1)
149172
if case let .finished(usage) = finishedDeltas[0] {
150173
#expect(usage?.input == 100)
151-
#expect(usage?.output == 40)
174+
#expect(usage?.output == 50)
152175
#expect(usage?.reasoning == 10)
153176
#expect(usage?.cacheRead == 20)
154177
}

0 commit comments

Comments
 (0)