3131__all__ = ["execute_action_dag" ]
3232
3333
34+ class _DagRun :
35+ """Mutable scheduling state shared by the submit / completion helpers."""
36+
37+ def __init__ (
38+ self ,
39+ nodes : list [Mapping [str , Any ]],
40+ pool : ThreadPoolExecutor ,
41+ fail_fast : bool ,
42+ ) -> None :
43+ self .graph , self .indegree = _build_graph (nodes )
44+ self .node_map = {_require_id (node ): node for node in nodes }
45+ self .results : dict [str , Any ] = {}
46+ self .lock = threading .Lock ()
47+ self .ready : deque [str ] = deque (
48+ node_id for node_id , count in self .indegree .items () if count == 0
49+ )
50+ self .in_flight : dict [Future [Any ], str ] = {}
51+ self .pool = pool
52+ self .fail_fast = fail_fast
53+
54+ def _mark_skipped (self , dependent : str , reason_id : str ) -> None :
55+ with self .lock :
56+ if dependent in self .results :
57+ return
58+ self .results [dependent ] = f"skipped: dep { reason_id !r} failed"
59+ for grandchild in self .graph .get (dependent , ()):
60+ self .indegree [grandchild ] -= 1
61+ self ._mark_skipped (grandchild , dependent )
62+
63+ def _skip_dependents (self , node_id : str ) -> None :
64+ for dependent in self .graph .get (node_id , ()):
65+ self .indegree [dependent ] -= 1
66+ self ._mark_skipped (dependent , node_id )
67+
68+ def submit (self , node_id : str ) -> None :
69+ action = self .node_map [node_id ].get ("action" )
70+ if not isinstance (action , list ):
71+ err = DagException (f"node { node_id !r} missing action list" )
72+ with self .lock :
73+ self .results [node_id ] = repr (err )
74+ if self .fail_fast :
75+ self ._skip_dependents (node_id )
76+ return
77+ future = self .pool .submit (_run_action , action )
78+ self .in_flight [future ] = node_id
79+
80+ def _complete (self , node_id : str , value : Any , failed : bool ) -> None :
81+ with self .lock :
82+ self .results [node_id ] = value
83+ for dependent in self .graph .get (node_id , ()):
84+ self .indegree [dependent ] -= 1
85+ if failed and self .fail_fast :
86+ self ._mark_skipped (dependent , node_id )
87+ elif self .indegree [dependent ] == 0 and dependent not in self .results :
88+ self .ready .append (dependent )
89+
90+ def drain_completed (self ) -> None :
91+ done , _ = wait (list (self .in_flight ), return_when = FIRST_COMPLETED )
92+ for future in done :
93+ node_id = self .in_flight .pop (future )
94+ try :
95+ value : Any = future .result ()
96+ failed = False
97+ except Exception as err : # pylint: disable=broad-except
98+ value = repr (err )
99+ failed = True
100+ self ._complete (node_id , value , failed )
101+
102+
34103def execute_action_dag (
35104 nodes : list [Mapping [str , Any ]],
36105 max_workers : int = 4 ,
@@ -46,54 +115,15 @@ def execute_action_dag(
46115 Raises :class:`DagException` for static errors detected before any action
47116 runs: duplicate ids, unknown dependencies, or cycles.
48117 """
49- graph , indegree = _build_graph (nodes )
50- node_map = {_require_id (node ): node for node in nodes }
51- results : dict [str , Any ] = {}
52- lock = threading .Lock ()
53-
54- ready : deque [str ] = deque (node_id for node_id , count in indegree .items () if count == 0 )
55-
56118 with ThreadPoolExecutor (max_workers = max_workers ) as pool :
57- in_flight : dict [Future [Any ], str ] = {}
58-
59- def submit (node_id : str ) -> None :
60- action = node_map [node_id ].get ("action" )
61- if not isinstance (action , list ):
62- err = DagException (f"node { node_id !r} missing action list" )
63- with lock :
64- results [node_id ] = repr (err )
65- if fail_fast :
66- for dependent in graph .get (node_id , ()):
67- indegree [dependent ] -= 1
68- _mark_skipped (dependent , node_id , graph , indegree , results , lock )
69- return
70- future = pool .submit (_run_action , action )
71- in_flight [future ] = node_id
72-
73- while ready or in_flight :
74- while ready :
75- submit (ready .popleft ())
76- if not in_flight :
119+ state = _DagRun (nodes , pool , fail_fast )
120+ while state .ready or state .in_flight :
121+ while state .ready :
122+ state .submit (state .ready .popleft ())
123+ if not state .in_flight :
77124 break
78- done , _ = wait (list (in_flight ), return_when = FIRST_COMPLETED )
79- for future in done :
80- node_id = in_flight .pop (future )
81- failed = False
82- try :
83- value : Any = future .result ()
84- except Exception as err : # pylint: disable=broad-except
85- value = repr (err )
86- failed = True
87- with lock :
88- results [node_id ] = value
89- for dependent in graph .get (node_id , ()):
90- indegree [dependent ] -= 1
91- if failed and fail_fast :
92- _mark_skipped (dependent , node_id , graph , indegree , results , lock )
93- elif indegree [dependent ] == 0 and dependent not in results :
94- ready .append (dependent )
95-
96- return results
125+ state .drain_completed ()
126+ return state .results
97127
98128
99129def _run_action (action : list ) -> Any :
@@ -156,20 +186,3 @@ def _detect_cycle(
156186 queue .append (dependent )
157187 if visited != len (ids ):
158188 raise DagException ("cycle detected in DAG" )
159-
160-
161- def _mark_skipped (
162- dependent : str ,
163- reason_id : str ,
164- graph : dict [str , list [str ]],
165- indegree : dict [str , int ],
166- results : dict [str , Any ],
167- lock : threading .Lock ,
168- ) -> None :
169- with lock :
170- if dependent in results :
171- return
172- results [dependent ] = f"skipped: dep { reason_id !r} failed"
173- for grandchild in graph .get (dependent , ()):
174- indegree [grandchild ] -= 1
175- _mark_skipped (grandchild , dependent , graph , indegree , results , lock )
0 commit comments