Skip to content

Commit 2281a20

Browse files
authored
fix: propagate reasoningSignature on Gemini tool use (strands-agents#1703)
1 parent 25c2aa4 commit 2281a20

6 files changed

Lines changed: 170 additions & 13 deletions

File tree

src/strands/event_loop/streaming.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
186186
current_tool_use["toolUseId"] = tool_use_data["toolUseId"]
187187
current_tool_use["name"] = tool_use_data["name"]
188188
current_tool_use["input"] = ""
189+
if "reasoningSignature" in tool_use_data:
190+
current_tool_use["reasoningSignature"] = tool_use_data["reasoningSignature"]
189191

190192
return current_tool_use
191193

@@ -286,6 +288,8 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
286288
name=tool_use_name,
287289
input=current_tool_use["input"],
288290
)
291+
if "reasoningSignature" in current_tool_use:
292+
tool_use["reasoningSignature"] = current_tool_use["reasoningSignature"]
289293
content.append({"toolUse": tool_use})
290294
state["current_tool_use"] = {}
291295

src/strands/models/gemini.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Docs: https://ai.google.dev/api
44
"""
55

6+
import base64
67
import json
78
import logging
89
import mimetypes
@@ -14,7 +15,7 @@
1415
from google import genai
1516
from typing_extensions import Required, Unpack, override
1617

17-
from ..types.content import ContentBlock, Messages
18+
from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages
1819
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
1920
from ..types.streaming import StreamEvent
2021
from ..types.tools import ToolChoice, ToolSpec
@@ -173,7 +174,7 @@ def _format_request_content_part(
173174
return genai.types.Part(
174175
text=content["reasoningContent"]["reasoningText"]["text"],
175176
thought=True,
176-
thought_signature=thought_signature.encode("utf-8") if thought_signature else None,
177+
thought_signature=base64.b64decode(thought_signature) if thought_signature else None,
177178
)
178179

179180
if "text" in content:
@@ -202,14 +203,18 @@ def _format_request_content_part(
202203
)
203204

204205
if "toolUse" in content:
205-
tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"]
206+
tool_use_id = content["toolUse"]["toolUseId"]
207+
tool_use_id_to_name[tool_use_id] = content["toolUse"]["name"]
208+
209+
reasoning_signature = content["toolUse"].get("reasoningSignature")
206210

207211
return genai.types.Part(
208212
function_call=genai.types.FunctionCall(
209213
args=content["toolUse"]["input"],
210-
id=content["toolUse"]["toolUseId"],
214+
id=tool_use_id,
211215
name=content["toolUse"]["name"],
212216
),
217+
thought_signature=base64.b64decode(reasoning_signature) if reasoning_signature else None,
213218
)
214219

215220
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
@@ -349,13 +354,18 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
349354
# Use Gemini's provided ID or generate one if missing
350355
tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}"
351356

357+
tool_use_start: ContentBlockStartToolUse = {
358+
"name": function_call.name,
359+
"toolUseId": tool_use_id,
360+
}
361+
if event["data"].thought_signature:
362+
tool_use_start["reasoningSignature"] = base64.b64encode(
363+
event["data"].thought_signature
364+
).decode("ascii")
352365
return {
353366
"contentBlockStart": {
354367
"start": {
355-
"toolUse": {
356-
"name": function_call.name,
357-
"toolUseId": tool_use_id,
358-
},
368+
"toolUse": tool_use_start,
359369
},
360370
},
361371
}
@@ -379,7 +389,11 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
379389
"reasoningContent": {
380390
"text": event["data"].text,
381391
**(
382-
{"signature": event["data"].thought_signature.decode("utf-8")}
392+
{
393+
"signature": base64.b64encode(event["data"].thought_signature).decode(
394+
"ascii"
395+
)
396+
}
383397
if event["data"].thought_signature
384398
else {}
385399
),

src/strands/types/content.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from typing import Literal
1010

11-
from typing_extensions import TypedDict
11+
from typing_extensions import NotRequired, TypedDict
1212

1313
from .citations import CitationsContentBlock
1414
from .media import DocumentContent, ImageContent, VideoContent
@@ -129,10 +129,12 @@ class ContentBlockStartToolUse(TypedDict):
129129
Attributes:
130130
name: The name of the tool that the model is requesting to use.
131131
toolUseId: The ID for the tool request.
132+
reasoningSignature: Token that ties the model's reasoning to this tool call.
132133
"""
133134

134135
name: str
135136
toolUseId: str
137+
reasoningSignature: NotRequired[str]
136138

137139

138140
class ContentBlockStart(TypedDict, total=False):

