Skip to content

Commit f479a1d

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 274d22b + 29741dd commit f479a1d

3 files changed

Lines changed: 293 additions & 13 deletions

File tree

invokeai/app/services/shared/README.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
9696
### 4.1 Data
9797

9898
- `graph: Graph` - source graph for the run; treated as stable during normal execution.
99-
- `execution_graph: Graph` - materialized runtime nodes/edges.
99+
- `execution_graph: Graph` - materialized runtime nodes/edges. This is mutable runtime state, not an immutable audit
100+
log. Lazy `If` pruning may remove unselected input edges during execution, so persisted failed/completed session
101+
snapshots can contain a structurally pruned execution graph. Retry paths rebuild from `graph`, not from a previously
102+
persisted `execution_graph`.
100103
- `executed: set[str]`, `executed_history: list[str]`.
101104
- `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`.
102105
- `prepared_source_mapping: dict[str, str]` - exec id -> source id.
@@ -123,7 +126,8 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
123126
- `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus
124127
cached metadata such as iteration path and runtime state.
125128
- `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of
126-
ready work.
129+
ready work. When matching prepared parents for a downstream exec node, skipped prepared exec nodes are ignored and
130+
cannot be selected as live inputs.
127131
- `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion.
128132
- `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes.
129133
- `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then
@@ -178,7 +182,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
178182
- For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so
179183
collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming
180184
`item` values are appended.
181-
- For **IfInvocation**: hydrate only `condition` and the selected branch input.
185+
- For **IfInvocation**: hydrate only `condition` and the selected branch input. As a defensive guard against
186+
inconsistent runtime or deserialized session state, the runtime raises if the selected input edge points at an exec
187+
node with no stored runtime output. In normal scheduling this path should be unreachable.
182188
- For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation
183189
through shared references.
184190

@@ -191,7 +197,11 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
191197
- Once the prepared `If` node resolves its condition:
192198
- the selected branch is released
193199
- the unselected branch is marked skipped
200+
- unselected input edges on the prepared `If` exec node are pruned from the execution graph so they no longer
201+
participate in downstream indegree accounting
194202
- branch-exclusive ancestors of the unselected branch are never executed
203+
- Skipped branch-local exec nodes may still be treated as executed for scheduling purposes, but they do not create
204+
entries in `results`.
195205
- Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph.
196206

197207
This behavior is implemented in the runtime scheduler, not in the invocation body itself.

invokeai/app/services/shared/graph.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]:
184184

185185
def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None:
186186
for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field):
187-
if edge.source.node_id in self._state.executed:
188-
continue
189-
if self._state.indegree[exec_node_id] == 0:
190-
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
191-
self._state.indegree[exec_node_id] -= 1
187+
if edge.source.node_id not in self._state.executed:
188+
if self._state.indegree[exec_node_id] == 0:
189+
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
190+
self._state.indegree[exec_node_id] -= 1
191+
self._state.execution_graph.delete_edge(edge)
192192

193193
def _apply_branch_resolution(
194194
self,
@@ -243,6 +243,10 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool:
243243
return False
244244

245245
def mark_exec_node_skipped(self, exec_node_id: str) -> None:
246+
state = self._state._get_prepared_exec_metadata(exec_node_id).state
247+
if state in ("executed", "skipped"):
248+
return
249+
246250
self._state._remove_from_ready_queues(exec_node_id)
247251
self._state._set_prepared_exec_state(exec_node_id, "skipped")
248252
self._state.executed.add(exec_node_id)
@@ -356,7 +360,7 @@ def _initialize_execution_node(self, exec_node_id: str) -> None:
356360
def _get_collect_iteration_mappings(self, parent_node_ids: list[str]) -> list[tuple[str, str]]:
357361
all_iteration_mappings: list[tuple[str, str]] = []
358362
for source_node_id in parent_node_ids:
359-
prepared_nodes = self._state.source_prepared_mapping[source_node_id]
363+
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
360364
all_iteration_mappings.extend((source_node_id, prepared_id) for prepared_id in prepared_nodes)
361365
return all_iteration_mappings
362366

@@ -414,7 +418,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None
414418
return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)]
415419

416420
def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]:
417-
return self._state.source_prepared_mapping[source_node_id]
421+
return {
422+
exec_node_id
423+
for exec_node_id in self._state.source_prepared_mapping[source_node_id]
424+
if self._state._get_prepared_exec_metadata(exec_node_id).state != "skipped"
425+
}
418426

419427
def _get_parent_iterator_exec_nodes(
420428
self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str]
@@ -471,10 +479,15 @@ def get_iteration_node(
471479
prepared_iterator_nodes: list[str],
472480
) -> Optional[str]:
473481
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
474-
if len(prepared_nodes) == 1:
482+
if len(prepared_nodes) == 1 and not prepared_iterator_nodes:
475483
return next(iter(prepared_nodes))
476484

477485
parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes)
486+
if len(prepared_nodes) == 1:
487+
prepared_node_id = next(iter(prepared_nodes))
488+
if self._matches_parent_iterators(prepared_node_id, parent_iterators, execution_graph):
489+
return prepared_node_id
490+
return None
478491

479492
direct_iterator_match = self._get_direct_prepared_iterator_match(
480493
prepared_nodes, prepared_iterator_nodes, parent_iterators, execution_graph
@@ -733,6 +746,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) ->
733746
def _get_copied_result_value(self, edge: Edge) -> Any:
734747
return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field))
735748

749+
def _try_get_copied_result_value(self, edge: Edge) -> tuple[bool, Any]:
750+
source_output = self._state.results.get(edge.source.node_id)
751+
if source_output is None:
752+
return False, None
753+
return True, copydeep(getattr(source_output, edge.source.field))
754+
736755
def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]:
737756
item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD)
738757
collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD)
@@ -761,7 +780,20 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E
761780
def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None:
762781
selected_field = self._state._resolved_if_exec_branches.get(node.id)
763782
allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"}
764-
self._set_node_inputs(node, input_edges, allowed_fields)
783+
784+
for edge in input_edges:
785+
if edge.destination.field not in allowed_fields:
786+
continue
787+
788+
found_value, copied_value = self._try_get_copied_result_value(edge)
789+
if not found_value:
790+
iteration_path = self._state._get_iteration_path(node.id)
791+
raise RuntimeError(
792+
"IfInvocation selected input edge points at an exec node with no stored result output: "
793+
f"if_exec_id={node.id}, source_exec_id={edge.source.node_id}, iteration_path={iteration_path}"
794+
)
795+
796+
setattr(node, edge.destination.field, copied_value)
765797

766798
def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None:
767799
self._set_node_inputs(node, input_edges)

0 commit comments

Comments
 (0)