88from invokeai .app .invocations .collections import RangeInvocation
99from invokeai .app .invocations .logic import IfInvocation , IfInvocationOutput
1010from invokeai .app .invocations .math import AddInvocation , MultiplyInvocation
11- from invokeai .app .invocations .primitives import BooleanCollectionInvocation , BooleanInvocation , BooleanOutput
11+ from invokeai .app .invocations .primitives import (
12+ BooleanCollectionInvocation ,
13+ BooleanCollectionOutput ,
14+ BooleanInvocation ,
15+ BooleanOutput ,
16+ )
1217from invokeai .app .services .shared .graph import (
1318 CollectInvocation ,
1419 Graph ,
@@ -886,7 +891,7 @@ def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_t
886891 assert "true_collect" not in executed_source_ids
887892
888893
889- def test_prepare_if_inputs_ignores_selected_branch_sources_without_results ():
894+ def test_prepare_if_inputs_raises_when_selected_branch_source_has_no_result ():
890895 graph = Graph ()
891896 graph .add_node (BooleanInvocation (id = "condition" , value = True ))
892897 graph .add_node (PromptTestInvocation (id = "true_value" , prompt = "true branch" ))
@@ -910,10 +915,23 @@ def test_prepare_if_inputs_ignores_selected_branch_sources_without_results():
910915 g ._resolved_if_exec_branches [if_exec_id ] = "true_input"
911916
912917 if_node = g .execution_graph .get_node (if_exec_id )
913- g ._prepare_inputs (if_node )
918+ with pytest .raises (RuntimeError , match = "selected input edge" ):
919+ g ._prepare_inputs (if_node )
920+
921+
922+ def test_get_collect_iteration_mappings_ignores_skipped_prepared_exec_nodes ():
923+ graph = Graph ()
924+ graph .add_node (AnyTypeTestInvocation (id = "parent" , value = "value" ))
925+
926+ g = GraphExecutionState (graph = graph )
927+
928+ skipped_exec_id = g ._create_execution_node ("parent" , [])[0 ]
929+ active_exec_id = g ._create_execution_node ("parent" , [])[0 ]
930+ g ._set_prepared_exec_state (skipped_exec_id , "skipped" )
931+
932+ mappings = g ._materializer ()._get_collect_iteration_mappings (["parent" ])
914933
915- assert if_node .condition is True
916- assert if_node .true_input is None
934+ assert mappings == [("parent" , active_exec_id )]
917935
918936
919937def test_get_iteration_node_ignores_skipped_prepared_exec_nodes ():
@@ -958,6 +976,65 @@ def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_e
958976 assert selected_exec_id is None
959977
960978
979+ def test_get_iteration_node_does_not_reuse_wrong_iterator_when_only_other_iteration_is_live ():
980+ graph = Graph ()
981+ graph .add_node (BooleanCollectionInvocation (id = "conditions" , collection = [True , False ]))
982+ graph .add_node (IterateInvocation (id = "condition_iter" ))
983+ graph .add_node (AnyTypeTestInvocation (id = "value" ))
984+
985+ graph .add_edge (create_edge ("conditions" , "collection" , "condition_iter" , "collection" ))
986+ graph .add_edge (create_edge ("condition_iter" , "item" , "value" , "value" ))
987+
988+ g = GraphExecutionState (graph = graph )
989+
990+ conditions_exec_id = g ._create_execution_node ("conditions" , [])[0 ]
991+ g .executed .add (conditions_exec_id )
992+ g .results [conditions_exec_id ] = BooleanCollectionOutput (collection = [True , False ])
993+
994+ iterator_exec_ids = g ._create_execution_node ("condition_iter" , [("conditions" , conditions_exec_id )])
995+ assert len (iterator_exec_ids ) == 2
996+ iterator_exec_ids_by_index = {g .execution_graph .get_node (exec_id ).index : exec_id for exec_id in iterator_exec_ids }
997+ first_iter_exec_id = iterator_exec_ids_by_index [0 ]
998+ second_iter_exec_id = iterator_exec_ids_by_index [1 ]
999+
1000+ value_exec_ids = []
1001+ value_exec_ids .extend (g ._create_execution_node ("value" , [("condition_iter" , first_iter_exec_id )]))
1002+ value_exec_ids .extend (g ._create_execution_node ("value" , [("condition_iter" , second_iter_exec_id )]))
1003+ assert len (value_exec_ids ) == 2
1004+
1005+ for exec_id in value_exec_ids :
1006+ if g ._get_iteration_path (exec_id ) == (1 ,):
1007+ active_value_exec_id = exec_id
1008+ else :
1009+ skipped_value_exec_id = exec_id
1010+
1011+ g ._set_prepared_exec_state (skipped_value_exec_id , "skipped" )
1012+
1013+ selected_exec_id = g ._get_iteration_node (
1014+ "value" , graph .nx_graph_flat (), g .execution_graph .nx_graph_flat (), [first_iter_exec_id ]
1015+ )
1016+
1017+ assert selected_exec_id is None
1018+ assert active_value_exec_id != skipped_value_exec_id
1019+
1020+
1021+ def test_mark_exec_node_skipped_does_not_hide_already_executed_results ():
1022+ graph = Graph ()
1023+ graph .add_node (AnyTypeTestInvocation (id = "value" , value = "value" ))
1024+
1025+ g = GraphExecutionState (graph = graph )
1026+
1027+ exec_id = g ._create_execution_node ("value" , [])[0 ]
1028+ g .results [exec_id ] = AnyTypeTestInvocation (id = "result" , value = "value" ).invoke (Mock (InvocationContext ))
1029+ g .executed .add (exec_id )
1030+ g ._set_prepared_exec_state (exec_id , "executed" )
1031+
1032+ g ._if_scheduler ().mark_exec_node_skipped (exec_id )
1033+
1034+ assert g ._get_prepared_exec_metadata (exec_id ).state == "executed"
1035+ assert g .results [exec_id ].value == "value"
1036+
1037+
9611038def test_are_connection_types_compatible_accepts_subclass_to_base ():
9621039 """A subclass output should be connectable to a base-class input.
9631040
0 commit comments