Skip to content

Commit 3785efb

Browse files
committed
Fix lazy If branch pruning and skipped-parent handling in graph runtime
1 parent 5a0818a commit 3785efb

3 files changed

Lines changed: 178 additions & 10 deletions

File tree

invokeai/app/services/shared/README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
123123
- `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus
124124
cached metadata such as iteration path and runtime state.
125125
- `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of
126-
ready work.
126+
ready work. When matching prepared parents for a downstream exec node, skipped prepared exec nodes are ignored and
127+
cannot be selected as live inputs.
127128
- `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion.
128129
- `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes.
129130
- `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then
@@ -178,7 +179,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
178179
- For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so
179180
collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming
180181
`item` values are appended.
181-
- For **IfInvocation**: hydrate only `condition` and the selected branch input.
182+
- For **IfInvocation**: hydrate only `condition` and the selected branch input. If the selected branch's upstream exec
183+
node was skipped and therefore produced no runtime output, the branch input is left at its default value (typically
184+
`None`) instead of raising during hydration.
182185
- For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation
183186
through shared references.
184187

@@ -191,7 +194,11 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
191194
- Once the prepared `If` node resolves its condition:
192195
- the selected branch is released
193196
- the unselected branch is marked skipped
197+
- unselected input edges on the prepared `If` exec node are pruned from the execution graph so they no longer
198+
participate in downstream indegree accounting
194199
- branch-exclusive ancestors of the unselected branch are never executed
200+
- Skipped branch-local exec nodes may still be treated as executed for scheduling purposes, but they do not create
201+
entries in `results`.
195202
- Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph.
196203

197204
This behavior is implemented in the runtime scheduler, not in the invocation body itself.

invokeai/app/services/shared/graph.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]:
194194

195195
def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None:
196196
for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field):
197-
if edge.source.node_id in self._state.executed:
198-
continue
199-
if self._state.indegree[exec_node_id] == 0:
200-
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
201-
self._state.indegree[exec_node_id] -= 1
197+
if edge.source.node_id not in self._state.executed:
198+
if self._state.indegree[exec_node_id] == 0:
199+
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
200+
self._state.indegree[exec_node_id] -= 1
201+
self._state.execution_graph.delete_edge(edge)
202202

203203
def _apply_branch_resolution(
204204
self,
@@ -424,7 +424,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None
424424
return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)]
425425

426426
def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]:
427-
return self._state.source_prepared_mapping[source_node_id]
427+
return {
428+
exec_node_id
429+
for exec_node_id in self._state.source_prepared_mapping[source_node_id]
430+
if self._state._get_prepared_exec_metadata(exec_node_id).state != "skipped"
431+
}
428432

429433
def _get_parent_iterator_exec_nodes(
430434
self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str]
@@ -743,6 +747,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) ->
743747
def _get_copied_result_value(self, edge: Edge) -> Any:
744748
return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field))
745749

750+
def _try_get_copied_result_value(self, edge: Edge) -> tuple[bool, Any]:
751+
source_output = self._state.results.get(edge.source.node_id)
752+
if source_output is None:
753+
return False, None
754+
return True, copydeep(getattr(source_output, edge.source.field))
755+
746756
def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]:
747757
item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD)
748758
collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD)
@@ -771,7 +781,18 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E
771781
def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None:
772782
selected_field = self._state._resolved_if_exec_branches.get(node.id)
773783
allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"}
774-
self._set_node_inputs(node, input_edges, allowed_fields)
784+
785+
for edge in input_edges:
786+
if edge.destination.field not in allowed_fields:
787+
continue
788+
789+
found_value, copied_value = self._try_get_copied_result_value(edge)
790+
if not found_value:
791+
# A skipped branch-local exec node is considered executed for scheduling purposes, but it does not
792+
# produce an output payload. Leave the optional branch input at its default None instead of crashing.
793+
continue
794+
795+
setattr(node, edge.destination.field, copied_value)
775796

776797
def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None:
777798
self._set_node_inputs(node, input_edges)

tests/test_graph_execution_state.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from invokeai.app.invocations.collections import RangeInvocation
88
from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput
99
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
10-
from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation
10+
from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation, BooleanOutput
1111
from invokeai.app.services.shared.graph import (
1212
CollectInvocation,
1313
Graph,
@@ -750,6 +750,146 @@ def test_if_graph_optimized_behavior_keeps_shared_live_consumers_per_iteration()
750750
assert executed_source_ids.count("false_branch") == 2
751751

752752

753+
def test_if_graph_optimized_behavior_handles_selected_true_branch_with_shared_false_input_ancestor():
754+
graph = Graph()
755+
graph.add_node(BooleanInvocation(id="condition", value=True))
756+
graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared"))
757+
graph.add_node(AnyTypeTestInvocation(id="true_item", value="true"))
758+
graph.add_node(CollectInvocation(id="shared_collect"))
759+
graph.add_node(CollectInvocation(id="true_collect"))
760+
graph.add_node(IfInvocation(id="if"))
761+
graph.add_node(AnyTypeTestInvocation(id="selected_output"))
762+
763+
graph.add_edge(create_edge("condition", "value", "if", "condition"))
764+
graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item"))
765+
graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection"))
766+
graph.add_edge(create_edge("true_item", "value", "true_collect", "item"))
767+
graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input"))
768+
graph.add_edge(create_edge("true_collect", "collection", "if", "true_input"))
769+
graph.add_edge(create_edge("if", "value", "selected_output", "value"))
770+
771+
g = GraphExecutionState(graph=graph)
772+
executed_source_ids = execute_all_nodes(g)
773+
774+
prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
775+
assert g.results[prepared_selected_output_id].value == ["shared", "true"]
776+
assert set(executed_source_ids) == {
777+
"condition",
778+
"shared_item",
779+
"true_item",
780+
"shared_collect",
781+
"true_collect",
782+
"if",
783+
"selected_output",
784+
}
785+
786+
787+
def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_true_input_ancestor():
788+
graph = Graph()
789+
graph.add_node(BooleanInvocation(id="condition", value=False))
790+
graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared"))
791+
graph.add_node(AnyTypeTestInvocation(id="true_item", value="true"))
792+
graph.add_node(CollectInvocation(id="shared_collect"))
793+
graph.add_node(CollectInvocation(id="true_collect"))
794+
graph.add_node(IfInvocation(id="if"))
795+
graph.add_node(AnyTypeTestInvocation(id="selected_output"))
796+
797+
graph.add_edge(create_edge("condition", "value", "if", "condition"))
798+
graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item"))
799+
graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection"))
800+
graph.add_edge(create_edge("true_item", "value", "true_collect", "item"))
801+
graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input"))
802+
graph.add_edge(create_edge("true_collect", "collection", "if", "true_input"))
803+
graph.add_edge(create_edge("if", "value", "selected_output", "value"))
804+
805+
g = GraphExecutionState(graph=graph)
806+
executed_source_ids = execute_all_nodes(g)
807+
808+
prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
809+
assert g.results[prepared_selected_output_id].value == ["shared"]
810+
assert set(executed_source_ids) == {
811+
"condition",
812+
"shared_item",
813+
"shared_collect",
814+
"if",
815+
"selected_output",
816+
}
817+
assert "true_item" not in executed_source_ids
818+
assert "true_collect" not in executed_source_ids
819+
820+
821+
def test_prepare_if_inputs_ignores_selected_branch_sources_without_results():
822+
graph = Graph()
823+
graph.add_node(BooleanInvocation(id="condition", value=True))
824+
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
825+
graph.add_node(IfInvocation(id="if"))
826+
827+
graph.add_edge(create_edge("condition", "value", "if", "condition"))
828+
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))
829+
830+
g = GraphExecutionState(graph=graph)
831+
832+
condition_exec_id = g._create_execution_node("condition", [])[0]
833+
true_value_exec_id = g._create_execution_node("true_value", [])[0]
834+
if_exec_id = g._create_execution_node(
835+
"if",
836+
[("condition", condition_exec_id), ("true_value", true_value_exec_id)],
837+
)[0]
838+
839+
g.executed.add(condition_exec_id)
840+
g.results[condition_exec_id] = BooleanOutput(value=True)
841+
g.executed.add(true_value_exec_id)
842+
g._resolved_if_exec_branches[if_exec_id] = "true_input"
843+
844+
if_node = g.execution_graph.get_node(if_exec_id)
845+
g._prepare_inputs(if_node)
846+
847+
assert if_node.condition is True
848+
assert if_node.true_input is None
849+
850+
851+
def test_get_iteration_node_ignores_skipped_prepared_exec_nodes():
852+
graph = Graph()
853+
graph.add_node(PromptTestInvocation(id="value", prompt="branch value"))
854+
855+
g = GraphExecutionState(graph=graph)
856+
857+
skipped_exec_id = g._create_execution_node("value", [])[0]
858+
active_exec_id = g._create_execution_node("value", [])[0]
859+
g._set_prepared_exec_state(skipped_exec_id, "skipped")
860+
861+
selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [])
862+
863+
assert selected_exec_id == active_exec_id
864+
865+
866+
def test_get_iteration_node_returns_single_active_prepared_exec_node():
867+
graph = Graph()
868+
graph.add_node(PromptTestInvocation(id="value", prompt="branch value"))
869+
870+
g = GraphExecutionState(graph=graph)
871+
872+
active_exec_id = g._create_execution_node("value", [])[0]
873+
874+
selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [])
875+
876+
assert selected_exec_id == active_exec_id
877+
878+
879+
def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_exist():
880+
graph = Graph()
881+
graph.add_node(PromptTestInvocation(id="value", prompt="branch value"))
882+
883+
g = GraphExecutionState(graph=graph)
884+
885+
skipped_exec_id = g._create_execution_node("value", [])[0]
886+
g._set_prepared_exec_state(skipped_exec_id, "skipped")
887+
888+
selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [])
889+
890+
assert selected_exec_id is None
891+
892+
753893
def test_are_connection_types_compatible_accepts_subclass_to_base():
754894
"""A subclass output should be connectable to a base-class input.
755895

0 commit comments

Comments
 (0)