|
7 | 7 | from invokeai.app.invocations.collections import RangeInvocation |
8 | 8 | from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput |
9 | 9 | from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation |
10 | | -from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation |
| 10 | +from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation, BooleanOutput |
11 | 11 | from invokeai.app.services.shared.graph import ( |
12 | 12 | CollectInvocation, |
13 | 13 | Graph, |
@@ -750,6 +750,146 @@ def test_if_graph_optimized_behavior_keeps_shared_live_consumers_per_iteration() |
750 | 750 | assert executed_source_ids.count("false_branch") == 2 |
751 | 751 |
|
752 | 752 |
|
| 753 | +def test_if_graph_optimized_behavior_handles_selected_true_branch_with_shared_false_input_ancestor(): |
| 754 | + graph = Graph() |
| 755 | + graph.add_node(BooleanInvocation(id="condition", value=True)) |
| 756 | + graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared")) |
| 757 | + graph.add_node(AnyTypeTestInvocation(id="true_item", value="true")) |
| 758 | + graph.add_node(CollectInvocation(id="shared_collect")) |
| 759 | + graph.add_node(CollectInvocation(id="true_collect")) |
| 760 | + graph.add_node(IfInvocation(id="if")) |
| 761 | + graph.add_node(AnyTypeTestInvocation(id="selected_output")) |
| 762 | + |
| 763 | + graph.add_edge(create_edge("condition", "value", "if", "condition")) |
| 764 | + graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item")) |
| 765 | + graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection")) |
| 766 | + graph.add_edge(create_edge("true_item", "value", "true_collect", "item")) |
| 767 | + graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input")) |
| 768 | + graph.add_edge(create_edge("true_collect", "collection", "if", "true_input")) |
| 769 | + graph.add_edge(create_edge("if", "value", "selected_output", "value")) |
| 770 | + |
| 771 | + g = GraphExecutionState(graph=graph) |
| 772 | + executed_source_ids = execute_all_nodes(g) |
| 773 | + |
| 774 | + prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"])) |
| 775 | + assert g.results[prepared_selected_output_id].value == ["shared", "true"] |
| 776 | + assert set(executed_source_ids) == { |
| 777 | + "condition", |
| 778 | + "shared_item", |
| 779 | + "true_item", |
| 780 | + "shared_collect", |
| 781 | + "true_collect", |
| 782 | + "if", |
| 783 | + "selected_output", |
| 784 | + } |
| 785 | + |
| 786 | + |
| 787 | +def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_true_input_ancestor(): |
| 788 | + graph = Graph() |
| 789 | + graph.add_node(BooleanInvocation(id="condition", value=False)) |
| 790 | + graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared")) |
| 791 | + graph.add_node(AnyTypeTestInvocation(id="true_item", value="true")) |
| 792 | + graph.add_node(CollectInvocation(id="shared_collect")) |
| 793 | + graph.add_node(CollectInvocation(id="true_collect")) |
| 794 | + graph.add_node(IfInvocation(id="if")) |
| 795 | + graph.add_node(AnyTypeTestInvocation(id="selected_output")) |
| 796 | + |
| 797 | + graph.add_edge(create_edge("condition", "value", "if", "condition")) |
| 798 | + graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item")) |
| 799 | + graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection")) |
| 800 | + graph.add_edge(create_edge("true_item", "value", "true_collect", "item")) |
| 801 | + graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input")) |
| 802 | + graph.add_edge(create_edge("true_collect", "collection", "if", "true_input")) |
| 803 | + graph.add_edge(create_edge("if", "value", "selected_output", "value")) |
| 804 | + |
| 805 | + g = GraphExecutionState(graph=graph) |
| 806 | + executed_source_ids = execute_all_nodes(g) |
| 807 | + |
| 808 | + prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"])) |
| 809 | + assert g.results[prepared_selected_output_id].value == ["shared"] |
| 810 | + assert set(executed_source_ids) == { |
| 811 | + "condition", |
| 812 | + "shared_item", |
| 813 | + "shared_collect", |
| 814 | + "if", |
| 815 | + "selected_output", |
| 816 | + } |
| 817 | + assert "true_item" not in executed_source_ids |
| 818 | + assert "true_collect" not in executed_source_ids |
| 819 | + |
| 820 | + |
| 821 | +def test_prepare_if_inputs_ignores_selected_branch_sources_without_results(): |
| 822 | + graph = Graph() |
| 823 | + graph.add_node(BooleanInvocation(id="condition", value=True)) |
| 824 | + graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch")) |
| 825 | + graph.add_node(IfInvocation(id="if")) |
| 826 | + |
| 827 | + graph.add_edge(create_edge("condition", "value", "if", "condition")) |
| 828 | + graph.add_edge(create_edge("true_value", "prompt", "if", "true_input")) |
| 829 | + |
| 830 | + g = GraphExecutionState(graph=graph) |
| 831 | + |
| 832 | + condition_exec_id = g._create_execution_node("condition", [])[0] |
| 833 | + true_value_exec_id = g._create_execution_node("true_value", [])[0] |
| 834 | + if_exec_id = g._create_execution_node( |
| 835 | + "if", |
| 836 | + [("condition", condition_exec_id), ("true_value", true_value_exec_id)], |
| 837 | + )[0] |
| 838 | + |
| 839 | + g.executed.add(condition_exec_id) |
| 840 | + g.results[condition_exec_id] = BooleanOutput(value=True) |
| 841 | + g.executed.add(true_value_exec_id) |
| 842 | + g._resolved_if_exec_branches[if_exec_id] = "true_input" |
| 843 | + |
| 844 | + if_node = g.execution_graph.get_node(if_exec_id) |
| 845 | + g._prepare_inputs(if_node) |
| 846 | + |
| 847 | + assert if_node.condition is True |
| 848 | + assert if_node.true_input is None |
| 849 | + |
| 850 | + |
| 851 | +def test_get_iteration_node_ignores_skipped_prepared_exec_nodes(): |
| 852 | + graph = Graph() |
| 853 | + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) |
| 854 | + |
| 855 | + g = GraphExecutionState(graph=graph) |
| 856 | + |
| 857 | + skipped_exec_id = g._create_execution_node("value", [])[0] |
| 858 | + active_exec_id = g._create_execution_node("value", [])[0] |
| 859 | + g._set_prepared_exec_state(skipped_exec_id, "skipped") |
| 860 | + |
| 861 | + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) |
| 862 | + |
| 863 | + assert selected_exec_id == active_exec_id |
| 864 | + |
| 865 | + |
| 866 | +def test_get_iteration_node_returns_single_active_prepared_exec_node(): |
| 867 | + graph = Graph() |
| 868 | + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) |
| 869 | + |
| 870 | + g = GraphExecutionState(graph=graph) |
| 871 | + |
| 872 | + active_exec_id = g._create_execution_node("value", [])[0] |
| 873 | + |
| 874 | + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) |
| 875 | + |
| 876 | + assert selected_exec_id == active_exec_id |
| 877 | + |
| 878 | + |
| 879 | +def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_exist(): |
| 880 | + graph = Graph() |
| 881 | + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) |
| 882 | + |
| 883 | + g = GraphExecutionState(graph=graph) |
| 884 | + |
| 885 | + skipped_exec_id = g._create_execution_node("value", [])[0] |
| 886 | + g._set_prepared_exec_state(skipped_exec_id, "skipped") |
| 887 | + |
| 888 | + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) |
| 889 | + |
| 890 | + assert selected_exec_id is None |
| 891 | + |
| 892 | + |
753 | 893 | def test_are_connection_types_compatible_accepts_subclass_to_base(): |
754 | 894 | """A subclass output should be connectable to a base-class input. |
755 | 895 |
|
|
0 commit comments