Skip to content

Commit b8a09a6

Browse files
committed
Merge branch 'main' of https://github.com/strands-agents/sdk-python into avoid-event-loop-block
2 parents af315e3 + b568864 commit b8a09a6

5 files changed

Lines changed: 353 additions & 81 deletions

File tree

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,14 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
132132
model_id=model_id,
133133
)
134134
with trace_api.use_span(model_invoke_span):
135-
tool_specs = agent.tool_registry.get_all_tool_specs()
136-
137135
agent.hooks.invoke_callbacks(
138136
BeforeModelInvocationEvent(
139137
agent=agent,
140138
)
141139
)
142140

141+
tool_specs = agent.tool_registry.get_all_tool_specs()
142+
143143
try:
144144
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
145145
if not isinstance(event, ModelStopReason):

src/strands/event_loop/streaming.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T
289289
"text": "",
290290
"current_tool_use": {},
291291
"reasoningText": "",
292-
"signature": "",
293292
"citationsContent": [],
294293
}
295294
state["content"] = state["message"]["content"]

src/strands/multiagent/graph.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -469,41 +469,32 @@ async def _execute_graph(self) -> None:
469469
ready_nodes.clear()
470470

471471
# Execute current batch of ready nodes concurrently
472-
tasks = [
473-
asyncio.create_task(self._execute_node(node))
474-
for node in current_batch
475-
if node not in self.state.completed_nodes
476-
]
472+
tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch]
477473

478474
for task in tasks:
479475
await task
480476

481477
# Find newly ready nodes after batch execution
482-
ready_nodes.extend(self._find_newly_ready_nodes())
478+
# We add all nodes in current batch as completed batch,
479+
# because a failure would throw exception and code would not make it here
480+
ready_nodes.extend(self._find_newly_ready_nodes(current_batch))
483481

484-
def _find_newly_ready_nodes(self) -> list["GraphNode"]:
482+
def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]:
485483
"""Find nodes that became ready after the last execution."""
486484
newly_ready = []
487485
for _node_id, node in self.nodes.items():
488-
if (
489-
node not in self.state.completed_nodes
490-
and node not in self.state.failed_nodes
491-
and self._is_node_ready_with_conditions(node)
492-
):
486+
if self._is_node_ready_with_conditions(node, completed_batch):
493487
newly_ready.append(node)
494488
return newly_ready
495489

496-
def _is_node_ready_with_conditions(self, node: GraphNode) -> bool:
490+
def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool:
497491
"""Check if a node is ready considering conditional edges."""
498492
# Get incoming edges to this node
499493
incoming_edges = [edge for edge in self.edges if edge.to_node == node]
500494

501-
if not incoming_edges:
502-
return node in self.entry_points
503-
504495
# Check if at least one incoming edge condition is satisfied
505496
for edge in incoming_edges:
506-
if edge.from_node in self.state.completed_nodes:
497+
if edge.from_node in completed_batch:
507498
if edge.should_traverse(self.state):
508499
logger.debug(
509500
"from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id

tests/strands/event_loop/test_streaming.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import unittest.mock
2+
from typing import cast
23

34
import pytest
45

56
import strands
67
import strands.event_loop
7-
from strands.types._events import TypedEvent
8+
from strands.types._events import ModelStopReason, TypedEvent
9+
from strands.types.content import Message
810
from strands.types.streaming import (
911
ContentBlockDeltaEvent,
1012
ContentBlockStartEvent,
@@ -565,6 +567,88 @@ async def test_process_stream(response, exp_events, agenerator, alist):
565567
assert non_typed_events == []
566568

567569

570+
def _get_message_from_event(event: ModelStopReason) -> Message:
571+
return cast(Message, event["stop"][1])
572+
573+
574+
@pytest.mark.asyncio
575+
async def test_process_stream_with_no_signature(agenerator, alist):
576+
response = [
577+
{"messageStart": {"role": "assistant"}},
578+
{
579+
"contentBlockDelta": {
580+
"delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}},
581+
"contentBlockIndex": 0,
582+
}
583+
},
584+
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}},
585+
{"contentBlockStop": {"contentBlockIndex": 0}},
586+
{
587+
"contentBlockDelta": {
588+
"delta": {"text": "Sure! Let’s do it"},
589+
"contentBlockIndex": 1,
590+
}
591+
},
592+
{"contentBlockStop": {"contentBlockIndex": 1}},
593+
{"messageStop": {"stopReason": "end_turn"}},
594+
{
595+
"metadata": {
596+
"usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876},
597+
"metrics": {"latencyMs": 2970},
598+
}
599+
},
600+
]
601+
602+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
603+
604+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
605+
606+
message = _get_message_from_event(last_event)
607+
608+
assert "signature" not in message["content"][0]["reasoningContent"]["reasoningText"]
609+
assert message["content"][1]["text"] == "Sure! Let’s do it"
610+
611+
612+
@pytest.mark.asyncio
613+
async def test_process_stream_with_signature(agenerator, alist):
614+
response = [
615+
{"messageStart": {"role": "assistant"}},
616+
{
617+
"contentBlockDelta": {
618+
"delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}},
619+
"contentBlockIndex": 0,
620+
}
621+
},
622+
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}},
623+
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "test-"}}, "contentBlockIndex": 0}},
624+
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "signature"}}, "contentBlockIndex": 0}},
625+
{"contentBlockStop": {"contentBlockIndex": 0}},
626+
{
627+
"contentBlockDelta": {
628+
"delta": {"text": "Sure! Let’s do it"},
629+
"contentBlockIndex": 1,
630+
}
631+
},
632+
{"contentBlockStop": {"contentBlockIndex": 1}},
633+
{"messageStop": {"stopReason": "end_turn"}},
634+
{
635+
"metadata": {
636+
"usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876},
637+
"metrics": {"latencyMs": 2970},
638+
}
639+
},
640+
]
641+
642+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
643+
644+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
645+
646+
message = _get_message_from_event(last_event)
647+
648+
assert message["content"][0]["reasoningContent"]["reasoningText"]["signature"] == "test-signature"
649+
assert message["content"][1]["text"] == "Sure! Let’s do it"
650+
651+
568652
@pytest.mark.asyncio
569653
async def test_stream_messages(agenerator, alist):
570654
mock_model = unittest.mock.MagicMock()

0 commit comments

Comments
 (0)