Skip to content

Commit 9a895f7

Browse files
authored
Merge branch 'strands-agents:main' into main
2 parents 29f246e + d77f08b commit 9a895f7

8 files changed

Lines changed: 385 additions & 33 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ docs = [
6868
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
6969
]
7070
litellm = [
71-
"litellm>=1.73.1,<2.0.0",
72-
# https://github.com/BerriAI/litellm/issues/13711
73-
"openai<1.100.0",
71+
"litellm>=1.75.9,<2.0.0",
72+
"openai>=1.68.0,<1.102.0",
7473
]
7574
llamaapi = [
7675
"llama-api-client>=0.1.0,<1.0.0",

src/strands/event_loop/streaming.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T
289289
"text": "",
290290
"current_tool_use": {},
291291
"reasoningText": "",
292-
"signature": "",
293292
"citationsContent": [],
294293
}
295294
state["content"] = state["message"]["content"]

src/strands/models/bedrock.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
"too many total text bytes",
3838
]
3939

40+
# Models that should include tool result status (include_tool_result_status = True)
41+
_MODELS_INCLUDE_STATUS = [
42+
"anthropic.claude",
43+
]
44+
4045
T = TypeVar("T", bound=BaseModel)
4146

4247

@@ -71,6 +76,8 @@ class BedrockConfig(TypedDict, total=False):
7176
guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message.
7277
max_tokens: Maximum number of tokens to generate in the response
7378
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0")
79+
include_tool_result_status: Flag to include status field in tool results.
80+
True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto".
7481
stop_sequences: List of sequences that will stop generation when encountered
7582
streaming: Flag to enable/disable streaming. Defaults to True.
7683
temperature: Controls randomness in generation (higher = more random)
@@ -92,6 +99,7 @@ class BedrockConfig(TypedDict, total=False):
9299
guardrail_redact_output_message: Optional[str]
93100
max_tokens: Optional[int]
94101
model_id: str
102+
include_tool_result_status: Optional[Literal["auto"] | bool]
95103
stop_sequences: Optional[list[str]]
96104
streaming: Optional[bool]
97105
temperature: Optional[float]
@@ -119,7 +127,7 @@ def __init__(
119127
if region_name and boto_session:
120128
raise ValueError("Cannot specify both `region_name` and `boto_session`.")
121129

122-
self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID)
130+
self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto")
123131
self.update_config(**model_config)
124132