src/strands/types/tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ class ToolUse(TypedDict):
5858
Can be any JSON-serializable type.
5959
name: The name of the tool to invoke.
6060
toolUseId: A unique identifier for this specific tool use request.
61+
reasoningSignature: Token that ties the model's reasoning to this tool call.
6162
"""
6263

6364
input: Any
6465
name: str
6566
toolUseId: str
67+
reasoningSignature: NotRequired[str]
6668

6769

6870
class ToolResultContent(TypedDict, total=False):

tests/strands/event_loop/test_streaming.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def test_handle_message_start():
125125
{"start": {"toolUse": {"toolUseId": "test", "name": "test"}}},
126126
{"toolUseId": "test", "name": "test", "input": ""},
127127
),
128+
(
129+
{"start": {"toolUse": {"toolUseId": "test", "name": "test", "reasoningSignature": "YWJj"}}},
130+
{"toolUseId": "test", "name": "test", "input": "", "reasoningSignature": "YWJj"},
131+
),
128132
],
129133
)
130134
def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use):
@@ -310,6 +314,39 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s
310314
"redactedContent": b"",
311315
},
312316
),
317+
# Tool Use - With reasoningSignature
318+
(
319+
{
320+
"content": [],
321+
"current_tool_use": {
322+
"toolUseId": "123",
323+
"name": "test",
324+
"input": '{"key": "value"}',
325+
"reasoningSignature": "YWJj",
326+
},
327+
"text": "",
328+
"reasoningText": "",
329+
"citationsContent": [],
330+
"redactedContent": b"",
331+
},
332+
{
333+
"content": [
334+
{
335+
"toolUse": {
336+
"toolUseId": "123",
337+
"name": "test",
338+
"input": {"key": "value"},
339+
"reasoningSignature": "YWJj",
340+
}
341+
}
342+
],
343+
"current_tool_use": {},
344+
"text": "",
345+
"reasoningText": "",
346+
"citationsContent": [],
347+
"redactedContent": b"",
348+
},
349+
),
313350
# Tool Use - Missing input
314351
(
315352
{

tests/strands/models/test_gemini.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id):
203203
{
204204
"reasoningContent": {
205205
"reasoningText": {
206-
"signature": "abc",
206+
"signature": "YWJj", # base64 of "abc"
207207
"text": "reasoning_text",
208208
},
209209
},
@@ -260,6 +260,51 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too
260260

261261
@pytest.mark.asyncio
262262
async def test_stream_request_with_tool_use(gemini_client, model, model_id):
263+
"""Test toolUse with reasoningSignature is sent as function_call with thought_signature."""
264+
messages = [
265+
{
266+
"role": "assistant",
267+
"content": [
268+
{
269+
"toolUse": {
270+
"toolUseId": "c1",
271+
"name": "calculator",
272+
"input": {"expression": "2+2"},
273+
"reasoningSignature": "YWJj", # base64 of "abc"
274+
},
275+
},
276+
],
277+
},
278+
]
279+
await anext(model.stream(messages))
280+
281+
exp_request = {
282+
"config": {
283+
"tools": [{"function_declarations": []}],
284+
},
285+
"contents": [
286+
{
287+
"parts": [
288+
{
289+
"function_call": {
290+
"args": {"expression": "2+2"},
291+
"id": "c1",
292+
"name": "calculator",
293+
},
294+
"thought_signature": "YWJj",
295+
},
296+
],
297+
"role": "model",
298+
},
299+
],
300+
"model": model_id,
301+
}
302+
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
303+
304+
305+
@pytest.mark.asyncio
306+
async def test_stream_request_with_tool_use_no_reasoning_signature(gemini_client, model, model_id):
307+
"""Test toolUse without reasoningSignature is sent as function_call without thought_signature."""
263308
messages = [
264309
{
265310
"role": "assistant",
@@ -532,6 +577,55 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat
532577
assert tru_chunks == exp_chunks
533578

534579

580+
@pytest.mark.asyncio
581+
async def test_stream_response_tool_use_with_thought_signature(gemini_client, model, messages, agenerator, alist):
582+
"""Test that tool use responses with thought_signature include reasoningSignature."""
583+
gemini_client.aio.models.generate_content_stream.return_value = agenerator(
584+
[
585+
genai.types.GenerateContentResponse(
586+
candidates=[
587+
genai.types.Candidate(
588+
content=genai.types.Content(
589+
parts=[
590+
genai.types.Part(
591+
function_call=genai.types.FunctionCall(
592+
args={"expression": "2+2"},
593+
id="c1",
594+
name="calculator",
595+
),
596+
thought_signature=b"abc",
597+
),
598+
],
599+
),
600+
finish_reason="STOP",
601+
),
602+
],
603+
usage_metadata=genai.types.GenerateContentResponseUsageMetadata(
604+
prompt_token_count=1,
605+
total_token_count=3,
606+
),
607+
),
608+
]
609+
)
610+
611+
tru_chunks = await alist(model.stream(messages))
612+
exp_chunks = [
613+
{"messageStart": {"role": "assistant"}},
614+
{
615+
"contentBlockStart": {
616+
"start": {
617+
"toolUse": {"name": "calculator", "toolUseId": "c1", "reasoningSignature": "YWJj"},
618+
},
619+
},
620+
},
621+
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}},
622+
{"contentBlockStop": {}},
623+
{"messageStop": {"stopReason": "tool_use"}},
624+
{"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}},
625+
]
626+
assert tru_chunks == exp_chunks
627+
628+
535629
@pytest.mark.asyncio
536630
async def test_stream_response_reasoning(gemini_client, model, messages, agenerator, alist):
537631
gemini_client.aio.models.generate_content_stream.return_value = agenerator(
@@ -563,7 +657,7 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera
563657
exp_chunks = [
564658
{"messageStart": {"role": "assistant"}},
565659
{"contentBlockStart": {"start": {}}},
566-
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "abc", "text": "test reason"}}}},
660+
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "YWJj", "text": "test reason"}}}},
567661
{"contentBlockStop": {}},
568662
{"messageStop": {"stopReason": "end_turn"}},
569663
{"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}},
@@ -622,7 +716,11 @@ async def test_stream_response_reasoning_and_text(gemini_client, model, messages
622716
exp_chunks = [
623717
{"messageStart": {"role": "assistant"}},
624718
{"contentBlockStart": {"start": {}}},
625-
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "sig1", "text": "thinking about math"}}}},
719+
{
720+
"contentBlockDelta": {
721+
"delta": {"reasoningContent": {"signature": "c2lnMQ==", "text": "thinking about math"}}
722+
}
723+
},
626724
{"contentBlockStop": {}},
627725
{"contentBlockStart": {"start": {}}},
628726
{"contentBlockDelta": {"delta": {"text": "2 + 2 = 4"}}},

0 commit comments

Comments
 (0)