diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 24c1dd1fe4f..992a0c6dcf3 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -175,17 +175,7 @@ def _has_unresolved_matching_if(self, if_node_id: str, iteration_path: tuple[int return not all(pid in self._state._resolved_if_exec_branches for pid in matching_prepared_if_ids) def _apply_condition_inputs(self, exec_node_id: str, node: IfInvocation) -> bool: - condition_edges = self._state.execution_graph._get_input_edges(exec_node_id, "condition") - if any(edge.source.node_id not in self._state.executed for edge in condition_edges): - return False - - for edge in condition_edges: - setattr( - node, - edge.destination.field, - copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field)), - ) - return True + return self._state._apply_if_condition_inputs(exec_node_id, node) def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]: selected_field = "true_input" if node.condition else "false_input" @@ -1819,6 +1809,73 @@ def _prepare_until_node_ready(self) -> Optional[BaseInvocation]: return next_node + def _reset_runtime_caches(self) -> None: + self._ready_queues = {} + self._active_class = None + self._iteration_path_cache = {} + self._if_branch_exclusive_sources = {} + self._resolved_if_exec_branches = {} + self._prepared_exec_metadata = {} + self._prepared_exec_registry = None + self._if_branch_scheduler = None + self._execution_materializer = None + self._execution_scheduler = None + self._execution_runtime = None + + def _rehydrate_prepared_exec_metadata(self) -> None: + registry = self._prepared_registry() + for exec_node_id, source_node_id in self.prepared_source_mapping.items(): + metadata = registry.get_metadata(exec_node_id) + metadata.source_node_id = source_node_id + metadata.iteration_path = self._get_iteration_path(exec_node_id) + if exec_node_id in self.executed: + metadata.state = "executed" if exec_node_id in self.results else "skipped" + elif self.indegree.get(exec_node_id) == 0: + metadata.state = "ready" + else: + metadata.state = "pending" + + def _apply_if_condition_inputs(self, exec_node_id: str, node: IfInvocation) -> bool: + condition_edges = self.execution_graph._get_input_edges(exec_node_id, "condition") + if any(edge.source.node_id not in self.executed for edge in condition_edges): + return False + + for edge in condition_edges: + setattr( + node, + edge.destination.field, + copydeep(getattr(self.results[edge.source.node_id], edge.source.field)), + ) + return True + + def _rehydrate_resolved_if_exec_branches(self) -> None: + for exec_node_id, node in self.execution_graph.nodes.items(): + if not isinstance(node, IfInvocation): + continue + + if not self._apply_if_condition_inputs(exec_node_id, node): + continue + + self._resolved_if_exec_branches[exec_node_id] = "true_input" if node.condition else "false_input" + + def _rehydrate_ready_queues(self) -> None: + execution_graph = self.execution_graph.nx_graph_flat() + for exec_node_id in nx.topological_sort(execution_graph): + if exec_node_id in self.executed: + continue + if self.indegree.get(exec_node_id) != 0: + continue + self._enqueue_if_ready(exec_node_id) + + def _rehydrate_runtime_state(self) -> None: + self._reset_runtime_caches() + self._rehydrate_prepared_exec_metadata() + self._rehydrate_resolved_if_exec_branches() + self._rehydrate_ready_queues() + + def model_post_init(self, __context: Any) -> None: + self._rehydrate_runtime_state() + model_config = ConfigDict( json_schema_extra={ "required": [ diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index ffd0ca1559d..fb5da1c45a0 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from pydantic import TypeAdapter from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.collections import RangeInvocation @@ -137,6 +138,73 @@ def test_graph_state_collects(): assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) +def test_graph_state_resumes_partially_executed_session_after_json_round_trip(): + graph = Graph() + graph.add_node(RangeInvocation(id="c", start=1, stop=5, step=1)) + graph.add_node(IterateInvocation(id="iter")) + graph.add_node(AddInvocation(id="add", b=1)) + graph.add_node(CollectInvocation(id="collect")) + + graph.add_edge(create_edge("c", "collection", "iter", "collection")) + graph.add_edge(create_edge("iter", "item", "add", "a")) + graph.add_edge(create_edge("add", "value", "collect", "item")) + + state = GraphExecutionState(graph=graph) + + for _ in range(4): + invocation, output = invoke_next(state) + assert invocation is not None + assert output is not None + + raw = state.model_dump_json(warnings=False, exclude_none=True) + resumed = TypeAdapter(GraphExecutionState).validate_json(raw, strict=False) + registry = resumed._prepared_registry() + + assert all( + registry.get_iteration_path(exec_node_id) is not None for exec_node_id in resumed.prepared_source_mapping + ) + + executed_source_ids = execute_all_nodes(resumed) + + assert executed_source_ids + assert "add" in executed_source_ids + assert "collect" in resumed.source_prepared_mapping + + prepared_collect_id = next(iter(resumed.source_prepared_mapping["collect"])) + assert resumed.results[prepared_collect_id].collection == [2, 3, 4, 5] + + +def test_if_graph_state_resumes_resolved_branch_after_json_round_trip(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=True)) + graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch")) + graph.add_node(PromptTestInvocation(id="false_value", prompt="false branch")) + graph.add_node(IfInvocation(id="if")) + graph.add_node(PromptTestInvocation(id="selected_output")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("true_value", "prompt", "if", "true_input")) + graph.add_edge(create_edge("false_value", "prompt", "if", "false_input")) + graph.add_edge(create_edge("if", "value", "selected_output", "prompt")) + + state = GraphExecutionState(graph=graph) + + for _ in range(2): + invocation, output = invoke_next(state) + assert invocation is not None + assert output is not None + + raw = state.model_dump_json(warnings=False, exclude_none=True) + resumed = TypeAdapter(GraphExecutionState).validate_json(raw, strict=False) + + executed_source_ids = execute_all_nodes(resumed) + + prepared_selected_output_id = next(iter(resumed.source_prepared_mapping["selected_output"])) + assert resumed.results[prepared_selected_output_id].prompt == "true branch" + assert set(executed_source_ids) == {"if", "selected_output"} + assert "false_value" not in executed_source_ids + + def test_graph_state_prepares_eagerly(): """Tests that all prepareable nodes are prepared""" graph = Graph()