@@ -2113,20 +2113,53 @@ def cancel_callback(event):
21132113 graph = builder .build ()
21142114 graph .hooks .add_callback (BeforeNodeCallEvent , cancel_callback )
21152115
2116- stream = graph .stream_async ("test task" )
2117-
21182116 tru_cancel_event = None
2119- with pytest .raises (RuntimeError , match = cancel_message ):
2120- async for event in stream :
2121- if event .get ("type" ) == "multiagent_node_cancel" :
2122- tru_cancel_event = event
2117+ async for event in graph .stream_async ("test task" ):
2118+ if event .get ("type" ) == "multiagent_node_cancel" :
2119+ tru_cancel_event = event
21232120
21242121 exp_cancel_event = MultiAgentNodeCancelEvent (node_id = "test_agent" , message = cancel_message )
21252122 assert tru_cancel_event == exp_cancel_event
21262123
2127- tru_status = graph .state .status
2128- exp_status = Status .FAILED
2129- assert tru_status == exp_status
2124+ assert graph .state .status == Status .COMPLETED
2125+ assert any (n .node_id == "test_agent" for n in graph .state .completed_nodes )
2126+ assert "test_agent" in graph .state .results
2127+ agent .__call__ .assert_not_called ()
2128+
2129+
2130+ @pytest .mark .asyncio
2131+ async def test_graph_cancel_node_downstream_executes ():
2132+ """Downstream nodes must run after an upstream node is skipped via cancel_node."""
2133+ cancelled_nodes : list [str ] = []
2134+
2135+ def cancel_step_a (event ):
2136+ if event .node_id == "step_a" :
2137+ event .cancel_node = "step_a skipped"
2138+ return event
2139+
2140+ step_a = create_mock_agent ("step_a" , "Should not run" )
2141+ step_b = create_mock_agent ("step_b" , "Step B completed" )
2142+
2143+ builder = GraphBuilder ()
2144+ builder .add_node (step_a , "step_a" )
2145+ builder .add_node (step_b , "step_b" )
2146+ builder .add_edge ("step_a" , "step_b" )
2147+ builder .set_entry_point ("step_a" )
2148+ graph = builder .build ()
2149+ graph .hooks .add_callback (BeforeNodeCallEvent , cancel_step_a )
2150+
2151+ async for event in graph .stream_async ("test task" ):
2152+ if event .get ("type" ) == "multiagent_node_cancel" :
2153+ cancelled_nodes .append (event ["node_id" ])
2154+
2155+ assert cancelled_nodes == ["step_a" ]
2156+ assert graph .state .status == Status .COMPLETED
2157+ step_a .__call__ .assert_not_called ()
2158+ step_b .__call__ .assert_not_called () # stream_async uses stream_async on agent, not __call__
2159+ assert any (n .node_id == "step_a" for n in graph .state .completed_nodes )
2160+ assert any (n .node_id == "step_b" for n in graph .state .completed_nodes )
2161+ assert "step_a" in graph .state .results
2162+ assert "step_b" in graph .state .results
21302163
21312164
21322165def test_graph_interrupt_on_before_node_call_event (interrupt_hook ):
0 commit comments