1111from typing import Iterable
1212from typing import TYPE_CHECKING
1313
14+ from _pytask ._inspect import get_annotations
1415from _pytask .exceptions import NodeNotCollectedError
16+ from _pytask .mark_utils import has_mark
1517from _pytask .mark_utils import remove_marks
18+ from _pytask .nodes import ProductType
1619from _pytask .nodes import PythonNode
1720from _pytask .shared import find_duplicates
1821from _pytask .task_utils import parse_keyword_arguments_from_signature_defaults
1922from attrs import define
2023from attrs import field
2124from pybaum .tree_util import tree_map
25+ from typing_extensions import Annotated
26+ from typing_extensions import get_origin
2227
2328
2429if TYPE_CHECKING :
@@ -211,8 +216,12 @@ def parse_dependencies_from_task_function(
211216 kwargs = {** signature_defaults , ** task_kwargs }
212217 kwargs .pop ("produces" , None )
213218
219+ parameters_with_product_annot = _find_args_with_product_annotation (obj )
220+
214221 dependencies = {}
215222 for name , value in kwargs .items ():
223+ if name in parameters_with_product_annot :
224+ continue
216225 parsed_value = tree_map (
217226 lambda x : _collect_dependencies (session , path , name , x ), value # noqa: B023
218227 )
@@ -223,29 +232,108 @@ def parse_dependencies_from_task_function(
223232 return dependencies
224233
225234
235+ _ERROR_MULTIPLE_PRODUCT_DEFINITIONS = (
236+ "The task uses multiple ways to define products. Products should be defined with "
237+ "either\n \n - 'typing.Annotated[Path(...), Product]' (recommended)\n "
238+ "- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n "
239+ "- as a default argument for 'produces': 'produces = Path(...)'\n "
240+ "- '@pytask.mark.produces(Path(...))' (deprecated).\n \n "
241+ "Read more about products in the documentation: https://tinyurl.com/yrezszr4."
242+ )
243+
244+
226245def parse_products_from_task_function (
227246 session : Session , path : Path , name : str , obj : Any
228247) -> dict [str , Any ]:
229- """Parse dependencies from task function."""
248+ """Parse products from task function.
249+
250+ Raises
251+ ------
252+ NodeNotCollectedError
253+ If multiple ways were used to specify products.
254+
255+ """
256+ has_produces_decorator = False
257+ has_task_decorator = False
258+ has_signature_default = False
259+ has_annotation = False
260+ out = {}
261+
262+ if has_mark (obj , "produces" ):
263+ has_produces_decorator = True
264+ nodes = parse_nodes (session , path , name , obj , produces )
265+ out = {"produces" : nodes }
266+
230267 task_kwargs = obj .pytask_meta .kwargs if hasattr (obj , "pytask_meta" ) else {}
231268 if "produces" in task_kwargs :
232- return tree_map (
269+ collected_products = tree_map (
233270 lambda x : _collect_product (session , path , name , x , is_string_allowed = True ),
234271 task_kwargs ["produces" ],
235272 )
273+ out = {"produces" : collected_products }
236274
237275 parameters = inspect .signature (obj ).parameters
238- if "produces" in parameters :
276+
277+ if not has_mark (obj , "task" ) and "produces" in parameters :
239278 parameter = parameters ["produces" ]
240279 if parameter .default is not parameter .empty :
280+ has_signature_default = True
241281 # Use _collect_new_node to not collect strings.
242- return tree_map (
282+ collected_products = tree_map (
243283 lambda x : _collect_product (
244284 session , path , name , x , is_string_allowed = False
245285 ),
246286 parameter .default ,
247287 )
248- return {}
288+ out = {"produces" : collected_products }
289+
290+ parameters_with_product_annot = _find_args_with_product_annotation (obj )
291+ if parameters_with_product_annot :
292+ has_annotation = True
293+ for parameter_name in parameters_with_product_annot :
294+ parameter = parameters [parameter_name ]
295+ if parameter .default is not parameter .empty :
296+ # Use _collect_new_node to not collect strings.
297+ collected_products = tree_map (
298+ lambda x : _collect_product (
299+ session , path , name , x , is_string_allowed = False
300+ ),
301+ parameter .default ,
302+ )
303+ out = {parameter_name : collected_products }
304+
305+ if (
306+ sum (
307+ (
308+ has_produces_decorator ,
309+ has_task_decorator ,
310+ has_signature_default ,
311+ has_annotation ,
312+ )
313+ )
314+ >= 2 # noqa: PLR2004
315+ ):
316+ raise NodeNotCollectedError (_ERROR_MULTIPLE_PRODUCT_DEFINITIONS )
317+
318+ return out
319+
320+
321+ def _find_args_with_product_annotation (func : Callable [..., Any ]) -> list [str ]:
322+ """Find args with product annotation."""
323+ annotations = get_annotations (func , eval_str = True )
324+ metas = {
325+ name : annotation .__metadata__
326+ for name , annotation in annotations .items ()
327+ if get_origin (annotation ) is Annotated
328+ }
329+
330+ args_with_product_annot = []
331+ for name , meta in metas .items ():
332+ has_product_annot = any (isinstance (i , ProductType ) for i in meta )
333+ if has_product_annot :
334+ args_with_product_annot .append (name )
335+
336+ return args_with_product_annot
249337
250338
251339def _collect_old_dependencies (
@@ -331,9 +419,10 @@ def _collect_product(
331419 # The parameter defaults only support Path objects.
332420 if not isinstance (node , Path ) and not is_string_allowed :
333421 raise ValueError (
334- "If you use 'produces' as an argument of a task, it can only accept values "
335- "of type 'pathlib.Path' or the same value nested in "
336- f"tuples, lists, and dictionaries. Here, { node } has type { type (node )} ."
422+ "If you use 'produces' as a function argument of a task and pass values as "
423+ "function defaults, it can only accept values of type 'pathlib.Path' or "
424+ "the same value nested in tuples, lists, and dictionaries. Here, "
425+ f"{ node !r} has type { type (node )} ."
337426 )
338427
339428 if isinstance (node , str ):
0 commit comments