Skip to content

Commit 49cee82

Browse files
committed
Allow pytrees in return annotations and function returns.
1 parent b3fbce0 commit 49cee82

10 files changed

Lines changed: 119 additions & 27 deletions

src/_pytask/collect_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,10 @@ def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, Node]
309309
args_with_node_annotation = {}
310310
for name, meta in metas.items():
311311
annot = [
312-
i for i in meta if not isinstance(i, ProductType) and isinstance(i, Node)
312+
i
313+
for i in meta
314+
if not isinstance(i, ProductType)
315+
and all(isinstance(x, Node) for x in tree_leaves(i))
313316
]
314317
if len(annot) >= 2: # noqa: PLR2004
315318
raise ValueError(
@@ -418,7 +421,17 @@ def parse_products_from_task_function(
418421

419422
if "return" in parameters_with_node_annot:
420423
has_return = True
421-
out = {"return": parameters_with_node_annot["return"]}
424+
collected_products = tree_map_with_path(
425+
lambda p, x: _collect_product(
426+
session,
427+
path,
428+
name,
429+
NodeInfo("return", p, x),
430+
is_string_allowed=False,
431+
),
432+
parameters_with_node_annot["return"],
433+
)
434+
out = {"return": collected_products}
422435

423436
if (
424437
sum((has_produces_decorator, has_produces_argument, has_annotation, has_return))
@@ -576,5 +589,5 @@ def _evolve_instance(x: Any, instance_from_annot: Node | None) -> Any:
576589
if not instance_from_annot:
577590
return x
578591

579-
instance_from_annot.value = x
592+
instance_from_annot.set_value(x)
580593
return instance_from_annot

src/_pytask/execute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def pytask_execute_task(session: Session, task: Task) -> bool:
147147

148148
kwargs = {}
149149
for name, value in task.depends_on.items():
150-
kwargs[name] = tree_map(lambda x: x.value, value)
150+
kwargs[name] = tree_map(lambda x: x.load(), value)
151151

152152
for name, value in task.produces.items():
153153
if name in parameters:
154-
kwargs[name] = tree_map(lambda x: x.value, value)
154+
kwargs[name] = tree_map(lambda x: x.load(), value)
155155

156156
task.execute(**kwargs)
157157
return True

src/_pytask/node_protocols.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ class Node(MetaNode, Protocol):
2525

2626
value: Any
2727

28+
def load(self) -> Any:
29+
...
30+
31+
def save(self, value: Any) -> Any:
32+
...
33+
34+
def set_value(self, value: Any) -> Any:
35+
...
36+
2837

2938
@runtime_checkable
3039
class PPathNode(Node, Protocol):

src/_pytask/nodes.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
from pathlib import Path
77
from typing import Any
88
from typing import Callable
9+
from typing import NoReturn
910
from typing import TYPE_CHECKING
1011

1112
from _pytask.node_protocols import MetaNode
1213
from _pytask.node_protocols import Node
1314
from _pytask.tree_util import PyTree
15+
from _pytask.tree_util import tree_leaves
16+
from _pytask.tree_util import tree_structure
1417
from attrs import define
1518
from attrs import field
1619

@@ -76,7 +79,18 @@ def execute(self, **kwargs: Any) -> None:
7679
out = self.function(**kwargs)
7780

7881
if "return" in self.produces:
79-
self.produces["return"].save(out)
82+
structure_out = tree_structure(out)
83+
structure_return = tree_structure(self.produces["return"])
84+
if not structure_out == structure_return:
85+
raise ValueError(
86+
"The structure of the function return does not match the structure "
87+
"of the return annotation."
88+
)
89+
90+
for out_, return_ in zip(
91+
tree_leaves(out), tree_leaves(self.produces["return"])
92+
):
93+
return_.save(out_)
8094

8195
def add_report_section(self, when: str, key: str, content: str) -> None:
8296
"""Add sections which will be displayed in report like stdout or stderr."""
@@ -90,25 +104,20 @@ class PathNode(Node):
90104

91105
name: str = ""
92106
"""Name of the node which makes it identifiable in the DAG."""
93-
_value: Path | None = None
107+
value: Path | None = None
94108
"""Value passed to the decorator which can be requested inside the function."""
95109

96110
@property
97111
def path(self) -> Path:
98112
return self.value
99113

100-
@property
101-
def value(self) -> Path:
102-
return self._value
103-
104-
@value.setter
105-
def value(self, value: Path) -> None:
114+
def set_value(self, value: Path) -> None:
106115
"""Set path and if other attributes are not set, set sensible defaults."""
107116
if not isinstance(value, Path):
108117
raise TypeError("'value' must be a 'pathlib.Path'.")
109118
if not self.name:
110119
self.name = value.as_posix()
111-
self._value = value
120+
self.value = value
112121

113122
@classmethod
114123
@functools.lru_cache
@@ -153,11 +162,20 @@ class PythonNode(Node):
153162

154163
name: str = ""
155164
"""Name of the node."""
156-
value: Any | None = None
165+
value: Any = None
157166
"""Value of the node."""
158167
hash: bool = False # noqa: A003
159168
"""Whether the value should be hashed to determine the state."""
160169

170+
def load(self) -> Any:
171+
return self.value
172+
173+
def save(self, value: Any) -> NoReturn:
174+
raise NotImplementedError
175+
176+
def set_value(self, value: Any) -> None:
177+
self.value = value
178+
161179
def state(self) -> str | None:
162180
"""Calculate state of the node.
163181

src/_pytask/tree_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from optree import tree_leaves as _optree_tree_leaves
1010
from optree import tree_map as _optree_tree_map
1111
from optree import tree_map_with_path as _optree_tree_map_with_path
12+
from optree import tree_structure as _optree_tree_structure
1213

1314

1415
__all__ = [
1516
"tree_leaves",
1617
"tree_map",
1718
"tree_map_with_path",
19+
"tree_structure",
1820
"PyTree",
1921
"TREE_UTIL_LIB_DIRECTORY",
2022
]
@@ -29,3 +31,6 @@
2931
tree_map_with_path = functools.partial(
3032
_optree_tree_map_with_path, none_is_leaf=True, namespace="pytask"
3133
)
34+
tree_structure = functools.partial(
35+
_optree_tree_structure, none_is_leaf=True, namespace="pytask"
36+
)

tests/test_collect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_pytask_collect_node(session, path, node_info, expected):
172172
if result is None:
173173
assert result is expected
174174
else:
175-
assert str(result.value) == str(expected)
175+
assert str(result.load()) == str(expected)
176176

177177

178178
@pytest.mark.unit()

tests/test_collect_command.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ class CustomNode:
529529
def state(self):
530530
return self.value
531531
532+
def load(self): ...
533+
def save(self, value): ...
534+
def set_value(self, value): ...
532535
533536
def task_example(
534537
data = CustomNode("custom", "text"),
@@ -555,21 +558,27 @@ def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path):
555558
class PickleFile:
556559
name: str
557560
path: Path
561+
value: Path
558562
559-
@property
560-
def value(self):
563+
def state(self):
564+
return str(self.path.stat().st_mtime)
565+
566+
def load(self):
561567
with self.path.open("rb") as f:
562568
out = pickle.load(f)
563569
return out
564570
565-
def state(self):
566-
return str(self.path.stat().st_mtime)
571+
def save(self, value):
572+
with self.path.open("wb") as f:
573+
pickle.dump(value, f)
574+
575+
def set_value(self, value): ...
567576
568577
569578
_PATH = Path(__file__).parent.joinpath("in.pkl")
570579
571580
def task_example(
572-
data = PickleFile(_PATH.as_posix(), _PATH),
581+
data = PickleFile(_PATH.as_posix(), _PATH, _PATH),
573582
out: Annotated[Path, Product] = Path("out.txt"),
574583
) -> None:
575584
out.write_text(data)

tests/test_execute.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def test_error_with_multiple_different_dep_annotations(runner, tmp_path):
532532
from pathlib import Path
533533
from typing_extensions import Annotated
534534
from pytask import Product, PythonNode, PathNode
535+
from typing import Any
535536
536537
def task_example(
537538
dependency: Annotated[Any, PythonNode(), PathNode()] = "hello",
@@ -591,6 +592,8 @@ def load(self) -> Any:
591592
def save(self, value: Any) -> None:
592593
self.path.write_bytes(pickle.dumps(value))
593594
595+
def set_value(self, value: Any) -> None: ...
596+
594597
node = PickleNode("pickled_data", Path(__file__).parent.joinpath("data.pkl"))
595598
596599
def task_example() -> Annotated[int, node]:
@@ -602,3 +605,24 @@ def task_example() -> Annotated[int, node]:
602605

603606
data = pickle.loads(tmp_path.joinpath("data.pkl").read_bytes()) # noqa: S301
604607
assert data == 1
608+
609+
610+
@pytest.mark.end_to_end()
611+
def test_return_with_tuple_pathnode_annotation_as_return(runner, tmp_path):
612+
source = """
613+
from pathlib import Path
614+
from typing import Any
615+
from typing_extensions import Annotated
616+
from pytask import PathNode
617+
618+
node1 = PathNode.from_path(Path(__file__).parent.joinpath("file1.txt"))
619+
node2 = PathNode.from_path(Path(__file__).parent.joinpath("file2.txt"))
620+
621+
def task_example() -> Annotated[str, (node1, node2)]:
622+
return "Hello,", "World!"
623+
"""
624+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
625+
result = runner.invoke(cli, [tmp_path.as_posix()])
626+
assert result.exit_code == ExitCode.OK
627+
assert tmp_path.joinpath("file1.txt").read_text() == "Hello,"
628+
assert tmp_path.joinpath("file2.txt").read_text() == "World!"

tests/test_node_protocols.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ class CustomNode:
2222
def state(self):
2323
return self.value
2424
25+
def load(self):
26+
return self.value
27+
28+
def save(self, value):
29+
self.value = value
30+
31+
def set_value(self, value): ...
32+
2533
2634
def task_example(
2735
data = CustomNode("custom", "text"),
@@ -48,21 +56,27 @@ def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path):
4856
class PickleFile:
4957
name: str
5058
path: Path
59+
value: Path
60+
61+
def state(self):
62+
return str(self.path.stat().st_mtime)
5163
52-
@property
53-
def value(self):
64+
def load(self):
5465
with self.path.open("rb") as f:
5566
out = pickle.load(f)
5667
return out
5768
58-
def state(self):
59-
return str(self.path.stat().st_mtime)
69+
def save(self, value):
70+
with self.path.open("wb") as f:
71+
pickle.dump(value, f)
72+
73+
def set_value(self, value): ...
6074
6175
6276
_PATH = Path(__file__).parent.joinpath("in.pkl")
6377
6478
def task_example(
65-
data = PickleFile(_PATH.as_posix(), _PATH),
79+
data = PickleFile(_PATH.as_posix(), _PATH, _PATH),
6680
out: Annotated[Path, Product] = Path("out.txt"),
6781
) -> None:
6882
out.write_text(data)

tests/test_tree_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def task_example():
4343

4444
assert session.exit_code == exit_code
4545

46-
products = tree_map(lambda x: x.value, getattr(session.tasks[0], decorator_name))
46+
products = tree_map(lambda x: x.load(), getattr(session.tasks[0], decorator_name))
4747
expected = {
4848
0: tmp_path / "out.txt",
4949
1: {0: tmp_path / "tuple_out.txt"},

0 commit comments

Comments
 (0)