Skip to content

Commit 3050c12

Browse files
ArnavBalyanskrawcz
authored andcommitted
update
1 parent bb54dc5 commit 3050c12

3 files changed

Lines changed: 22 additions & 5 deletions

File tree

hamilton/function_modifiers/adapters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -872,13 +872,13 @@ def save_json_data(data: pd.DataFrame, json_path: str = "data/my_saved_data.json
872872

873873
def validate(self, fn: Callable):
874874
"""Validates that the function output is a dict type."""
875-
return_annotation = inspect.signature(fn).return_annotation
876-
if return_annotation is inspect.Signature.empty:
875+
return_annotation = typing.get_type_hints(fn).get("return")
876+
if return_annotation is None:
877877
raise InvalidDecoratorException(
878878
f"Function: {fn.__qualname__} must have a return annotation."
879879
)
880-
# check that the return type is a dict
881-
if return_annotation not in (dict, dict):
880+
origin = typing.get_origin(return_annotation)
881+
if return_annotation is not dict and origin is not dict:
882882
raise InvalidDecoratorException(f"Function: {fn.__qualname__} must return a dict.")
883883

884884
def generate_nodes(self, fn: Callable, config) -> list[node.Node]:

tests/function_modifiers/test_adapters.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,17 @@ def test_dataloader_future_annotations():
835835
assert custom_subclass_check(fg["sample_dataloader"].type, list)
836836

837837

838+
def test_datasaver_future_annotations():
839+
from tests.resources import nodes_with_future_annotation
840+
841+
fn_to_collect = nodes_with_future_annotation.sample_datasaver
842+
fg = graph.create_function_graph(
843+
ad_hoc_utils.create_temporary_module(fn_to_collect),
844+
config={},
845+
)
846+
assert "sample_datasaver" in fg
847+
848+
838849
def test_datasaver():
839850
annotation = datasaver()
840851
(node1,) = annotation.generate_nodes(correct_ds_function, {})

tests/resources/nodes_with_future_annotation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from hamilton.function_modifiers import dataloader
20+
from hamilton.function_modifiers import dataloader, datasaver
2121
from hamilton.htypes import Collect, Parallelizable
2222

2323
"""Tests future annotations with common node types"""
@@ -41,3 +41,9 @@ def collected(standard: Collect[int]) -> int:
4141
def sample_dataloader() -> tuple[list[str], dict]:
4242
"""Grouping here as the rest test annotations"""
4343
return ["a", "b", "c"], {}
44+
45+
46+
@datasaver()
47+
def sample_datasaver(standard: int) -> dict:
48+
"""Grouping here as the rest test annotations"""
49+
return {"saved": standard}

0 commit comments

Comments
 (0)