|
1 | 1 | """Implement some capabilities to deal with the DAG.""" |
2 | 2 | import itertools |
| 3 | +import pprint |
| 4 | +from typing import Dict |
| 5 | +from typing import Generator |
| 6 | +from typing import Iterable |
| 7 | +from typing import List |
3 | 8 |
|
| 9 | +import attr |
4 | 10 | import networkx as nx |
| 11 | +from _pytask.mark import get_specific_markers_from_task |
| 12 | +from _pytask.nodes import MetaTask |
5 | 13 |
|
6 | 14 |
|
7 | | -def sort_tasks_topologically(dag): |
8 | | - """Sort tasks in topological ordering.""" |
9 | | - for node in nx.topological_sort(dag): |
10 | | - if "task" in dag.nodes[node]: |
11 | | - yield node |
12 | | - |
13 | | - |
14 | | -def descending_tasks(task_name, dag): |
| 15 | +def descending_tasks(task_name: str, dag: nx.DiGraph) -> Generator[str, None, None]: |
15 | 16 | """Yield only descending tasks.""" |
16 | 17 | for descendant in nx.descendants(dag, task_name): |
17 | 18 | if "task" in dag.nodes[descendant]: |
18 | 19 | yield descendant |
19 | 20 |
|
20 | 21 |
|
21 | | -def task_and_descending_tasks(task_name, dag): |
| 22 | +def task_and_descending_tasks( |
| 23 | + task_name: str, dag: nx.DiGraph |
| 24 | +) -> Generator[str, None, None]: |
22 | 25 | """Yield task and descending tasks.""" |
23 | 26 | yield task_name |
24 | 27 | yield from descending_tasks(task_name, dag) |
25 | 28 |
|
26 | 29 |
|
27 | | -def node_and_neighbors(dag, node): |
| 30 | +def node_and_neighbors(dag: nx.DiGraph, node: str) -> Generator[str, None, None]: |
28 | 31 | """Yield node and neighbors which are first degree predecessors and successors. |
29 | 32 |
|
30 | 33 | We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a |
31 | 34 | DAG. |
32 | 35 |
|
33 | 36 | """ |
34 | 37 | return itertools.chain([node], dag.predecessors(node), dag.successors(node)) |
| 38 | + |
| 39 | + |
| 40 | +@attr.s |
| 41 | +class TopologicalSorter: |
| 42 | + """The topological sorter class. |
| 43 | +
|
| 44 | + This class allows to perform a topological sort |
| 45 | +
|
| 46 | + """ |
| 47 | + |
| 48 | + dag = attr.ib(converter=nx.DiGraph) |
| 49 | + priorities = attr.ib(factory=dict) |
| 50 | + _dag_backup = attr.ib(default=None) |
| 51 | + _is_prepared = attr.ib(default=False, type=bool) |
| 52 | + _nodes_out = attr.ib(factory=set) |
| 53 | + |
| 54 | + @classmethod |
| 55 | + def from_dag(cls, dag: nx.DiGraph) -> "TopologicalSorter": |
| 56 | + if not dag.is_directed(): |
| 57 | + raise ValueError("Only directed graphs have a topological order.") |
| 58 | + |
| 59 | + tasks = [ |
| 60 | + dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] |
| 61 | + ] |
| 62 | + priorities = _extract_priorities_from_tasks(tasks) |
| 63 | + |
| 64 | + task_names = {task.name for task in tasks} |
| 65 | + task_dict = {name: nx.ancestors(dag, name) & task_names for name in task_names} |
| 66 | + task_dag = nx.DiGraph(task_dict).reverse() |
| 67 | + |
| 68 | + return cls(task_dag, priorities, task_dag.copy()) |
| 69 | + |
| 70 | + def prepare(self) -> None: |
| 71 | + """Perform some checks before creating a topological ordering.""" |
| 72 | + try: |
| 73 | + nx.algorithms.cycles.find_cycle(self.dag) |
| 74 | + except nx.NetworkXNoCycle: |
| 75 | + pass |
| 76 | + else: |
| 77 | + raise ValueError("The DAG contains cycles.") |
| 78 | + |
| 79 | + self._is_prepared = True |
| 80 | + |
| 81 | + def get_ready(self, n: int = 1): |
| 82 | + """Get up to ``n`` tasks which are ready.""" |
| 83 | + if not self._is_prepared: |
| 84 | + raise ValueError("The TopologicalSorter needs to be prepared.") |
| 85 | + if not isinstance(n, int) or n < 1: |
| 86 | + raise ValueError("'n' must be an integer greater or equal than 1.") |
| 87 | + |
| 88 | + ready_nodes = {v for v, d in self.dag.in_degree() if d == 0} - self._nodes_out |
| 89 | + prioritized_nodes = sorted( |
| 90 | + ready_nodes, key=lambda x: self.priorities.get(x, 0) |
| 91 | + )[-n:] |
| 92 | + |
| 93 | + self._nodes_out.update(prioritized_nodes) |
| 94 | + |
| 95 | + return prioritized_nodes |
| 96 | + |
| 97 | + def is_active(self) -> bool: |
| 98 | + """Indicate whether there are still tasks left.""" |
| 99 | + return bool(self.dag.nodes) |
| 100 | + |
| 101 | + def done(self, *nodes: Iterable[str]) -> None: |
| 102 | + """Mark some tasks as done.""" |
| 103 | + self._nodes_out = self._nodes_out - set(nodes) |
| 104 | + self.dag.remove_nodes_from(nodes) |
| 105 | + |
| 106 | + def reset(self) -> None: |
| 107 | + """Reset an exhausted topological sorter.""" |
| 108 | + self.dag = self._dag_backup.copy() |
| 109 | + self._is_prepared = False |
| 110 | + self._nodes_out = set() |
| 111 | + |
| 112 | + def static_order(self) -> Generator[str, None, None]: |
| 113 | + """Return a topological order of tasks as an iterable.""" |
| 114 | + self.prepare() |
| 115 | + while self.is_active(): |
| 116 | + new_task = self.get_ready()[0] |
| 117 | + yield new_task |
| 118 | + self.done(new_task) |
| 119 | + |
| 120 | + |
| 121 | +def _extract_priorities_from_tasks(tasks: List[MetaTask]) -> Dict[str, int]: |
| 122 | + """Extract priorities from tasks. |
| 123 | +
|
| 124 | + Priorities are set via the ``pytask.mark.try_first`` and ``pytask.mark.try_last`` |
| 125 | + markers. We recode these markers to numeric values to sort all available by |
| 126 | + priorities. ``try_first`` is assigned the highest value such that it has the |
| 127 | + rightmost position in the list. Then, we can simply call :meth:`list.pop` on the |
| 128 | + list which is far more efficient than ``list.pop(0)``. |
| 129 | +
|
| 130 | + """ |
| 131 | + priorities = { |
| 132 | + task.name: { |
| 133 | + "try_first": bool(get_specific_markers_from_task(task, "try_first")), |
| 134 | + "try_last": bool(get_specific_markers_from_task(task, "try_last")), |
| 135 | + } |
| 136 | + for task in tasks |
| 137 | + } |
| 138 | + tasks_w_mixed_priorities = [ |
| 139 | + name for name, p in priorities.items() if p["try_first"] and p["try_last"] |
| 140 | + ] |
| 141 | + if tasks_w_mixed_priorities: |
| 142 | + raise ValueError( |
| 143 | + "'try_first' and 'try_last' cannot be applied on the same task. See the " |
| 144 | + f"following tasks for errors:\n\n{pprint.pformat(tasks_w_mixed_priorities)}" |
| 145 | + ) |
| 146 | + |
| 147 | + # Recode to numeric values for sorting. |
| 148 | + numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} |
| 149 | + numeric_priorities = { |
| 150 | + name: numeric_mapping[(p["try_first"], p["try_last"])] |
| 151 | + for name, p in priorities.items() |
| 152 | + } |
| 153 | + |
| 154 | + return numeric_priorities |
0 commit comments