Skip to content

Commit fddde2e

Browse files
committed
fix(graph): treat cancel_node as control flow, not fatal error
Closes #2240
1 parent 980bc91 commit fddde2e

3 files changed

Lines changed: 63 additions & 14 deletions

File tree

src/strands/hooks/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
330330
source: The multi-agent orchestrator instance
331331
node_id: ID of the node about to execute
332332
invocation_state: Configuration that user passes in
333-
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
334-
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
335-
node using a default cancel message.
333+
cancel_node: A user defined message that when set, will skip the node and mark it as completed, allowing
334+
downstream nodes to continue executing. The message will be emitted under a MultiAgentNodeCancel event.
335+
If set to `True`, Strands will skip the node using a default cancel message.
336336
"""
337337

338338
source: "MultiAgentBase"

src/strands/multiagent/graph.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -899,9 +899,25 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
899899
cancel_message = (
900900
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
901901
)
902-
logger.debug("reason=<%s> | cancelling execution", cancel_message)
902+
logger.debug("reason=<%s> | node skipped, graph continues", cancel_message)
903903
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
904-
raise RuntimeError(cancel_message)
904+
node_result = NodeResult(
905+
result=RuntimeError(cancel_message),
906+
execution_time=0,
907+
status=Status.COMPLETED,
908+
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
909+
accumulated_metrics=Metrics(latencyMs=0),
910+
execution_count=0,
911+
)
912+
node.result = node_result
913+
node.execution_time = 0
914+
node.execution_status = Status.COMPLETED
915+
self.state.completed_nodes.add(node)
916+
self.state.results[node.node_id] = node_result
917+
self.state.execution_order.append(node)
918+
self._accumulate_metrics(node_result)
919+
yield MultiAgentNodeStopEvent(node_id=node.node_id, node_result=node_result)
920+
return
905921

906922
# Build node input from satisfied dependencies
907923
node_input = self._build_node_input(node)

tests/strands/multiagent/test_graph.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,20 +2113,53 @@ def cancel_callback(event):
21132113
graph = builder.build()
21142114
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)
21152115

2116-
stream = graph.stream_async("test task")
2117-
21182116
tru_cancel_event = None
2119-
with pytest.raises(RuntimeError, match=cancel_message):
2120-
async for event in stream:
2121-
if event.get("type") == "multiagent_node_cancel":
2122-
tru_cancel_event = event
2117+
async for event in graph.stream_async("test task"):
2118+
if event.get("type") == "multiagent_node_cancel":
2119+
tru_cancel_event = event
21232120

21242121
exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message)
21252122
assert tru_cancel_event == exp_cancel_event
21262123

2127-
tru_status = graph.state.status
2128-
exp_status = Status.FAILED
2129-
assert tru_status == exp_status
2124+
assert graph.state.status == Status.COMPLETED
2125+
assert any(n.node_id == "test_agent" for n in graph.state.completed_nodes)
2126+
assert "test_agent" in graph.state.results
2127+
agent.__call__.assert_not_called()
2128+
2129+
2130+
@pytest.mark.asyncio
2131+
async def test_graph_cancel_node_downstream_executes():
2132+
"""Downstream nodes must run after an upstream node is skipped via cancel_node."""
2133+
cancelled_nodes: list[str] = []
2134+
2135+
def cancel_step_a(event):
2136+
if event.node_id == "step_a":
2137+
event.cancel_node = "step_a skipped"
2138+
return event
2139+
2140+
step_a = create_mock_agent("step_a", "Should not run")
2141+
step_b = create_mock_agent("step_b", "Step B completed")
2142+
2143+
builder = GraphBuilder()
2144+
builder.add_node(step_a, "step_a")
2145+
builder.add_node(step_b, "step_b")
2146+
builder.add_edge("step_a", "step_b")
2147+
builder.set_entry_point("step_a")
2148+
graph = builder.build()
2149+
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_step_a)
2150+
2151+
async for event in graph.stream_async("test task"):
2152+
if event.get("type") == "multiagent_node_cancel":
2153+
cancelled_nodes.append(event["node_id"])
2154+
2155+
assert cancelled_nodes == ["step_a"]
2156+
assert graph.state.status == Status.COMPLETED
2157+
step_a.__call__.assert_not_called()
2158+
step_b.__call__.assert_not_called() # stream_async uses stream_async on agent, not __call__
2159+
assert any(n.node_id == "step_a" for n in graph.state.completed_nodes)
2160+
assert any(n.node_id == "step_b" for n in graph.state.completed_nodes)
2161+
assert "step_a" in graph.state.results
2162+
assert "step_b" in graph.state.results
21302163

21312164

21322165
def test_graph_interrupt_on_before_node_call_event(interrupt_hook):

0 commit comments

Comments
 (0)