diff --git a/invokeai/app/services/shared/README.md b/invokeai/app/services/shared/README.md index 113b7a41e54..f92b1f1ea2e 100644 --- a/invokeai/app/services/shared/README.md +++ b/invokeai/app/services/shared/README.md @@ -96,7 +96,10 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre ### 4.1 Data - `graph: Graph` - source graph for the run; treated as stable during normal execution. -- `execution_graph: Graph` - materialized runtime nodes/edges. +- `execution_graph: Graph` - materialized runtime nodes/edges. This is mutable runtime state, not an immutable audit + log. Lazy `If` pruning may remove unselected input edges during execution, so persisted failed/completed session + snapshots can contain a structurally pruned execution graph. Retry paths rebuild from `graph`, not from a previously + persisted `execution_graph`. - `executed: set[str]`, `executed_history: list[str]`. - `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`. - `prepared_source_mapping: dict[str, str]` - exec id -> source id. @@ -123,7 +126,8 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre - `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus cached metadata such as iteration path and runtime state. - `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of - ready work. + ready work. When matching prepared parents for a downstream exec node, skipped prepared exec nodes are ignored and + cannot be selected as live inputs. - `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion. - `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes. - `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then @@ -178,7 +182,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done. - For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming `item` values are appended. -- For **IfInvocation**: hydrate only `condition` and the selected branch input. +- For **IfInvocation**: hydrate only `condition` and the selected branch input. As a defensive guard against + inconsistent runtime or deserialized session state, the runtime raises if the selected input edge points at an exec + node with no stored runtime output. In normal scheduling this path should be unreachable. - For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation through shared references. @@ -191,7 +197,11 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done. - Once the prepared `If` node resolves its condition: - the selected branch is released - the unselected branch is marked skipped + - unselected input edges on the prepared `If` exec node are pruned from the execution graph so they no longer + participate in downstream indegree accounting - branch-exclusive ancestors of the unselected branch are never executed +- Skipped branch-local exec nodes may still be treated as executed for scheduling purposes, but they do not create + entries in `results`. - Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph. This behavior is implemented in the runtime scheduler, not in the invocation body itself. diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 24c1dd1fe4f..74e58e42a8c 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -194,11 +194,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]: def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None: for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field): - if edge.source.node_id in self._state.executed: - continue - if self._state.indegree[exec_node_id] == 0: - raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}") - self._state.indegree[exec_node_id] -= 1 + if edge.source.node_id not in self._state.executed: + if self._state.indegree[exec_node_id] == 0: + raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}") + self._state.indegree[exec_node_id] -= 1 + self._state.execution_graph.delete_edge(edge) def _apply_branch_resolution( self, @@ -253,6 +253,10 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool: return False def mark_exec_node_skipped(self, exec_node_id: str) -> None: + state = self._state._get_prepared_exec_metadata(exec_node_id).state + if state in ("executed", "skipped"): + return + self._state._remove_from_ready_queues(exec_node_id) self._state._set_prepared_exec_state(exec_node_id, "skipped") self._state.executed.add(exec_node_id) @@ -366,7 +370,7 @@ def _initialize_execution_node(self, exec_node_id: str) -> None: def _get_collect_iteration_mappings(self, parent_node_ids: list[str]) -> list[tuple[str, str]]: all_iteration_mappings: list[tuple[str, str]] = [] for source_node_id in parent_node_ids: - prepared_nodes = self._state.source_prepared_mapping[source_node_id] + prepared_nodes = self._get_prepared_nodes_for_source(source_node_id) all_iteration_mappings.extend((source_node_id, prepared_id) for prepared_id in prepared_nodes) return all_iteration_mappings @@ -424,7 +428,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)] def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]: - return self._state.source_prepared_mapping[source_node_id] + return { + exec_node_id + for exec_node_id in self._state.source_prepared_mapping[source_node_id] + if self._state._get_prepared_exec_metadata(exec_node_id).state != "skipped" + } def _get_parent_iterator_exec_nodes( self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str] @@ -481,10 +489,15 @@ def get_iteration_node( prepared_iterator_nodes: list[str], ) -> Optional[str]: prepared_nodes = self._get_prepared_nodes_for_source(source_node_id) - if len(prepared_nodes) == 1: + if len(prepared_nodes) == 1 and not prepared_iterator_nodes: return next(iter(prepared_nodes)) parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes) + if len(prepared_nodes) == 1: + prepared_node_id = next(iter(prepared_nodes)) + if self._matches_parent_iterators(prepared_node_id, parent_iterators, execution_graph): + return prepared_node_id + return None direct_iterator_match = self._get_direct_prepared_iterator_match( prepared_nodes, prepared_iterator_nodes, parent_iterators, execution_graph @@ -743,6 +756,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) -> def _get_copied_result_value(self, edge: Edge) -> Any: return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field)) + def _try_get_copied_result_value(self, edge: Edge) -> tuple[bool, Any]: + source_output = self._state.results.get(edge.source.node_id) + if source_output is None: + return False, None + return True, copydeep(getattr(source_output, edge.source.field)) + def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]: item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD) collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD) @@ -771,7 +790,20 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None: selected_field = self._state._resolved_if_exec_branches.get(node.id) allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"} - self._set_node_inputs(node, input_edges, allowed_fields) + + for edge in input_edges: + if edge.destination.field not in allowed_fields: + continue + + found_value, copied_value = self._try_get_copied_result_value(edge) + if not found_value: + iteration_path = self._state._get_iteration_path(node.id) + raise RuntimeError( + "IfInvocation selected input edge points at an exec node with no stored result output: " + f"if_exec_id={node.id}, source_exec_id={edge.source.node_id}, iteration_path={iteration_path}" + ) + + setattr(node, edge.destination.field, copied_value) def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None: self._set_node_inputs(node, input_edges) diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index ffd0ca1559d..7cba90d1677 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -7,7 +7,12 @@ from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation -from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation +from invokeai.app.invocations.primitives import ( + BooleanCollectionInvocation, + BooleanCollectionOutput, + BooleanInvocation, + BooleanOutput, +) from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -18,6 +23,7 @@ # This import must happen before other invoke imports or test in other files(!!) break from tests.test_nodes import ( AnyTypeTestInvocation, + AnyTypeTestInvocationOutput, PromptCollectionTestInvocation, PromptTestInvocation, TextToImageTestInvocation, @@ -750,6 +756,238 @@ def test_if_graph_optimized_behavior_keeps_shared_live_consumers_per_iteration() assert executed_source_ids.count("false_branch") == 2 +def test_if_graph_optimized_behavior_handles_selected_true_branch_with_shared_false_input_ancestor(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=True)) + graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared")) + graph.add_node(AnyTypeTestInvocation(id="true_item", value="true")) + graph.add_node(CollectInvocation(id="shared_collect")) + graph.add_node(CollectInvocation(id="true_collect")) + graph.add_node(IfInvocation(id="if")) + graph.add_node(AnyTypeTestInvocation(id="selected_output")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection")) + graph.add_edge(create_edge("true_item", "value", "true_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input")) + graph.add_edge(create_edge("true_collect", "collection", "if", "true_input")) + graph.add_edge(create_edge("if", "value", "selected_output", "value")) + + g = GraphExecutionState(graph=graph) + executed_source_ids = execute_all_nodes(g) + + prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"])) + assert g.results[prepared_selected_output_id].value == ["shared", "true"] + assert set(executed_source_ids) == { + "condition", + "shared_item", + "true_item", + "shared_collect", + "true_collect", + "if", + "selected_output", + } + + +def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_true_input_ancestor(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=False)) + graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared")) + graph.add_node(AnyTypeTestInvocation(id="true_item", value="true")) + graph.add_node(CollectInvocation(id="shared_collect")) + graph.add_node(CollectInvocation(id="true_collect")) + graph.add_node(IfInvocation(id="if")) + graph.add_node(AnyTypeTestInvocation(id="selected_output")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection")) + graph.add_edge(create_edge("true_item", "value", "true_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input")) + graph.add_edge(create_edge("true_collect", "collection", "if", "true_input")) + graph.add_edge(create_edge("if", "value", "selected_output", "value")) + + g = GraphExecutionState(graph=graph) + executed_source_ids = execute_all_nodes(g) + + prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"])) + assert g.results[prepared_selected_output_id].value == ["shared"] + assert set(executed_source_ids) == { + "condition", + "shared_item", + "shared_collect", + "if", + "selected_output", + } + assert "true_item" not in executed_source_ids + assert "true_collect" not in executed_source_ids + + +def test_prepare_if_inputs_raises_when_selected_branch_source_has_no_result(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=True)) + graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch")) + graph.add_node(IfInvocation(id="if")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("true_value", "prompt", "if", "true_input")) + + g = GraphExecutionState(graph=graph) + + condition_exec_id = g._create_execution_node("condition", [])[0] + true_value_exec_id = g._create_execution_node("true_value", [])[0] + if_exec_id = g._create_execution_node( + "if", + [("condition", condition_exec_id), ("true_value", true_value_exec_id)], + )[0] + + g.executed.add(condition_exec_id) + g.results[condition_exec_id] = BooleanOutput(value=True) + g.executed.add(true_value_exec_id) + g._resolved_if_exec_branches[if_exec_id] = "true_input" + + if_node = g.execution_graph.get_node(if_exec_id) + with pytest.raises(RuntimeError) as exc_info: + g._prepare_inputs(if_node) + + message = str(exc_info.value) + assert if_exec_id in message + assert true_value_exec_id in message + assert "iteration_path=()" in message + + +def test_get_collect_iteration_mappings_ignores_skipped_prepared_exec_nodes(): + graph = Graph() + graph.add_node(AnyTypeTestInvocation(id="parent", value="value")) + + g = GraphExecutionState(graph=graph) + + skipped_exec_id = g._create_execution_node("parent", [])[0] + active_exec_id = g._create_execution_node("parent", [])[0] + g._set_prepared_exec_state(skipped_exec_id, "skipped") + + mappings = g._materializer()._get_collect_iteration_mappings(["parent"]) + + assert mappings == [("parent", active_exec_id)] + + +def test_get_iteration_node_ignores_skipped_prepared_exec_nodes(): + graph = Graph() + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) + + g = GraphExecutionState(graph=graph) + + skipped_exec_id = g._create_execution_node("value", [])[0] + active_exec_id = g._create_execution_node("value", [])[0] + g._set_prepared_exec_state(skipped_exec_id, "skipped") + + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) + + assert selected_exec_id == active_exec_id + + +def test_get_iteration_node_returns_single_active_prepared_exec_node(): + graph = Graph() + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) + + g = GraphExecutionState(graph=graph) + + active_exec_id = g._create_execution_node("value", [])[0] + + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) + + assert selected_exec_id == active_exec_id + + +def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_exist(): + graph = Graph() + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) + + g = GraphExecutionState(graph=graph) + + skipped_exec_id = g._create_execution_node("value", [])[0] + g._set_prepared_exec_state(skipped_exec_id, "skipped") + + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) + + assert selected_exec_id is None + + +def test_get_iteration_node_does_not_reuse_wrong_iterator_when_only_other_iteration_is_live(): + graph = Graph() + graph.add_node(BooleanCollectionInvocation(id="conditions", collection=[True, False])) + graph.add_node(IterateInvocation(id="condition_iter")) + graph.add_node(AnyTypeTestInvocation(id="value")) + + graph.add_edge(create_edge("conditions", "collection", "condition_iter", "collection")) + graph.add_edge(create_edge("condition_iter", "item", "value", "value")) + + g = GraphExecutionState(graph=graph) + + conditions_exec_id = g._create_execution_node("conditions", [])[0] + g.executed.add(conditions_exec_id) + g.results[conditions_exec_id] = BooleanCollectionOutput(collection=[True, False]) + + iterator_exec_ids = g._create_execution_node("condition_iter", [("conditions", conditions_exec_id)]) + assert len(iterator_exec_ids) == 2 + iterator_exec_ids_by_index = {g.execution_graph.get_node(exec_id).index: exec_id for exec_id in iterator_exec_ids} + first_iter_exec_id = iterator_exec_ids_by_index[0] + second_iter_exec_id = iterator_exec_ids_by_index[1] + + value_exec_ids = [] + value_exec_ids.extend(g._create_execution_node("value", [("condition_iter", first_iter_exec_id)])) + value_exec_ids.extend(g._create_execution_node("value", [("condition_iter", second_iter_exec_id)])) + assert len(value_exec_ids) == 2 + + for exec_id in value_exec_ids: + if g._get_iteration_path(exec_id) == (1,): + active_value_exec_id = exec_id + else: + skipped_value_exec_id = exec_id + + g._set_prepared_exec_state(skipped_value_exec_id, "skipped") + + selected_exec_id = g._get_iteration_node( + "value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [first_iter_exec_id] + ) + + assert selected_exec_id is None + assert active_value_exec_id != skipped_value_exec_id + + +def test_mark_exec_node_skipped_does_not_hide_already_executed_results(): + graph = Graph() + graph.add_node(AnyTypeTestInvocation(id="value", value="value")) + + g = GraphExecutionState(graph=graph) + + exec_id = g._create_execution_node("value", [])[0] + g.results[exec_id] = AnyTypeTestInvocationOutput(value="value") + g.executed.add(exec_id) + g._set_prepared_exec_state(exec_id, "executed") + + g._if_scheduler().mark_exec_node_skipped(exec_id) + + assert g._get_prepared_exec_metadata(exec_id).state == "executed" + assert g.results[exec_id].value == "value" + + +def test_mark_exec_node_skipped_is_idempotent_for_skipped_state(): + graph = Graph() + graph.add_node(AnyTypeTestInvocation(id="value", value="value")) + + g = GraphExecutionState(graph=graph) + + exec_id = g._create_execution_node("value", [])[0] + + g._if_scheduler().mark_exec_node_skipped(exec_id) + g._if_scheduler().mark_exec_node_skipped(exec_id) + + assert g._get_prepared_exec_metadata(exec_id).state == "skipped" + assert g.executed_history.count("value") == 1 + + def test_are_connection_types_compatible_accepts_subclass_to_base(): """A subclass output should be connectable to a base-class input.