Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,67 @@ def _prepare_until_node_ready(self) -> Optional[BaseInvocation]:

return next_node

def _reset_runtime_caches(self) -> None:
self._ready_queues = {}
self._active_class = None
self._iteration_path_cache = {}
self._if_branch_exclusive_sources = {}
self._resolved_if_exec_branches = {}
self._prepared_exec_metadata = {}
self._prepared_exec_registry = None
self._if_branch_scheduler = None
self._execution_materializer = None
self._execution_scheduler = None
self._execution_runtime = None

def _rehydrate_prepared_exec_metadata(self) -> None:
registry = self._prepared_registry()
for exec_node_id, source_node_id in self.prepared_source_mapping.items():
metadata = registry.get_metadata(exec_node_id)
metadata.source_node_id = source_node_id
if exec_node_id in self.executed:
metadata.state = "executed" if exec_node_id in self.results else "skipped"
elif self.indegree.get(exec_node_id) == 0:
metadata.state = "ready"
else:
metadata.state = "pending"

def _rehydrate_resolved_if_exec_branches(self) -> None:
for exec_node_id, node in self.execution_graph.nodes.items():
if not isinstance(node, IfInvocation):
continue

condition_edges = self.execution_graph._get_input_edges(exec_node_id, "condition")
if any(edge.source.node_id not in self.executed for edge in condition_edges):
continue

for edge in condition_edges:
setattr(
node,
edge.destination.field,
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
)

self._resolved_if_exec_branches[exec_node_id] = "true_input" if node.condition else "false_input"

def _rehydrate_ready_queues(self) -> None:
execution_graph = self.execution_graph.nx_graph_flat()
for exec_node_id in nx.topological_sort(execution_graph):
if exec_node_id in self.executed:
continue
if self.indegree.get(exec_node_id) != 0:
continue
self._enqueue_if_ready(exec_node_id)

def _rehydrate_runtime_state(self) -> None:
self._reset_runtime_caches()
self._rehydrate_prepared_exec_metadata()
self._rehydrate_resolved_if_exec_branches()
self._rehydrate_ready_queues()

def model_post_init(self, __context: Any) -> None:
self._rehydrate_runtime_state()

model_config = ConfigDict(
json_schema_extra={
"required": [
Expand Down
32 changes: 32 additions & 0 deletions tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock

import pytest
from pydantic import TypeAdapter

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
Expand Down Expand Up @@ -137,6 +138,37 @@ def test_graph_state_collects():
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)


def test_graph_state_resumes_partially_executed_session_after_json_round_trip():
graph = Graph()
graph.add_node(RangeInvocation(id="c", start=1, stop=5, step=1))
graph.add_node(IterateInvocation(id="iter"))
graph.add_node(AddInvocation(id="add", b=1))
graph.add_node(CollectInvocation(id="collect"))

graph.add_edge(create_edge("c", "collection", "iter", "collection"))
graph.add_edge(create_edge("iter", "item", "add", "a"))
graph.add_edge(create_edge("add", "value", "collect", "item"))

state = GraphExecutionState(graph=graph)

for _ in range(4):
invocation, output = invoke_next(state)
assert invocation is not None
assert output is not None

raw = state.model_dump_json(warnings=False, exclude_none=True)
resumed = TypeAdapter(GraphExecutionState).validate_json(raw, strict=False)

executed_source_ids = execute_all_nodes(resumed)

assert executed_source_ids
assert "add" in executed_source_ids
assert "collect" in resumed.source_prepared_mapping

prepared_collect_id = next(iter(resumed.source_prepared_mapping["collect"]))
assert resumed.results[prepared_collect_id].collection == [2, 3, 4, 5]


def test_graph_state_prepares_eagerly():
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
Expand Down
Loading