Skip to content

Commit 6f2376d

Browse files
authored
Handles partials in source code hash (#1116)
* Handles partials in source code hash This is a stop gap measure to handle partials for the CacheAdapter. I put the change here rather than in the source hash function, since for now it appears that this behavior is specific to the cache adapter.. * Adds unit tests * Check for partial explicitly * Update test doc strings
1 parent d90212f commit 6f2376d

2 files changed

Lines changed: 86 additions & 9 deletions

File tree

hamilton/lifecycle/default.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import shelve
1010
import time
11+
from functools import partial
1112
from typing import Any, Callable, Dict, List, Optional, Type, Union
1213

1314
from hamilton import graph_types, htypes
@@ -359,7 +360,7 @@ def __init__(
359360
def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs):
360361
"""Set `cache_vars` to all nodes if received None during `__init__`"""
361362
self.cache = shelve.open(self.cache_path)
362-
if self.cache_vars == []:
363+
if len(self.cache_vars) == 0:
363364
self.cache_vars = [n.name for n in graph.nodes]
364365

365366
def run_to_execute_node(
@@ -376,7 +377,10 @@ def run_to_execute_node(
376377
if node_name not in self.cache_vars:
377378
return node_callable(**node_kwargs)
378379

379-
node_hash = graph_types.hash_source_code(node_callable, strip=True)
380+
source_of_node_callable = node_callable
381+
while isinstance(source_of_node_callable, partial): # handle partials
382+
source_of_node_callable = source_of_node_callable.func
383+
node_hash = graph_types.hash_source_code(source_of_node_callable, strip=True)
380384
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)
381385

382386
from_cache = self.cache.get(cache_key, None)

tests/lifecycle/test_cache_adapter.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import inspect
1+
import functools
22
import pathlib
33
import shelve
44

@@ -8,12 +8,8 @@
88
from hamilton.lifecycle.default import CacheAdapter
99

1010

11-
def _callable_to_node(callable) -> node.Node:
12-
return node.Node(
13-
name=callable.__name__,
14-
typ=inspect.signature(callable).return_annotation,
15-
callabl=callable,
16-
)
11+
def _callable_to_node(callable, name=None) -> node.Node:
12+
return node.Node.from_fn(callable, name)
1713

1814

1915
@pytest.fixture()
@@ -52,6 +48,37 @@ def A(external_input: int) -> int:
5248
return _callable_to_node(A)
5349

5450

51+
@pytest.fixture()
52+
def node_a_partial():
53+
"""The function A() is a partial"""
54+
55+
def A(external_input: int, remainder: int) -> int:
56+
return external_input % remainder
57+
58+
base_node: node.Node = _callable_to_node(A)
59+
60+
A = functools.partial(A, remainder=7)
61+
base_node._callable = A
62+
del base_node.input_types["remainder"]
63+
return base_node
64+
65+
66+
@pytest.fixture()
67+
def node_a_nested_partial():
68+
"""The function A() is a partial"""
69+
70+
def A(external_input: int, remainder: int, extra: int) -> int:
71+
return external_input % remainder
72+
73+
base_node: node.Node = _callable_to_node(A)
74+
A = functools.partial(A, remainder=7)
75+
A = functools.partial(A, extra=7)
76+
base_node._callable = A
77+
del base_node.input_types["remainder"]
78+
del base_node.input_types["extra"]
79+
return base_node
80+
81+
5582
def test_set_result(hook: CacheAdapter, node_a: node.Node):
5683
"""Hook sets value and assert value in cache"""
5784
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
@@ -138,3 +165,49 @@ def test_commit_nodes_history(hook: CacheAdapter):
138165
# need to reopen the hook cache
139166
with shelve.open(hook.cache_path) as cache:
140167
assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history
168+
169+
170+
def test_partial_handling(hook: CacheAdapter, node_a_partial: node.Node):
171+
"""Tests partial functions are handled properly"""
172+
hook.cache_vars = [node_a_partial.name]
173+
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
174+
node_kwargs = dict(external_input=7)
175+
result = hook.run_to_execute_node(
176+
node_name=node_a_partial.name,
177+
node_kwargs=node_kwargs,
178+
node_callable=node_a_partial.callable,
179+
)
180+
hook.run_after_node_execution(
181+
node_name=node_a_partial.name,
182+
node_kwargs=node_kwargs,
183+
result=result,
184+
)
185+
result2 = hook.run_to_execute_node(
186+
node_name=node_a_partial.name,
187+
node_kwargs=node_kwargs,
188+
node_callable=node_a_partial.callable,
189+
)
190+
assert result2 == result
191+
192+
193+
def test_nested_partial_handling(hook: CacheAdapter, node_a_nested_partial: node.Node):
194+
"""Tests nested partial functions are handled properly"""
195+
hook.cache_vars = [node_a_nested_partial.name]
196+
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
197+
node_kwargs = dict(external_input=7)
198+
result = hook.run_to_execute_node(
199+
node_name=node_a_nested_partial.name,
200+
node_kwargs=node_kwargs,
201+
node_callable=node_a_nested_partial.callable,
202+
)
203+
hook.run_after_node_execution(
204+
node_name=node_a_nested_partial.name,
205+
node_kwargs=node_kwargs,
206+
result=result,
207+
)
208+
result2 = hook.run_to_execute_node(
209+
node_name=node_a_nested_partial.name,
210+
node_kwargs=node_kwargs,
211+
node_callable=node_a_nested_partial.callable,
212+
)
213+
assert result2 == result

0 commit comments

Comments
 (0)