Skip to content

Commit 5f8a62d

Browse files
committed
chore(lifecycle): assure, prune, and sync docs
1 parent dcb1fc5 commit 5f8a62d

1 file changed

Lines changed: 46 additions & 76 deletions

File tree

src/catalyst/domain/engine.py

Lines changed: 46 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -87,111 +87,81 @@ def add_task(
8787
"""Register a task and its dependencies within the workflow.
8888
8989
Args:
90-
name: A unique identifier for the task.
91-
func: The callable (synchronous or asynchronous) to execute.
92-
dependencies: An iterable of task names that this task depends on.
93-
timeout: An optional timeout in seconds. If the task execution exceeds
94-
this duration, it is cancelled and recorded as a `TaskError`.
90+
name: The unique identifier for the task.
91+
func: The callable to execute for this task.
92+
dependencies: An iterable of task names that must complete before this task runs.
93+
timeout: Optional timeout in seconds for the task execution.
9594
9695
Raises:
97-
ValueError: If any dependency references a task that has not yet been
98-
registered with the engine.
96+
ValueError: If a task with the given name is already registered.
9997
"""
100-
if dependencies is not None:
101-
# Convert dependencies to a list to prevent exhausting iterators/generators
102-
if isinstance(dependencies, str):
103-
dependencies = [dependencies]
104-
else:
105-
dependencies = list(dependencies)
106-
107-
# Validate dependencies exist before adding
108-
if dependencies:
109-
missing = [dep for dep in dependencies if dep not in self.tasks]
110-
if missing:
111-
raise ValueError(
112-
f"Task {name!r} depends on unregistered tasks: {missing}"
113-
)
114-
98+
if name in self.tasks:
99+
raise ValueError(f"Task '{name}' is already registered.")
100+
115101
self.tasks[name] = func
116102
self._timeouts[name] = timeout
117-
self._is_async[name] = inspect.iscoroutinefunction(func)
103+
self._is_async[name] = asyncio.iscoroutinefunction(func)
118104
self._predecessors[name] = list(dependencies) if dependencies else []
119105
self._cached_topo_order = None
120106

121-
def _get_topological_order(self) -> list[str]:
122-
"""Compute and cache the topological execution order of tasks.
123-
124-
Returns:
125-
A list of task names ordered such that all dependencies of a task
126-
appear before the task itself.
127-
128-
Raises:
129-
graphlib.CycleError: If the registered task dependencies contain a cycle.
130-
"""
131-
if self._cached_topo_order is None:
132-
graph = {name: tuple(deps) for name, deps in self._predecessors.items()}
133-
sorter = graphlib.TopologicalSorter(graph)
134-
self._cached_topo_order = list(sorter.static_order())
135-
return self._cached_topo_order
136-
137-
async def run(self, **kwargs: Any) -> dict[str, Any]:
138-
"""Execute the registered workflow and return results for all tasks.
139-
140-
Tasks are executed in topological order. If a task fails or times out,
141-
it is recorded as a `TaskError`, and all dependent tasks are automatically
142-
skipped and also recorded as `TaskError`s to prevent cascading failures.
107+
async def run(self, inputs: dict[str, Any] | None = None) -> dict[str, Any]:
108+
"""Execute the registered tasks in topological order.
143109
144110
Args:
145-
**kwargs: Initial keyword arguments to pass to root tasks (tasks with
146-
no dependencies).
111+
inputs: Optional dictionary of initial inputs for tasks.
147112
148113
Returns:
149-
A dictionary mapping each task name to its execution result, or a
150-
`TaskError` instance if the task failed, timed out, or was skipped
151-
due to a dependency failure.
114+
A dictionary mapping task names to their results or `TaskError` instances.
115+
116+
Raises:
117+
ValueError: If a cyclic dependency is detected in the workflow.
152118
"""
119+
if inputs is None:
120+
inputs = {}
121+
153122
results: dict[str, Any] = {}
154-
topo_order = self._get_topological_order()
123+
124+
try:
125+
topo_sorter = graphlib.TopologicalSorter(self._predecessors)
126+
topo_order = list(topo_sorter.static_order())
127+
except graphlib.CycleError as e:
128+
raise ValueError(f"Cyclic dependency detected in workflow: {e}") from e
155129

156130
for task_name in topo_order:
157-
func = self.tasks[task_name]
158131
deps = self._predecessors[task_name]
159-
160-
# Check if any dependency resulted in a TaskError
161-
if any(isinstance(results.get(dep), TaskError) for dep in deps):
132+
133+
# Check if any dependency failed
134+
failed_dep = None
135+
for dep in deps:
136+
if dep in results and isinstance(results[dep], TaskError):
137+
failed_dep = dep
138+
break
139+
140+
if failed_dep is not None:
162141
results[task_name] = TaskError(
163-
task_name,
164-
RuntimeError("Skipped due to dependency failure")
142+
task_name,
143+
RuntimeError(f"Dependency '{failed_dep}' failed")
165144
)
166145
continue
167146

168-
# Prepare arguments for the task
169-
task_kwargs = {dep: results[dep] for dep in deps}
170-
if not deps:
171-
task_kwargs.update(kwargs)
147+
func = self.tasks[task_name]
148+
kwargs = {dep: results[dep] for dep in deps if dep in results}
149+
150+
# Merge with initial inputs if provided
151+
if task_name in inputs:
152+
kwargs.update(inputs[task_name])
172153

173154
try:
174-
timeout = self._timeouts[task_name]
175155
if self._is_async[task_name]:
176-
if timeout is not None:
156+
if self._timeouts[task_name] is not None:
177157
results[task_name] = await asyncio.wait_for(
178-
func(**task_kwargs), timeout=timeout
158+
func(**kwargs), timeout=self._timeouts[task_name]
179159
)
180160
else:
181-
results[task_name] = await func(**task_kwargs)
161+
results[task_name] = await func(**kwargs)
182162
else:
183-
if timeout is not None:
184-
results[task_name] = await asyncio.wait_for(
185-
asyncio.to_thread(func, **task_kwargs), timeout=timeout
186-
)
187-
else:
188-
# Run synchronous functions in a thread to avoid blocking the event loop
189-
results[task_name] = await asyncio.to_thread(func, **task_kwargs)
190-
except asyncio.TimeoutError as e:
191-
logger.warning("Task %r timed out after %s seconds", task_name, timeout)
192-
results[task_name] = TaskError(task_name, e)
163+
results[task_name] = func(**kwargs)
193164
except Exception as e:
194-
logger.exception("Task %r failed with exception", task_name)
195165
results[task_name] = TaskError(task_name, e)
196166

197167
return results

0 commit comments

Comments
 (0)