Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,7 @@ def _has_unresolved_matching_if(self, if_node_id: str, iteration_path: tuple[int
return not all(pid in self._state._resolved_if_exec_branches for pid in matching_prepared_if_ids)

def _apply_condition_inputs(self, exec_node_id: str, node: IfInvocation) -> bool:
condition_edges = self._state.execution_graph._get_input_edges(exec_node_id, "condition")
if any(edge.source.node_id not in self._state.executed for edge in condition_edges):
return False

for edge in condition_edges:
setattr(
node,
edge.destination.field,
copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field)),
)
return True
return self._state._apply_if_condition_inputs(exec_node_id, node)

def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]:
selected_field = "true_input" if node.condition else "false_input"
Expand Down Expand Up @@ -1819,6 +1809,73 @@ def _prepare_until_node_ready(self) -> Optional[BaseInvocation]:

return next_node

def _reset_runtime_caches(self) -> None:
self._ready_queues = {}
self._active_class = None
self._iteration_path_cache = {}
self._if_branch_exclusive_sources = {}
self._resolved_if_exec_branches = {}
self._prepared_exec_metadata = {}
self._prepared_exec_registry = None
self._if_branch_scheduler = None
self._execution_materializer = None
self._execution_scheduler = None
self._execution_runtime = None

def _rehydrate_prepared_exec_metadata(self) -> None:
registry = self._prepared_registry()
for exec_node_id, source_node_id in self.prepared_source_mapping.items():
metadata = registry.get_metadata(exec_node_id)
metadata.source_node_id = source_node_id
metadata.iteration_path = self._get_iteration_path(exec_node_id)
if exec_node_id in self.executed:
metadata.state = "executed" if exec_node_id in self.results else "skipped"
elif self.indegree.get(exec_node_id) == 0:
metadata.state = "ready"
else:
metadata.state = "pending"

def _apply_if_condition_inputs(self, exec_node_id: str, node: IfInvocation) -> bool:
condition_edges = self.execution_graph._get_input_edges(exec_node_id, "condition")
if any(edge.source.node_id not in self.executed for edge in condition_edges):
return False

for edge in condition_edges:
setattr(
node,
edge.destination.field,
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
)
return True

def _rehydrate_resolved_if_exec_branches(self) -> None:
for exec_node_id, node in self.execution_graph.nodes.items():
if not isinstance(node, IfInvocation):
continue

if not self._apply_if_condition_inputs(exec_node_id, node):
continue

self._resolved_if_exec_branches[exec_node_id] = "true_input" if node.condition else "false_input"

def _rehydrate_ready_queues(self) -> None:
execution_graph = self.execution_graph.nx_graph_flat()
for exec_node_id in nx.topological_sort(execution_graph):
if exec_node_id in self.executed:
continue
if self.indegree.get(exec_node_id) != 0:
continue
self._enqueue_if_ready(exec_node_id)

def _rehydrate_runtime_state(self) -> None:
self._reset_runtime_caches()
self._rehydrate_prepared_exec_metadata()
self._rehydrate_resolved_if_exec_branches()
self._rehydrate_ready_queues()

def model_post_init(self, __context: Any) -> None:
self._rehydrate_runtime_state()

model_config = ConfigDict(
json_schema_extra={
"required": [
Expand Down
68 changes: 68 additions & 0 deletions tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock

import pytest
from pydantic import TypeAdapter

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
Expand Down Expand Up @@ -137,6 +138,73 @@ def test_graph_state_collects():
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)


def test_graph_state_resumes_partially_executed_session_after_json_round_trip():
graph = Graph()
graph.add_node(RangeInvocation(id="c", start=1, stop=5, step=1))
graph.add_node(IterateInvocation(id="iter"))
graph.add_node(AddInvocation(id="add", b=1))
graph.add_node(CollectInvocation(id="collect"))

graph.add_edge(create_edge("c", "collection", "iter", "collection"))
graph.add_edge(create_edge("iter", "item", "add", "a"))
graph.add_edge(create_edge("add", "value", "collect", "item"))

state = GraphExecutionState(graph=graph)

for _ in range(4):
invocation, output = invoke_next(state)
assert invocation is not None
assert output is not None

raw = state.model_dump_json(warnings=False, exclude_none=True)
resumed = TypeAdapter(GraphExecutionState).validate_json(raw, strict=False)
registry = resumed._prepared_registry()

assert all(
registry.get_iteration_path(exec_node_id) is not None for exec_node_id in resumed.prepared_source_mapping
)

executed_source_ids = execute_all_nodes(resumed)

assert executed_source_ids
assert "add" in executed_source_ids
assert "collect" in resumed.source_prepared_mapping

prepared_collect_id = next(iter(resumed.source_prepared_mapping["collect"]))
assert resumed.results[prepared_collect_id].collection == [2, 3, 4, 5]


def test_if_graph_state_resumes_resolved_branch_after_json_round_trip():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
graph.add_node(PromptTestInvocation(id="false_value", prompt="false branch"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))

graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))
graph.add_edge(create_edge("false_value", "prompt", "if", "false_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))

state = GraphExecutionState(graph=graph)

for _ in range(2):
invocation, output = invoke_next(state)
assert invocation is not None
assert output is not None

raw = state.model_dump_json(warnings=False, exclude_none=True)
resumed = TypeAdapter(GraphExecutionState).validate_json(raw, strict=False)

executed_source_ids = execute_all_nodes(resumed)

prepared_selected_output_id = next(iter(resumed.source_prepared_mapping["selected_output"]))
assert resumed.results[prepared_selected_output_id].prompt == "true branch"
assert set(executed_source_ids) == {"if", "selected_output"}
assert "false_value" not in executed_source_ids


def test_graph_state_prepares_eagerly():
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
Expand Down
Loading