@@ -87,111 +87,81 @@ def add_task(
8787 """Register a task and its dependencies within the workflow.
8888
8989 Args:
90- name: A unique identifier for the task.
91- func: The callable (synchronous or asynchronous) to execute.
92- dependencies: An iterable of task names that this task depends on.
93- timeout: An optional timeout in seconds. If the task execution exceeds
94- this duration, it is cancelled and recorded as a `TaskError`.
90+ name: The unique identifier for the task.
91+ func: The callable to execute for this task.
92+ dependencies: An iterable of task names that must complete before this task runs.
93+ timeout: Optional timeout in seconds for the task execution.
9594
9695 Raises:
97- ValueError: If any dependency references a task that has not yet been
98- registered with the engine.
96+ ValueError: If a task with the given name is already registered.
9997 """
100- if dependencies is not None :
101- # Convert dependencies to a list to prevent exhausting iterators/generators
102- if isinstance (dependencies , str ):
103- dependencies = [dependencies ]
104- else :
105- dependencies = list (dependencies )
106-
107- # Validate dependencies exist before adding
108- if dependencies :
109- missing = [dep for dep in dependencies if dep not in self .tasks ]
110- if missing :
111- raise ValueError (
112- f"Task { name !r} depends on unregistered tasks: { missing } "
113- )
114-
98+ if name in self .tasks :
99+ raise ValueError (f"Task '{ name } ' is already registered." )
100+
115101 self .tasks [name ] = func
116102 self ._timeouts [name ] = timeout
117- self ._is_async [name ] = inspect .iscoroutinefunction (func )
103+ self ._is_async [name ] = asyncio .iscoroutinefunction (func )
118104 self ._predecessors [name ] = list (dependencies ) if dependencies else []
119105 self ._cached_topo_order = None
120106
121- def _get_topological_order (self ) -> list [str ]:
122- """Compute and cache the topological execution order of tasks.
123-
124- Returns:
125- A list of task names ordered such that all dependencies of a task
126- appear before the task itself.
127-
128- Raises:
129- graphlib.CycleError: If the registered task dependencies contain a cycle.
130- """
131- if self ._cached_topo_order is None :
132- graph = {name : tuple (deps ) for name , deps in self ._predecessors .items ()}
133- sorter = graphlib .TopologicalSorter (graph )
134- self ._cached_topo_order = list (sorter .static_order ())
135- return self ._cached_topo_order
136-
137- async def run (self , ** kwargs : Any ) -> dict [str , Any ]:
138- """Execute the registered workflow and return results for all tasks.
139-
140- Tasks are executed in topological order. If a task fails or times out,
141- it is recorded as a `TaskError`, and all dependent tasks are automatically
142- skipped and also recorded as `TaskError`s to prevent cascading failures.
107+ async def run (self , inputs : dict [str , Any ] | None = None ) -> dict [str , Any ]:
108+ """Execute the registered tasks in topological order.
143109
144110 Args:
145- **kwargs: Initial keyword arguments to pass to root tasks (tasks with
146- no dependencies).
111+ inputs: Optional dictionary of initial inputs for tasks.
147112
148113 Returns:
149- A dictionary mapping each task name to its execution result, or a
150- `TaskError` instance if the task failed, timed out, or was skipped
151- due to a dependency failure.
114+ A dictionary mapping task names to their results or `TaskError` instances.
115+
116+ Raises:
117+ ValueError: If a cyclic dependency is detected in the workflow.
152118 """
119+ if inputs is None :
120+ inputs = {}
121+
153122 results : dict [str , Any ] = {}
154- topo_order = self ._get_topological_order ()
123+
124+ try :
125+ topo_sorter = graphlib .TopologicalSorter (self ._predecessors )
126+ topo_order = list (topo_sorter .static_order ())
127+ except graphlib .CycleError as e :
128+ raise ValueError (f"Cyclic dependency detected in workflow: { e } " ) from e
155129
156130 for task_name in topo_order :
157- func = self .tasks [task_name ]
158131 deps = self ._predecessors [task_name ]
159-
160- # Check if any dependency resulted in a TaskError
161- if any (isinstance (results .get (dep ), TaskError ) for dep in deps ):
132+
133+ # Check if any dependency failed
134+ failed_dep = None
135+ for dep in deps :
136+ if dep in results and isinstance (results [dep ], TaskError ):
137+ failed_dep = dep
138+ break
139+
140+ if failed_dep is not None :
162141 results [task_name ] = TaskError (
163- task_name ,
164- RuntimeError ("Skipped due to dependency failure " )
142+ task_name ,
143+ RuntimeError (f"Dependency ' { failed_dep } ' failed " )
165144 )
166145 continue
167146
168- # Prepare arguments for the task
169- task_kwargs = {dep : results [dep ] for dep in deps }
170- if not deps :
171- task_kwargs .update (kwargs )
147+ func = self .tasks [task_name ]
148+ kwargs = {dep : results [dep ] for dep in deps if dep in results }
149+
150+ # Merge with initial inputs if provided
151+ if task_name in inputs :
152+ kwargs .update (inputs [task_name ])
172153
173154 try :
174- timeout = self ._timeouts [task_name ]
175155 if self ._is_async [task_name ]:
176- if timeout is not None :
156+ if self . _timeouts [ task_name ] is not None :
177157 results [task_name ] = await asyncio .wait_for (
178- func (** task_kwargs ), timeout = timeout
158+ func (** kwargs ), timeout = self . _timeouts [ task_name ]
179159 )
180160 else :
181- results [task_name ] = await func (** task_kwargs )
161+ results [task_name ] = await func (** kwargs )
182162 else :
183- if timeout is not None :
184- results [task_name ] = await asyncio .wait_for (
185- asyncio .to_thread (func , ** task_kwargs ), timeout = timeout
186- )
187- else :
188- # Run synchronous functions in a thread to avoid blocking the event loop
189- results [task_name ] = await asyncio .to_thread (func , ** task_kwargs )
190- except asyncio .TimeoutError as e :
191- logger .warning ("Task %r timed out after %s seconds" , task_name , timeout )
192- results [task_name ] = TaskError (task_name , e )
163+ results [task_name ] = func (** kwargs )
193164 except Exception as e :
194- logger .exception ("Task %r failed with exception" , task_name )
195165 results [task_name ] = TaskError (task_name , e )
196166
197167 return results
0 commit comments