Skip to content

Commit a4284cf

Browse files
committed
Implements a routine for topologically sorting a DAG with dynamic keying function
1 parent dc431a4 commit a4284cf

File tree

1 file changed

+78
-10
lines changed

1 file changed

+78
-10
lines changed

pytools/graph.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
.. autofunction:: compute_sccs
3434
.. autoclass:: CycleError
3535
.. autofunction:: compute_topological_order
36+
.. autofunction:: compute_topological_order_with_dynamic_key
3637
.. autofunction:: compute_transitive_closure
3738
.. autofunction:: contains_cycle
3839
.. autofunction:: compute_induced_subgraph
40+
.. autoclass:: TopologicalOrderState
3941
4042
Type Variables Used
4143
-------------------
@@ -46,7 +48,8 @@
4648
"""
4749

4850
from typing import (TypeVar, Mapping, Iterable, List, Optional, Any, Callable,
49-
Set, MutableSet, Dict, Iterator, Tuple)
51+
Set, MutableSet, Dict, Iterator, Tuple, Generic)
52+
from dataclasses import dataclass
5053

5154

5255
T = TypeVar("T")
@@ -207,18 +210,43 @@ def __lt__(self, other):
207210
return self.key < other.key
208211

209212

210-
def compute_topological_order(graph: Mapping[T, Iterable[T]],
211-
key: Optional[Callable[[T], Any]] = None) -> List[T]:
212-
"""Compute a topological order of nodes in a directed graph.
213+
@dataclass(frozen=True)
214+
class TopologicalOrderState(Generic[T]):
215+
"""
216+
.. attribute:: scheduled_nodes
217+
A :class:`list` of nodes that have been scheduled.
218+
219+
.. warning::
220+
221+
- Mutable updates to :attr:`scheduled_nodes`
222+
results in an undefined behavior.
223+
"""
224+
scheduled_nodes: List[T]
225+
226+
227+
def compute_topological_order_with_dynamic_key(
228+
graph: Mapping[T, Iterable[T]],
229+
trigger_key_update: Callable[[TopologicalOrderState[T]], bool],
230+
get_key: Callable[[TopologicalOrderState[T]], Callable[[T], Any]]
231+
) -> List[T]:
232+
"""
233+
Computes a topological order of nodes in a directed graph with support for
234+
a dynamic keying function.
213235
214236
:arg graph: A :class:`collections.abc.Mapping` representing a directed
215237
graph. The dictionary contains one key representing each node in the
216238
graph, and this key maps to a :class:`collections.abc.Iterable` of its
217239
successor nodes.
218240
219-
:arg key: A custom key function may be supplied to determine the order in
220-
break-even cases. Expects a function of one argument that is used to
221-
extract a comparison key from each node of the *graph*.
241+
:arg trigger_key_update: A function called after scheduling a node in
242+
*graph* that takes in an instance of :class:`TopologicalOrderState`
243+
corresponding to the scheduling state at that point and returns whether
244+
the comparison keys corresponding to the nodes be updated.
245+
246+
:arg get_key: A callable called when *trigger_key_update*
247+
returns *True*. Takes in an instance of :class:`TopologicalOrderState`
248+
and returns another callable that accepts node as an argument and returns the
249+
comparison key corresponding to the node.
222250
223251
:returns: A :class:`list` representing a valid topological ordering of the
224252
nodes in the directed graph.
@@ -228,10 +256,8 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
228256
* Requires the keys of the mapping *graph* to be hashable.
229257
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.
230258
231-
.. versionadded:: 2020.2
259+
.. versionadded:: 2022.2
232260
"""
233-
# all nodes have the same keys when not provided
234-
keyfunc = key if key is not None else (lambda x: 0)
235261

236262
from heapq import heapify, heappop, heappush
237263

@@ -248,6 +274,8 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
248274

249275
# }}}
250276

277+
keyfunc = get_key(TopologicalOrderState(scheduled_nodes=[]))
278+
251279
total_num_nodes = len(nodes_to_num_predecessors)
252280

253281
# heap: list of instances of HeapEntry(n) where 'n' is a node in
@@ -263,6 +291,14 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
263291
node_to_be_scheduled = heappop(heap).node
264292
order.append(node_to_be_scheduled)
265293

294+
state = TopologicalOrderState(scheduled_nodes=order)
295+
296+
if trigger_key_update(state):
297+
keyfunc = get_key(state)
298+
heap = [HeapEntry(entry.node, keyfunc(entry.node))
299+
for entry in heap]
300+
heapify(heap)
301+
266302
# discard 'node_to_be_scheduled' from the predecessors of its
267303
# successors since it's been scheduled
268304
for child in graph.get(node_to_be_scheduled, ()):
@@ -277,6 +313,38 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
277313

278314
return order
279315

316+
317+
def compute_topological_order(graph: Mapping[T, Iterable[T]],
318+
key: Optional[Callable[[T], Any]] = None) -> List[T]:
319+
"""Compute a topological order of nodes in a directed graph.
320+
321+
:arg graph: A :class:`collections.abc.Mapping` representing a directed
322+
graph. The dictionary contains one key representing each node in the
323+
graph, and this key maps to a :class:`collections.abc.Iterable` of its
324+
successor nodes.
325+
326+
:arg key: A custom key function may be supplied to determine the order in
327+
break-even cases. Expects a function of one argument that is used to
328+
extract a comparison key from each node of the *graph*.
329+
330+
:returns: A :class:`list` representing a valid topological ordering of the
331+
nodes in the directed graph.
332+
333+
.. note::
334+
335+
* Requires the keys of the mapping *graph* to be hashable.
336+
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.
337+
338+
.. versionadded:: 2020.2
339+
"""
340+
# all nodes have the same keys when not provided
341+
keyfunc = key if key is not None else (lambda x: 0)
342+
343+
return compute_topological_order_with_dynamic_key(
344+
graph,
345+
trigger_key_update=lambda _: False,
346+
get_key=lambda _: keyfunc)
347+
280348
# }}}
281349

282350

0 commit comments

Comments
 (0)