-
Notifications
You must be signed in to change notification settings - Fork 185
Expand file tree
/
Copy pathtest_node_grouping.py
More file actions
154 lines (133 loc) · 6.37 KB
/
test_node_grouping.py
File metadata and controls
154 lines (133 loc) · 6.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from hamilton import ad_hoc_utils, graph, node
from hamilton.execution import grouping
from hamilton.execution.grouping import (
GroupByRepeatableBlocks,
GroupNodesAllAsOne,
GroupNodesByLevel,
GroupNodesIndividually,
NodeGroupPurpose,
)
from hamilton.graph import FunctionGraph
from hamilton.lifecycle import base as lifecycle_base
from hamilton.node import NodeType
from tests.resources.dynamic_parallelism import no_parallel, parallel_complex, parallel_linear_basic
def test_group_individually():
fn_graph = FunctionGraph.from_modules(no_parallel, config={})
node_grouper = GroupNodesIndividually()
nodes_grouped = node_grouper.group_nodes(list(fn_graph.nodes.values()))
assert len(nodes_grouped) == len(fn_graph.nodes)
def test_group_all_as_one():
fn_graph = FunctionGraph.from_modules(no_parallel, config={})
node_grouper = GroupNodesAllAsOne()
nodes_grouped = node_grouper.group_nodes(list(fn_graph.nodes.values()))
assert len(nodes_grouped) == 1
def test_group_by_level():
# No good reason you'd ever really want to group this way, but this helps test the grouping
# system
fn_graph = FunctionGraph.from_modules(no_parallel, config={})
node_grouper = GroupNodesByLevel()
nodes_grouped = node_grouper.group_nodes(list(fn_graph.nodes.values()))
assert len(nodes_grouped) == 6 # Two are in the same group
def test_group_nodes_by_repeatable_blocks():
fn_graph = FunctionGraph.from_modules(parallel_linear_basic, config={})
node_grouper = GroupByRepeatableBlocks()
nodes_grouped_by_name = {
group.base_id: group for group in node_grouper.group_nodes(list(fn_graph.nodes.values()))
}
assert len(nodes_grouped_by_name["collect-steps"].nodes) == 1
assert len(nodes_grouped_by_name["expand-steps"].nodes) == 1
assert len(nodes_grouped_by_name["block-steps"].nodes) == 3
assert len(nodes_grouped_by_name["number_of_steps"].nodes) == 1
assert len(nodes_grouped_by_name["final"].nodes) == 1
assert nodes_grouped_by_name["final"].purpose == NodeGroupPurpose.EXECUTE_SINGLE
assert nodes_grouped_by_name["number_of_steps"].purpose == NodeGroupPurpose.EXECUTE_SINGLE
assert nodes_grouped_by_name["block-steps"].purpose == NodeGroupPurpose.EXECUTE_BLOCK
assert nodes_grouped_by_name["collect-steps"].purpose == NodeGroupPurpose.GATHER
assert nodes_grouped_by_name["block-steps"].spawning_task_base_id == "expand-steps"
assert nodes_grouped_by_name["collect-steps"].spawning_task_base_id == "expand-steps"
def test_group_nodes_by_repeatable_blocks_complex():
fn_graph = FunctionGraph.from_modules(parallel_complex, config={})
node_grouper = GroupByRepeatableBlocks()
nodes_grouped_by_name = {
group.base_id: group for group in node_grouper.group_nodes(list(fn_graph.nodes.values()))
}
assert len(nodes_grouped_by_name["collect-steps"].nodes) == 1
assert len(nodes_grouped_by_name["expand-steps"].nodes) == 1
assert len(nodes_grouped_by_name["block-steps"].nodes) == 5
# See comments in parallel_complex.py for why this is -- between start/end of parallelizable block
assert nodes_grouped_by_name["number_of_steps"].purpose == NodeGroupPurpose.EXECUTE_SINGLE
assert nodes_grouped_by_name["block-steps"].purpose == NodeGroupPurpose.EXECUTE_BLOCK
assert nodes_grouped_by_name["collect-steps"].purpose == NodeGroupPurpose.GATHER
assert nodes_grouped_by_name["block-steps"].spawning_task_base_id == "expand-steps"
assert nodes_grouped_by_name["collect-steps"].spawning_task_base_id == "expand-steps"
def test_create_task_plan():
fn_graph = FunctionGraph.from_modules(parallel_linear_basic, config={})
node_grouper = GroupByRepeatableBlocks()
nodes_grouped = node_grouper.group_nodes(list(fn_graph.nodes.values()))
task_plan = grouping.create_task_plan(
nodes_grouped, ["final"], {}, lifecycle_base.LifecycleAdapterSet()
)
assert len(task_plan) == 5
task_plan_by_id = {task.base_id: task for task in task_plan}
assert {key: value.base_dependencies for key, value in task_plan_by_id.items()} == {
"expand-steps": ["number_of_steps"],
"block-steps": ["expand-steps"],
"collect-steps": ["block-steps"],
"final": ["collect-steps"],
"number_of_steps": [],
}
def test_task_get_input_vars_not_user_defined():
def bar(foo: int) -> int:
return foo + 1
# This is hacking around function graph which is messy as it is built of larger components
# (modules), and should instead be broken into smaller pieces (functions/nodes), and have utilities
# to create it from those.
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(bar), config={}
)
node_ = fn_graph.nodes["bar"]
task = grouping.TaskSpec(
base_id="bar",
nodes=[node_],
purpose=NodeGroupPurpose.EXECUTE_SINGLE,
outputs_to_compute=["bar"],
overrides={},
adapter=lifecycle_base.LifecycleAdapterSet(),
base_dependencies=[],
spawning_task_base_id=None,
)
assert task.get_input_vars() == (["foo"], [])
def test_task_get_input_vars_with_optional():
def bar(foo: int, baz: int = 1) -> int:
return foo + 1
# This is hacking around function graph which is messy as it is built of larger components
# (modules), and should instead be broken into smaller pieces (functions/nodes), and have utilities
# to create it from those.
fn_graph = graph.FunctionGraph.from_modules(
ad_hoc_utils.create_temporary_module(bar), config={}
)
node_ = fn_graph.nodes["bar"]
task = grouping.TaskSpec(
base_id="bar",
nodes=[node_],
purpose=NodeGroupPurpose.EXECUTE_SINGLE,
outputs_to_compute=["bar"],
overrides={},
adapter=lifecycle_base.LifecycleAdapterSet(),
base_dependencies=[],
spawning_task_base_id=None,
)
assert task.get_input_vars() == (["foo"], ["baz"])
def test_task_get_input_vars_user_defined():
node_ = node.Node(name="foo", typ=int, node_source=NodeType.EXTERNAL)
task = grouping.TaskSpec(
base_id="foo",
nodes=[node_],
purpose=NodeGroupPurpose.EXECUTE_SINGLE,
outputs_to_compute=["foo"],
overrides={},
adapter=lifecycle_base.LifecycleAdapterSet(),
base_dependencies=[],
spawning_task_base_id=None,
)
assert task.get_input_vars() == (["foo"], [])