|
1 | | -import inspect |
| 1 | +import functools |
2 | 2 | import pathlib |
3 | 3 | import shelve |
4 | 4 |
|
|
8 | 8 | from hamilton.lifecycle.default import CacheAdapter |
9 | 9 |
|
10 | 10 |
|
11 | | -def _callable_to_node(callable) -> node.Node: |
12 | | - return node.Node( |
13 | | - name=callable.__name__, |
14 | | - typ=inspect.signature(callable).return_annotation, |
15 | | - callabl=callable, |
16 | | - ) |
| 11 | +def _callable_to_node(callable, name=None) -> node.Node: |
| 12 | + return node.Node.from_fn(callable, name) |
17 | 13 |
|
18 | 14 |
|
19 | 15 | @pytest.fixture() |
@@ -52,6 +48,37 @@ def A(external_input: int) -> int: |
52 | 48 | return _callable_to_node(A) |
53 | 49 |
|
54 | 50 |
|
| 51 | +@pytest.fixture() |
| 52 | +def node_a_partial(): |
| 53 | + """The function A() is a partial""" |
| 54 | + |
| 55 | + def A(external_input: int, remainder: int) -> int: |
| 56 | + return external_input % remainder |
| 57 | + |
| 58 | + base_node: node.Node = _callable_to_node(A) |
| 59 | + |
| 60 | + A = functools.partial(A, remainder=7) |
| 61 | + base_node._callable = A |
| 62 | + del base_node.input_types["remainder"] |
| 63 | + return base_node |
| 64 | + |
| 65 | + |
| 66 | +@pytest.fixture() |
| 67 | +def node_a_nested_partial(): |
| 68 | + """The function A() is a partial""" |
| 69 | + |
| 70 | + def A(external_input: int, remainder: int, extra: int) -> int: |
| 71 | + return external_input % remainder |
| 72 | + |
| 73 | + base_node: node.Node = _callable_to_node(A) |
| 74 | + A = functools.partial(A, remainder=7) |
| 75 | + A = functools.partial(A, extra=7) |
| 76 | + base_node._callable = A |
| 77 | + del base_node.input_types["remainder"] |
| 78 | + del base_node.input_types["extra"] |
| 79 | + return base_node |
| 80 | + |
| 81 | + |
55 | 82 | def test_set_result(hook: CacheAdapter, node_a: node.Node): |
56 | 83 | """Hook sets value and assert value in cache""" |
57 | 84 | node_hash = graph_types.hash_source_code(node_a.callable, strip=True) |
@@ -138,3 +165,49 @@ def test_commit_nodes_history(hook: CacheAdapter): |
138 | 165 | # need to reopen the hook cache |
139 | 166 | with shelve.open(hook.cache_path) as cache: |
140 | 167 | assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history |
| 168 | + |
| 169 | + |
| 170 | +def test_partial_handling(hook: CacheAdapter, node_a_partial: node.Node): |
| 171 | + """Tests partial functions are handled properly""" |
| 172 | + hook.cache_vars = [node_a_partial.name] |
| 173 | + hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache |
| 174 | + node_kwargs = dict(external_input=7) |
| 175 | + result = hook.run_to_execute_node( |
| 176 | + node_name=node_a_partial.name, |
| 177 | + node_kwargs=node_kwargs, |
| 178 | + node_callable=node_a_partial.callable, |
| 179 | + ) |
| 180 | + hook.run_after_node_execution( |
| 181 | + node_name=node_a_partial.name, |
| 182 | + node_kwargs=node_kwargs, |
| 183 | + result=result, |
| 184 | + ) |
| 185 | + result2 = hook.run_to_execute_node( |
| 186 | + node_name=node_a_partial.name, |
| 187 | + node_kwargs=node_kwargs, |
| 188 | + node_callable=node_a_partial.callable, |
| 189 | + ) |
| 190 | + assert result2 == result |
| 191 | + |
| 192 | + |
| 193 | +def test_nested_partial_handling(hook: CacheAdapter, node_a_nested_partial: node.Node): |
| 194 | + """Tests nested partial functions are handled properly""" |
| 195 | + hook.cache_vars = [node_a_nested_partial.name] |
| 196 | + hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache |
| 197 | + node_kwargs = dict(external_input=7) |
| 198 | + result = hook.run_to_execute_node( |
| 199 | + node_name=node_a_nested_partial.name, |
| 200 | + node_kwargs=node_kwargs, |
| 201 | + node_callable=node_a_nested_partial.callable, |
| 202 | + ) |
| 203 | + hook.run_after_node_execution( |
| 204 | + node_name=node_a_nested_partial.name, |
| 205 | + node_kwargs=node_kwargs, |
| 206 | + result=result, |
| 207 | + ) |
| 208 | + result2 = hook.run_to_execute_node( |
| 209 | + node_name=node_a_nested_partial.name, |
| 210 | + node_kwargs=node_kwargs, |
| 211 | + node_callable=node_a_nested_partial.callable, |
| 212 | + ) |
| 213 | + assert result2 == result |
0 commit comments