Skip to content

Commit 83d462e

Browse files
committed
Implements a routine for topologically sorting a DAG with dynamic keying function
1 parent e7e4766 commit 83d462e

1 file changed

Lines changed: 79 additions & 10 deletions

File tree

pytools/graph.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
.. autofunction:: compute_sccs
3434
.. autoclass:: CycleError
3535
.. autofunction:: compute_topological_order
36+
.. autofunction:: compute_topological_order_with_dynamic_key
3637
.. autofunction:: compute_transitive_closure
3738
.. autofunction:: contains_cycle
3839
.. autofunction:: compute_induced_subgraph
40+
.. autoclass:: TopologicalOrderState
3941
4042
Type Variables Used
4143
-------------------
@@ -46,7 +48,8 @@
4648
"""
4749

4850
from typing import (TypeVar, Mapping, Iterable, List, Optional, Any, Callable,
49-
Set, MutableSet, Dict, Iterator, Tuple)
51+
Set, MutableSet, Dict, Iterator, Tuple, Generic)
52+
from dataclasses import dataclass
5053

5154

5255
T = TypeVar("T")
@@ -207,18 +210,44 @@ def __lt__(self, other):
207210
return self.key < other.key
208211

209212

210-
def compute_topological_order(graph: Mapping[T, Iterable[T]],
211-
key: Optional[Callable[[T], Any]] = None) -> List[T]:
212-
"""Compute a topological order of nodes in a directed graph.
213+
@dataclass(frozen=True)
214+
class TopologicalOrderState(Generic[T]):
215+
"""
216+
.. attribute:: scheduled_nodes
217+
218+
A :class:`list` of nodes that have been scheduled.
219+
220+
.. warning::
221+
222+
- Mutable updates to :attr:`scheduled_nodes`
223+
results in an undefined behavior.
224+
"""
225+
scheduled_nodes: List[T]
226+
227+
228+
def compute_topological_order_with_dynamic_key(
229+
graph: Mapping[T, Iterable[T]],
230+
trigger_key_update: Callable[[TopologicalOrderState[T]], bool],
231+
get_key: Callable[[TopologicalOrderState[T]], Callable[[T], Any]]
232+
) -> List[T]:
233+
"""
234+
Computes a topological order of nodes in a directed graph with support for
235+
a dynamic keying function.
213236
214237
:arg graph: A :class:`collections.abc.Mapping` representing a directed
215238
graph. The dictionary contains one key representing each node in the
216239
graph, and this key maps to a :class:`collections.abc.Iterable` of its
217240
successor nodes.
218241
219-
:arg key: A custom key function may be supplied to determine the order in
220-
break-even cases. Expects a function of one argument that is used to
221-
extract a comparison key from each node of the *graph*.
242+
:arg trigger_key_update: A function called after scheduling a node in
243+
*graph* that takes in an instance of :class:`TopologicalOrderState`
244+
corresponding to the scheduling state at that point and returns whether
245+
the comparison keys corresponding to the nodes be updated.
246+
247+
:arg get_key: A callable called when *trigger_key_update*
248+
returns *True*. Takes in an instance of :class:`TopologicalOrderState`
249+
and returns another callable that accepts node as an argument and returns the
250+
comparison key corresponding to the node.
222251
223252
:returns: A :class:`list` representing a valid topological ordering of the
224253
nodes in the directed graph.
@@ -228,10 +257,8 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
228257
* Requires the keys of the mapping *graph* to be hashable.
229258
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.
230259
231-
.. versionadded:: 2020.2
260+
.. versionadded:: 2022.2
232261
"""
233-
# all nodes have the same keys when not provided
234-
keyfunc = key if key is not None else (lambda x: 0)
235262

236263
from heapq import heapify, heappop, heappush
237264

@@ -248,6 +275,8 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
248275

249276
# }}}
250277

278+
keyfunc = get_key(TopologicalOrderState(scheduled_nodes=[]))
279+
251280
total_num_nodes = len(nodes_to_num_predecessors)
252281

253282
# heap: list of instances of HeapEntry(n) where 'n' is a node in
@@ -263,6 +292,14 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
263292
node_to_be_scheduled = heappop(heap).node
264293
order.append(node_to_be_scheduled)
265294

295+
state = TopologicalOrderState(scheduled_nodes=order)
296+
297+
if trigger_key_update(state):
298+
keyfunc = get_key(state)
299+
heap = [HeapEntry(entry.node, keyfunc(entry.node))
300+
for entry in heap]
301+
heapify(heap)
302+
266303
# discard 'node_to_be_scheduled' from the predecessors of its
267304
# successors since it's been scheduled
268305
for child in graph.get(node_to_be_scheduled, ()):
@@ -277,6 +314,38 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
277314

278315
return order
279316

317+
318+
def compute_topological_order(graph: Mapping[T, Iterable[T]],
319+
key: Optional[Callable[[T], Any]] = None) -> List[T]:
320+
"""Compute a topological order of nodes in a directed graph.
321+
322+
:arg graph: A :class:`collections.abc.Mapping` representing a directed
323+
graph. The dictionary contains one key representing each node in the
324+
graph, and this key maps to a :class:`collections.abc.Iterable` of its
325+
successor nodes.
326+
327+
:arg key: A custom key function may be supplied to determine the order in
328+
break-even cases. Expects a function of one argument that is used to
329+
extract a comparison key from each node of the *graph*.
330+
331+
:returns: A :class:`list` representing a valid topological ordering of the
332+
nodes in the directed graph.
333+
334+
.. note::
335+
336+
* Requires the keys of the mapping *graph* to be hashable.
337+
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.
338+
339+
.. versionadded:: 2020.2
340+
"""
341+
# all nodes have the same keys when not provided
342+
keyfunc = key if key is not None else (lambda x: 0)
343+
344+
return compute_topological_order_with_dynamic_key(
345+
graph,
346+
trigger_key_update=lambda _: False,
347+
get_key=lambda _: keyfunc)
348+
280349
# }}}
281350

282351

0 commit comments

Comments
 (0)