Skip to content

Commit dcc5958

Browse files
committed
Add test showing how to parse custom products from path annotations while introducing new load and save methods.
1 parent 37d244e commit dcc5958

3 files changed

Lines changed: 53 additions & 2 deletions

File tree

src/_pytask/collect_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def parse_products_from_task_function(
360360
has_produces_decorator = False
361361
has_produces_argument = False
362362
has_annotation = False
363+
has_return = False
363364
out = {}
364365

365366
# Parse products from decorators.
@@ -415,8 +416,12 @@ def parse_products_from_task_function(
415416
)
416417
out = {parameter_name: collected_products}
417418

419+
if "return" in parameters_with_node_annot:
420+
has_return = True
421+
out = {"return": parameters_with_node_annot["return"]}
422+
418423
if (
419-
sum((has_produces_decorator, has_produces_argument, has_annotation))
424+
sum((has_produces_decorator, has_produces_argument, has_annotation, has_return))
420425
>= 2 # noqa: PLR2004
421426
):
422427
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)

src/_pytask/nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def state(self, hash: bool = False) -> str | None: # noqa: A002
7272

7373
def execute(self, **kwargs: Any) -> None:
7474
"""Execute the task."""
75-
self.function(**kwargs)
75+
out = self.function(**kwargs)
76+
77+
if "return" in self.produces:
78+
self.produces["return"].save(out)
7679

7780
def add_report_section(self, when: str, key: str, content: str) -> None:
7881
"""Add sections which will be displayed in report like stdout or stderr."""
@@ -127,6 +130,9 @@ def state(self) -> str | None:
127130
return str(self.path.stat().st_mtime)
128131
return None
129132

133+
def load(self) -> Path:
134+
return self.value
135+
130136

131137
@define(kw_only=True)
132138
class PythonNode(Node):

tests/test_execute.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
import pickle
45
import re
56
import subprocess
67
import sys
@@ -544,3 +545,42 @@ def task_example(
544545
result = runner.invoke(cli, [tmp_path.as_posix()])
545546
assert result.exit_code == ExitCode.COLLECTION_FAILED
546547
assert "Parameter 'dependency'" in result.output
548+
549+
550+
@pytest.mark.end_to_end()
551+
def test_return_with_custom_type_annotation_as_return(runner, tmp_path):
552+
source = """
553+
from pathlib import Path
554+
import pickle
555+
from typing import Any
556+
from typing_extensions import Annotated
557+
import attrs
558+
559+
@attrs.define
560+
class PickleNode:
561+
name: str = ""
562+
path: Path | None = None
563+
value: None = None
564+
565+
def state(self) -> str | None:
566+
if self.path.exists():
567+
return str(self.path.stat().st_mtime)
568+
return None
569+
570+
def load(self) -> Any:
571+
return pickle.loads(self.path.read_bytes())
572+
573+
def save(self, value: Any) -> None:
574+
self.path.write_bytes(pickle.dumps(value))
575+
576+
node = PickleNode("pickled_data", Path(__file__).parent.joinpath("data.pkl"))
577+
578+
def task_example() -> Annotated[int, node]:
579+
return 1
580+
"""
581+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
582+
result = runner.invoke(cli, [tmp_path.as_posix()])
583+
assert result.exit_code == ExitCode.OK
584+
585+
data = pickle.loads(tmp_path.joinpath("data.pkl").read_bytes()) # noqa: S301
586+
assert data == 1

0 commit comments

Comments
 (0)