Skip to content

Commit 71feefa

Browse files
committed
Tighten lazy If runtime edge-case handling
1 parent ce234be commit 71feefa

3 files changed

Lines changed: 101 additions & 16 deletions

File tree

invokeai/app/services/shared/README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
9696
### 4.1 Data
9797

9898
- `graph: Graph` - source graph for the run; treated as stable during normal execution.
99-
- `execution_graph: Graph` - materialized runtime nodes/edges.
99+
- `execution_graph: Graph` - materialized runtime nodes/edges. This is mutable runtime state, not an immutable audit
100+
log. Lazy `If` pruning may remove unselected input edges during execution, so persisted failed/completed session
101+
snapshots can contain a structurally pruned execution graph. Retry paths rebuild from `graph`, not from a previously
102+
persisted `execution_graph`.
100103
- `executed: set[str]`, `executed_history: list[str]`.
101104
- `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`.
102105
- `prepared_source_mapping: dict[str, str]` - exec id -> source id.
@@ -179,9 +182,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
179182
- For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so
180183
collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming
181184
`item` values are appended.
182-
- For **IfInvocation**: hydrate only `condition` and the selected branch input. If the selected branch's upstream exec
183-
node was skipped and therefore produced no runtime output, the branch input is left at its default value (typically
184-
`None`) instead of raising during hydration.
185+
- For **IfInvocation**: hydrate only `condition` and the selected branch input. As a defensive guard against
186+
inconsistent runtime or deserialized session state, the runtime raises if the selected input edge points at an exec
187+
node with no stored runtime output. In normal scheduling this path should be unreachable.
185188
- For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation
186189
through shared references.
187190

invokeai/app/services/shared/graph.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool:
243243
return False
244244

245245
def mark_exec_node_skipped(self, exec_node_id: str) -> None:
246+
if self._state._get_prepared_exec_metadata(exec_node_id).state == "executed":
247+
return
248+
246249
self._state._remove_from_ready_queues(exec_node_id)
247250
self._state._set_prepared_exec_state(exec_node_id, "skipped")
248251
self._state.executed.add(exec_node_id)
@@ -356,7 +359,7 @@ def _initialize_execution_node(self, exec_node_id: str) -> None:
356359
def _get_collect_iteration_mappings(self, parent_node_ids: list[str]) -> list[tuple[str, str]]:
357360
all_iteration_mappings: list[tuple[str, str]] = []
358361
for source_node_id in parent_node_ids:
359-
prepared_nodes = self._state.source_prepared_mapping[source_node_id]
362+
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
360363
all_iteration_mappings.extend((source_node_id, prepared_id) for prepared_id in prepared_nodes)
361364
return all_iteration_mappings
362365

@@ -475,10 +478,12 @@ def get_iteration_node(
475478
prepared_iterator_nodes: list[str],
476479
) -> Optional[str]:
477480
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
478-
if len(prepared_nodes) == 1:
479-
return next(iter(prepared_nodes))
480-
481481
parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes)
482+
if len(prepared_nodes) == 1:
483+
prepared_node_id = next(iter(prepared_nodes))
484+
if self._matches_parent_iterators(prepared_node_id, parent_iterators, execution_graph):
485+
return prepared_node_id
486+
return None
482487

