Skip to content

Commit 73b67fc

Browse files
committed
Polish lazy If runtime diagnostics and idempotency
1 parent fdc0f39 commit 73b67fc

2 files changed

Lines changed: 31 additions & 4 deletions

File tree

invokeai/app/services/shared/graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool:
253253
return False
254254

255255
def mark_exec_node_skipped(self, exec_node_id: str) -> None:
256-
if self._state._get_prepared_exec_metadata(exec_node_id).state == "executed":
256+
state = self._state._get_prepared_exec_metadata(exec_node_id).state
257+
if state in ("executed", "skipped"):
257258
return
258259

259260
self._state._remove_from_ready_queues(exec_node_id)
@@ -488,6 +489,9 @@ def get_iteration_node(
488489
prepared_iterator_nodes: list[str],
489490
) -> Optional[str]:
490491
prepared_nodes = self._get_prepared_nodes_for_source(source_node_id)
492+
if len(prepared_nodes) == 1 and not prepared_iterator_nodes:
493+
return next(iter(prepared_nodes))
494+
491495
parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes)
492496
if len(prepared_nodes) == 1:
493497
prepared_node_id = next(iter(prepared_nodes))
@@ -793,8 +797,10 @@ def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> Non
793797

794798
found_value, copied_value = self._try_get_copied_result_value(edge)
795799
if not found_value:
800+
iteration_path = self._state._get_iteration_path(node.id)
796801
raise RuntimeError(
797-
"IfInvocation selected input edge points at an exec node with no stored result output"
802+
"IfInvocation selected input edge points at an exec node with no stored result output: "
803+
f"if_exec_id={node.id}, source_exec_id={edge.source.node_id}, iteration_path={iteration_path}"
798804
)
799805

800806
setattr(node, edge.destination.field, copied_value)

tests/test_graph_execution_state.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# This import must happen before other invoke imports or test in other files(!!) break
2424
from tests.test_nodes import (
2525
AnyTypeTestInvocation,
26+
AnyTypeTestInvocationOutput,
2627
PromptCollectionTestInvocation,
2728
PromptTestInvocation,
2829
TextToImageTestInvocation,
@@ -847,9 +848,14 @@ def test_prepare_if_inputs_raises_when_selected_branch_source_has_no_result():
847848
g._resolved_if_exec_branches[if_exec_id] = "true_input"
848849

849850
if_node = g.execution_graph.get_node(if_exec_id)
850-
with pytest.raises(RuntimeError, match="selected input edge"):
851+
with pytest.raises(RuntimeError) as exc_info:
851852
g._prepare_inputs(if_node)
852853

854+
message = str(exc_info.value)
855+
assert if_exec_id in message
856+
assert true_value_exec_id in message
857+
assert "iteration_path=()" in message
858+
853859

854860
def test_get_collect_iteration_mappings_ignores_skipped_prepared_exec_nodes():
855861
graph = Graph()
@@ -957,7 +963,7 @@ def test_mark_exec_node_skipped_does_not_hide_already_executed_results():
957963
g = GraphExecutionState(graph=graph)
958964

959965
exec_id = g._create_execution_node("value", [])[0]
960-
g.results[exec_id] = AnyTypeTestInvocation(id="result", value="value").invoke(Mock(InvocationContext))
966+
g.results[exec_id] = AnyTypeTestInvocationOutput(value="value")
961967
g.executed.add(exec_id)
962968
g._set_prepared_exec_state(exec_id, "executed")
963969

@@ -967,6 +973,21 @@ def test_mark_exec_node_skipped_does_not_hide_already_executed_results():
967973
assert g.results[exec_id].value == "value"
968974

969975

976+
def test_mark_exec_node_skipped_is_idempotent_for_skipped_state():
977+
graph = Graph()
978+
graph.add_node(AnyTypeTestInvocation(id="value", value="value"))
979+
980+
g = GraphExecutionState(graph=graph)
981+
982+
exec_id = g._create_execution_node("value", [])[0]
983+
984+
g._if_scheduler().mark_exec_node_skipped(exec_id)
985+
g._if_scheduler().mark_exec_node_skipped(exec_id)
986+
987+
assert g._get_prepared_exec_metadata(exec_id).state == "skipped"
988+
assert g.executed_history.count("value") == 1
989+
990+
970991
def test_are_connection_types_compatible_accepts_subclass_to_base():
971992
"""A subclass output should be connectable to a base-class input.
972993

0 commit comments

Comments
 (0)