diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 8da8314ea..07d150f89 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -193,7 +193,7 @@ def reset_executor_state(self) -> None: if hasattr(self.executor, "messages"): self.executor.messages = copy.deepcopy(self._initial_messages) - if hasattr(self.executor, "state"): + if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"): self.executor.state = AgentState(self._initial_state.get()) if hasattr(self.executor, "_model_state"): diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index a6085627c..baf2bba8e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -794,6 +794,32 @@ async def test_node_reset_executor_state(): assert multi_agent_node.result is None +def test_node_reset_executor_state_does_not_corrupt_nested_graph_state(): + """Test that reset_executor_state does not overwrite a nested Graph's GraphState with AgentState.""" + inner_agent = create_mock_agent("inner_agent") + builder = GraphBuilder() + builder.add_node(inner_agent, "inner_a") + inner_graph = builder.build() + + # The nested Graph's .state is GraphState, which does not have a .get() method + assert isinstance(inner_graph.state, GraphState) + assert not hasattr(inner_graph.state, "get") + + # Wrap the nested graph in a GraphNode + outer_node = GraphNode("outer", inner_graph) + outer_node.execution_status = Status.COMPLETED + + # Before fix: reset_executor_state would call AgentState(self._initial_state.get()) + # and assign it to inner_graph.state, overwriting GraphState with AgentState. + outer_node.reset_executor_state() + + # Verify the nested graph's state was NOT replaced with AgentState + assert isinstance(inner_graph.state, GraphState), ( + "reset_executor_state must not replace a nested Graph's GraphState with AgentState" + ) + assert outer_node.execution_status == Status.PENDING + + def test_graph_dataclasses_and_enums(): """Test dataclass initialization, properties, and enum behavior.""" # Test Status enum