Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions invokeai/app/services/shared/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
### 4.1 Data

- `graph: Graph` - source graph for the run; treated as stable during normal execution.
- `execution_graph: Graph` - materialized runtime nodes/edges.
- `execution_graph: Graph` - materialized runtime nodes/edges. This is mutable runtime state, not an immutable audit
log. Lazy `If` pruning may remove unselected input edges during execution, so persisted failed/completed session
snapshots can contain a structurally pruned execution graph. Retry paths rebuild from `graph`, not from a previously
persisted `execution_graph`.
- `executed: set[str]`, `executed_history: list[str]`.
- `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`.
- `prepared_source_mapping: dict[str, str]` - exec id -> source id.
Expand All @@ -123,7 +126,8 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
- `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus
cached metadata such as iteration path and runtime state.
- `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of
ready work.
ready work. When matching prepared parents for a downstream exec node, skipped prepared exec nodes are ignored and
cannot be selected as live inputs.
- `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion.
- `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes.
- `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then
Expand Down Expand Up @@ -178,7 +182,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
- For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so
collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming
`item` values are appended.
- For **IfInvocation**: hydrate only `condition` and the selected branch input.
- For **IfInvocation**: hydrate only `condition` and the selected branch input. As a defensive guard against
inconsistent runtime or deserialized session state, the runtime raises if the selected input edge points at an exec
node with no stored runtime output. In normal scheduling this path should be unreachable.
- For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation
through shared references.

Expand All @@ -191,7 +197,11 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
- Once the prepared `If` node resolves its condition:
- the selected branch is released
- the unselected branch is marked skipped
- unselected input edges on the prepared `If` exec node are pruned from the execution graph so they no longer
participate in downstream indegree accounting
- branch-exclusive ancestors of the unselected branch are never executed
- Skipped branch-local exec nodes may still be treated as executed for scheduling purposes, but they do not create
entries in `results`.
- Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph.

This behavior is implemented in the runtime scheduler, not in the invocation body itself.
Expand Down
50 changes: 41 additions & 9 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]:

def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None:
for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field):
if edge.source.node_id in self._state.executed:
continue
if self._state.indegree[exec_node_id] == 0:
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
self._state.indegree[exec_node_id] -= 1
if edge.source.node_id not in self._state.executed:
if self._state.indegree[exec_node_id] == 0:
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
self._state.indegree[exec_node_id] -= 1
self._state.execution_graph.delete_edge(edge)

def _apply_branch_resolution(
self,
Expand Down Expand Up @@ -253,6 +253,10 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool:
return False

def mark_exec_node_skipped(self, exec_node_id: str) -> None:
state = self._state._get_prepared_exec_metadata(exec_node_id).state
if state in ("executed", "skipped"):
return

self._state._remove_from_ready_queues(exec_node_id)
self._state._set_prepared_exec_state(exec_node_id, "skipped")
self._state.executed.add(exec_node_id)
Expand Down Expand Up @@ -366,7 +370,7 @@ def _initialize_execution_node(self, exec_node_id: str) -> None:
def _get_collect_iteration_mappings(self, parent_node_ids: list[str]) -> list[tuple[str, str]]:
all_iteration_mappings: list[tuple[str, str]] = []
for source_node_id in parent_node_ids:
prepared_nodes = self._state.source_prepared_mapping[source_node_id]
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
all_iteration_mappings.extend((source_node_id, prepared_id) for prepared_id in prepared_nodes)
return all_iteration_mappings

Expand Down Expand Up @@ -424,7 +428,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None
return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)]

def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]:
return self._state.source_prepared_mapping[source_node_id]
return {
exec_node_id
for exec_node_id in self._state.source_prepared_mapping[source_node_id]
if self._state._get_prepared_exec_metadata(exec_node_id).state != "skipped"
}

def _get_parent_iterator_exec_nodes(
self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str]
Expand Down Expand Up @@ -481,10 +489,15 @@ def get_iteration_node(
prepared_iterator_nodes: list[str],
) -> Optional[str]:
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
if len(prepared_nodes) == 1:
if len(prepared_nodes) == 1 and not prepared_iterator_nodes:
return next(iter(prepared_nodes))

parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes)
if len(prepared_nodes) == 1:
prepared_node_id = next(iter(prepared_nodes))
if self._matches_parent_iterators(prepared_node_id, parent_iterators, execution_graph):
return prepared_node_id
return None

direct_iterator_match = self._get_direct_prepared_iterator_match(
prepared_nodes, prepared_iterator_nodes, parent_iterators, execution_graph
Expand Down Expand Up @@ -743,6 +756,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) ->
def _get_copied_result_value(self, edge: Edge) -> Any:
return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field))

def _try_get_copied_result_value(self, edge: Edge) -> tuple[bool, Any]:
source_output = self._state.results.get(edge.source.node_id)
if source_output is None:
return False, None
return True, copydeep(getattr(source_output, edge.source.field))

def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]:
item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD)
collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD)
Expand Down Expand Up @@ -771,7 +790,20 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E
def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None:
selected_field = self._state._resolved_if_exec_branches.get(node.id)
allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"}
self._set_node_inputs(node, input_edges, allowed_fields)

for edge in input_edges:
if edge.destination.field not in allowed_fields:
continue

found_value, copied_value = self._try_get_copied_result_value(edge)
if not found_value:
iteration_path = self._state._get_iteration_path(node.id)
raise RuntimeError(
"IfInvocation selected input edge points at an exec node with no stored result output: "
f"if_exec_id={node.id}, source_exec_id={edge.source.node_id}, iteration_path={iteration_path}"
)

setattr(node, edge.destination.field, copied_value)

def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None:
self._set_node_inputs(node, input_edges)
Expand Down
Loading
Loading