Skip to content

Commit 25091ae

Browse files
committed
Polish lazy If runtime diagnostics and idempotency
1 parent 71feefa commit 25091ae

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

invokeai/app/services/shared/graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool:
243243
return False
244244

245245
def mark_exec_node_skipped(self, exec_node_id: str) -> None:
246-
if self._state._get_prepared_exec_metadata(exec_node_id).state == "executed":
246+
state = self._state._get_prepared_exec_metadata(exec_node_id).state
247+
if state in ("executed", "skipped"):
247248
return
248249

249250
self._state._remove_from_ready_queues(exec_node_id)
@@ -478,6 +479,9 @@ def get_iteration_node(
478479
prepared_iterator_nodes: list[str],
479480
) -> Optional[str]:
480481
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
482+
if len(prepared_nodes) == 1 and not prepared_iterator_nodes:
483+
return next(iter(prepared_nodes))
484+
481485
parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes)
482486
if len(prepared_nodes) == 1:
483487
prepared_node_id = next(iter(prepared_nodes))
@@ -783,8 +787,10 @@ def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> Non
783787

784788
found_value, copied_value = self._try_get_copied_result_value(edge)
785789
if not found_value:
790+
iteration_path = self._state._get_iteration_path(node.id)
786791
raise RuntimeError(
787-
"IfInvocation selected input edge points at an exec node with no stored result output"
792+
"IfInvocation selected input edge points at an exec node with no stored result output: "
793+
f"if_exec_id={node.id}, source_exec_id={edge.source.node_id}, iteration_path={iteration_path}"
788794
)
789795

790796
setattr(node, edge.destination.field, copied_value)

tests/test_graph_execution_state.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# This import must happen before other invoke imports or test in other files(!!) break
2525
from tests.test_nodes import (
2626
AnyTypeTestInvocation,
27+
AnyTypeTestInvocationOutput,
2728
PromptCollectionTestInvocation,
2829
PromptTestInvocation,
2930
TextToImageTestInvocation,
@@ -915,9 +916,14 @@ def test_prepare_if_inputs_raises_when_selected_branch_source_has_no_result():
915916
g._resolved_if_exec_branches[if_exec_id] = "true_input"
916917

917918
if_node = g.execution_graph.get_node(if_exec_id)
918-
with pytest.raises(RuntimeError, match="selected input edge"):
919+
with pytest.raises(RuntimeError) as exc_info:
919920
g._prepare_inputs(if_node)
920921

922+
message = str(exc_info.value)
923+
assert if_exec_id in message
924+
assert true_value_exec_id in message
925+
assert "iteration_path=()" in message
926+
921927

922928
def test_get_collect_iteration_mappings_ignores_skipped_prepared_exec_nodes():
923929
graph = Graph()
@@ -1025,7 +1031,7 @@ def test_mark_exec_node_skipped_does_not_hide_already_executed_results():
10251031
g = GraphExecutionState(graph=graph)
10261032

10271033
exec_id = g._create_execution_node("value", [])[0]
1028-
g.results[exec_id] = AnyTypeTestInvocation(id="result", value="value").invoke(Mock(InvocationContext))
1034+
g.results[exec_id] = AnyTypeTestInvocationOutput(value="value")
10291035
g.executed.add(exec_id)
10301036
g._set_prepared_exec_state(exec_id, "executed")
10311037

@@ -1035,6 +1041,21 @@ def test_mark_exec_node_skipped_does_not_hide_already_executed_results():
10351041
assert g.results[exec_id].value == "value"
10361042

10371043

1044+
def test_mark_exec_node_skipped_is_idempotent_for_skipped_state():
1045+
graph = Graph()
1046+
graph.add_node(AnyTypeTestInvocation(id="value", value="value"))
1047+
1048+
g = GraphExecutionState(graph=graph)
1049+
1050+
exec_id = g._create_execution_node("value", [])[0]
1051+
1052+
g._if_scheduler().mark_exec_node_skipped(exec_id)
1053+
g._if_scheduler().mark_exec_node_skipped(exec_id)
1054+
1055+
assert g._get_prepared_exec_metadata(exec_id).state == "skipped"
1056+
assert g.executed_history.count("value") == 1
1057+
1058+
10381059
def test_are_connection_types_compatible_accepts_subclass_to_base():
10391060
"""A subclass output should be connectable to a base-class input.
10401061

0 commit comments

Comments
 (0)