125133
logger.debug("config=<%s> | initializing", self.config)
@@ -169,6 +177,17 @@ def get_config(self) -> BedrockConfig:
169177
"""
170178
return self.config
171179

180+
def _should_include_tool_result_status(self) -> bool:
181+
"""Determine whether to include tool result status based on current config."""
182+
include_status = self.config.get("include_tool_result_status", "auto")
183+
184+
if include_status is True:
185+
return True
186+
elif include_status is False:
187+
return False
188+
else: # "auto"
189+
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
190+
172191
def format_request(
173192
self,
174193
messages: Messages,
@@ -256,6 +275,7 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
256275
"""Format messages for Bedrock API compatibility.
257276
258277
This function ensures messages conform to Bedrock's expected format by:
278+
- Filtering out SDK_UNKNOWN_MEMBER content blocks
259279
- Cleaning tool result content blocks by removing additional fields that may be
260280
useful for retaining information in hooks but would cause Bedrock validation
261281
exceptions when presented with unexpected fields
@@ -273,19 +293,33 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
273293
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
274294
"""
275295
cleaned_messages = []
296+
filtered_unknown_members = False
276297

277298
for message in messages:
278299
cleaned_content: list[ContentBlock] = []
279300

280301
for content_block in message["content"]:
302+
# Filter out SDK_UNKNOWN_MEMBER content blocks
303+
if "SDK_UNKNOWN_MEMBER" in content_block:
304+
filtered_unknown_members = True
305+
continue
306+
281307
if "toolResult" in content_block:
282308
# Create a new content block with only the cleaned toolResult
283309
tool_result: ToolResult = content_block["toolResult"]
284310

285-
# Keep only the required fields for Bedrock
286-
cleaned_tool_result = ToolResult(
287-
content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
288-
)
311+
if self._should_include_tool_result_status():
312+
# Include status field
313+
cleaned_tool_result = ToolResult(
314+
content=tool_result["content"],
315+
toolUseId=tool_result["toolUseId"],
316+
status=tool_result["status"],
317+
)
318+
else:
319+
# Remove status field
320+
cleaned_tool_result = ToolResult( # type: ignore[typeddict-item]
321+
toolUseId=tool_result["toolUseId"], content=tool_result["content"]
322+
)
289323

290324
cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
291325
cleaned_content.append(cleaned_block)
@@ -297,6 +331,11 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
297331
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
298332
cleaned_messages.append(cleaned_message)
299333

334+
if filtered_unknown_members:
335+
logger.warning(
336+
"Filtered out SDK_UNKNOWN_MEMBER content blocks from messages, consider upgrading boto3 version"
337+
)
338+
300339
return cleaned_messages
301340

302341
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:

src/strands/types/_events.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,24 +275,19 @@ def is_callback_event(self) -> bool:
275275
class ToolStreamEvent(TypedEvent):
276276
"""Event emitted when a tool yields sub-events as part of tool execution."""
277277

278-
def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None:
278+
def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None:
279279
"""Initialize with tool streaming data.
280280
281281
Args:
282282
tool_use: The tool invocation producing the stream
283-
tool_sub_event: The yielded event from the tool execution
283+
tool_stream_data: The yielded event from the tool execution
284284
"""
285-
super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event})
285+
super().__init__({"tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}})
286286

287287
@property
288288
def tool_use_id(self) -> str:
289289
"""The toolUseId associated with this stream."""
290-
return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId"))
291-
292-
@property
293-
@override
294-
def is_callback_event(self) -> bool:
295-
return False
290+
return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId"))
296291

297292

298293
class ModelMessageEvent(TypedEvent):

tests/strands/agent/hooks/test_agent_events.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,22 @@ async def test_stream_e2e_success(alist):
260260
"role": "assistant",
261261
}
262262
},
263+
{
264+
"tool_stream_event": {
265+
"data": {"tool_streaming": True},
266+
"tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
267+
}
268+
},
269+
{
270+
"tool_stream_event": {
271+
"data": "Final result",
272+
"tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
273+
}
274+
},
263275
{
264276
"message": {
265277
"content": [
266-
{
267-
"toolResult": {
268-
# TODO update this text when we get tool streaming implemented; right now this
269-
# TODO is of the form '<async_generator object streaming_tool at 0x107d18a00>'
270-
"content": [{"text": ANY}],
271-
"status": "success",
272-
"toolUseId": "12345",
273-
}
274-
},
278+
{"toolResult": {"content": [{"text": "Final result"}], "status": "success", "toolUseId": "12345"}}
275279
],
276280
"role": "user",
277281
}

tests/strands/event_loop/test_streaming.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import unittest.mock
2+
from typing import cast
23

34
import pytest
45

56
import strands
67
import strands.event_loop
7-
from strands.types._events import TypedEvent
8+
from strands.types._events import ModelStopReason, TypedEvent
9+
from strands.types.content import Message
810
from strands.types.streaming import (
911
ContentBlockDeltaEvent,
1012
ContentBlockStartEvent,
@@ -565,6 +567,88 @@ async def test_process_stream(response, exp_events, agenerator, alist):
565567
assert non_typed_events == []
566568

567569

570+
def _get_message_from_event(event: ModelStopReason) -> Message:
571+
return cast(Message, event["stop"][1])
572+
573+
574+
@pytest.mark.asyncio
575+
async def test_process_stream_with_no_signature(agenerator, alist):
576+
response = [
577+
{"messageStart": {"role": "assistant"}},
578+
{
579+
"contentBlockDelta": {
580+
"delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}},
581+
"contentBlockIndex": 0,
582+
}
583+
},
584+
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}},
585+
{"contentBlockStop": {"contentBlockIndex": 0}},
586+
{
587+
"contentBlockDelta": {
588+
"delta": {"text": "Sure! Let’s do it"},
589+
"contentBlockIndex": 1,
590+
}
591+
},
592+
{"contentBlockStop": {"contentBlockIndex": 1}},
593+
{"messageStop": {"stopReason": "end_turn"}},
594+
{
595+
"metadata": {
596+
"usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876},
597+
"metrics": {"latencyMs": 2970},
598+
}
599+
},
600+
]
601+
602+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
603+
604+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
605+
606+
message = _get_message_from_event(last_event)
607+
608+
assert "signature" not in message["content"][0]["reasoningContent"]["reasoningText"]
609+
assert message["content"][1]["text"] == "Sure! Let’s do it"
610+
611+
612+
@pytest.mark.asyncio
613+
async def test_process_stream_with_signature(agenerator, alist):
614+
response = [
615+
{"messageStart": {"role": "assistant"}},
616+
{
617+
"contentBlockDelta": {
618+
"delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}},
619+
"contentBlockIndex": 0,
620+
}
621+
},
622+
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}},
623+
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "test-"}}, "contentBlockIndex": 0}},
624+
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "signature"}}, "contentBlockIndex": 0}},
625+
{"contentBlockStop": {"contentBlockIndex": 0}},
626+
{
627+
"contentBlockDelta": {
628+
"delta": {"text": "Sure! Let’s do it"},
629+
"contentBlockIndex": 1,
630+
}
631+
},
632+
{"contentBlockStop": {"contentBlockIndex": 1}},
633+
{"messageStop": {"stopReason": "end_turn"}},
634+
{
635+
"metadata": {
636+
"usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876},
637+
"metrics": {"latencyMs": 2970},
638+
}
639+
},
640+
]
641+
642+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
643+
644+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
645+
646+
message = _get_message_from_event(last_event)
647+
648+
assert message["content"][0]["reasoningContent"]["reasoningText"]["signature"] == "test-signature"
649+
assert message["content"][1]["text"] == "Sure! Let’s do it"
650+
651+
568652
@pytest.mark.asyncio
569653
async def test_stream_messages(agenerator, alist):
570654
mock_model = unittest.mock.MagicMock()

0 commit comments

Comments
 (0)