483488
direct_iterator_match = self._get_direct_prepared_iterator_match(
484489
prepared_nodes, prepared_iterator_nodes, parent_iterators, execution_graph
@@ -778,9 +783,9 @@ def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> Non
778783

779784
found_value, copied_value = self._try_get_copied_result_value(edge)
780785
if not found_value:
781-
# A skipped branch-local exec node is considered executed for scheduling purposes, but it does not
782-
# produce an output payload. Leave the optional branch input at its default None instead of crashing.
783-
continue
786+
raise RuntimeError(
787+
"IfInvocation selected input edge points at an exec node with no stored result output"
788+
)
784789

785790
setattr(node, edge.destination.field, copied_value)
786791

tests/test_graph_execution_state.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from invokeai.app.invocations.collections import RangeInvocation
99
from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput
1010
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
11-
from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation, BooleanOutput
11+
from invokeai.app.invocations.primitives import (
12+
BooleanCollectionInvocation,
13+
BooleanCollectionOutput,
14+
BooleanInvocation,
15+
BooleanOutput,
16+
)
1217
from invokeai.app.services.shared.graph import (
1318
CollectInvocation,
1419
Graph,
@@ -886,7 +891,7 @@ def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_t
886891
assert "true_collect" not in executed_source_ids
887892

888893

889-
def test_prepare_if_inputs_ignores_selected_branch_sources_without_results():
894+
def test_prepare_if_inputs_raises_when_selected_branch_source_has_no_result():
890895
graph = Graph()
891896
graph.add_node(BooleanInvocation(id="condition", value=True))
892897
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
@@ -910,10 +915,23 @@ def test_prepare_if_inputs_ignores_selected_branch_sources_without_results():
910915
g._resolved_if_exec_branches[if_exec_id] = "true_input"
911916

912917
if_node = g.execution_graph.get_node(if_exec_id)
913-
g._prepare_inputs(if_node)
918+
with pytest.raises(RuntimeError, match="selected input edge"):
919+
g._prepare_inputs(if_node)
920+
921+
922+
def test_get_collect_iteration_mappings_ignores_skipped_prepared_exec_nodes():
923+
graph = Graph()
924+
graph.add_node(AnyTypeTestInvocation(id="parent", value="value"))
925+
926+
g = GraphExecutionState(graph=graph)
927+
928+
skipped_exec_id = g._create_execution_node("parent", [])[0]
929+
active_exec_id = g._create_execution_node("parent", [])[0]
930+
g._set_prepared_exec_state(skipped_exec_id, "skipped")
931+
932+
mappings = g._materializer()._get_collect_iteration_mappings(["parent"])
914933

915-
assert if_node.condition is True
916-
assert if_node.true_input is None
934+
assert mappings == [("parent", active_exec_id)]
917935

918936

919937
def test_get_iteration_node_ignores_skipped_prepared_exec_nodes():
@@ -958,6 +976,65 @@ def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_e
958976
assert selected_exec_id is None
959977

960978

979+
def test_get_iteration_node_does_not_reuse_wrong_iterator_when_only_other_iteration_is_live():
980+
graph = Graph()
981+
graph.add_node(BooleanCollectionInvocation(id="conditions", collection=[True, False]))
982+
graph.add_node(IterateInvocation(id="condition_iter"))
983+
graph.add_node(AnyTypeTestInvocation(id="value"))
984+
985+
graph.add_edge(create_edge("conditions", "collection", "condition_iter", "collection"))
986+
graph.add_edge(create_edge("condition_iter", "item", "value", "value"))
987+
988+
g = GraphExecutionState(graph=graph)
989+
990+
conditions_exec_id = g._create_execution_node("conditions", [])[0]
991+
g.executed.add(conditions_exec_id)
992+
g.results[conditions_exec_id] = BooleanCollectionOutput(collection=[True, False])
993+
994+
iterator_exec_ids = g._create_execution_node("condition_iter", [("conditions", conditions_exec_id)])
995+
assert len(iterator_exec_ids) == 2
996+
iterator_exec_ids_by_index = {g.execution_graph.get_node(exec_id).index: exec_id for exec_id in iterator_exec_ids}
997+
first_iter_exec_id = iterator_exec_ids_by_index[0]
998+
second_iter_exec_id = iterator_exec_ids_by_index[1]
999+
1000+
value_exec_ids = []
1001+
value_exec_ids.extend(g._create_execution_node("value", [("condition_iter", first_iter_exec_id)]))
1002+
value_exec_ids.extend(g._create_execution_node("value", [("condition_iter", second_iter_exec_id)]))
1003+
assert len(value_exec_ids) == 2
1004+
1005+
for exec_id in value_exec_ids:
1006+
if g._get_iteration_path(exec_id) == (1,):
1007+
active_value_exec_id = exec_id
1008+
else:
1009+
skipped_value_exec_id = exec_id
1010+
1011+
g._set_prepared_exec_state(skipped_value_exec_id, "skipped")
1012+
1013+
selected_exec_id = g._get_iteration_node(
1014+
"value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [first_iter_exec_id]
1015+
)
1016+
1017+
assert selected_exec_id is None
1018+
assert active_value_exec_id != skipped_value_exec_id
1019+
1020+
1021+
def test_mark_exec_node_skipped_does_not_hide_already_executed_results():
1022+
graph = Graph()
1023+
graph.add_node(AnyTypeTestInvocation(id="value", value="value"))
1024+
1025+
g = GraphExecutionState(graph=graph)
1026+
1027+
exec_id = g._create_execution_node("value", [])[0]
1028+
g.results[exec_id] = AnyTypeTestInvocation(id="result", value="value").invoke(Mock(InvocationContext))
1029+
g.executed.add(exec_id)
1030+
g._set_prepared_exec_state(exec_id, "executed")
1031+
1032+
g._if_scheduler().mark_exec_node_skipped(exec_id)
1033+
1034+
assert g._get_prepared_exec_metadata(exec_id).state == "executed"
1035+
assert g.results[exec_id].value == "value"
1036+
1037+
9611038
def test_are_connection_types_compatible_accepts_subclass_to_base():
9621039
"""A subclass output should be connectable to a base-class input.
9631040

0 commit comments

Comments
 (0)