Skip to content

Commit cd96dc2

Browse files
add find_cycles and use it
1 parent e9d1f57 commit cd96dc2

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

pytools/graph.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
.. autoexception:: CycleError
4242
.. autofunction:: compute_topological_order
4343
.. autofunction:: compute_transitive_closure
44+
.. autofunction:: find_cycles
4445
.. autofunction:: contains_cycle
4546
.. autofunction:: compute_induced_subgraph
4647
.. autofunction:: validate_graph
@@ -240,6 +241,42 @@ def __init__(self, node: NodeT) -> None:
240241
self.node = node
241242

242243

244+
def find_cycles(graph: GraphT) -> List[List[NodeT]]:
245+
"""
246+
Find all cycles in *graph* using DFS.
247+
248+
:returns: A :class:`list` in which each element represents another :class:`list`
249+
of nodes that form a cycle.
250+
"""
251+
def DFS(node: NodeT, path: List[NodeT]) -> List[NodeT]:
252+
# Cycle detected
253+
if visited[node] == 1:
254+
return path
255+
256+
# Visit this node, explore its children
257+
visited[node] = 1
258+
path.append(node)
259+
for child in graph[node]:
260+
if visited[child] != 2 and DFS(child, path):
261+
return path
262+
263+
# Done visiting node
264+
visited[node] = 2
265+
return []
266+
267+
visited = {node: 0 for node in graph.keys()}
268+
269+
res = []
270+
271+
for node in graph:
272+
if not visited[node]:
273+
cycle = DFS(node, [])
274+
if cycle:
275+
res.append(cycle)
276+
277+
return res
278+
279+
243280
class HeapEntry:
244281
"""
245282
Helper class to compare associated keys while comparing the elements in
@@ -257,8 +294,8 @@ def __lt__(self, other: "HeapEntry") -> bool:
257294

258295

259296
def compute_topological_order(graph: GraphT,
260-
key: Optional[Callable[[T], Any]] = None,
261-
verbose_cycle: bool = True) -> List[T]:
297+
key: Optional[Callable[[NodeT], Any]] = None,
298+
verbose_cycle: bool = True) -> List[NodeT]:
262299
"""Compute a topological order of nodes in a directed graph.
263300
264301
:arg key: A custom key function may be supplied to determine the order in
@@ -323,24 +360,12 @@ def compute_topological_order(graph: GraphT,
323360
raise CycleError(None)
324361

325362
try:
326-
validate_graph(graph)
327-
except ValueError:
328-
# Graph is invalid, we can't compute SCCs or return a meaningful node
329-
# that is part of a cycle
363+
cycles = find_cycles(graph)
364+
except KeyError:
365+
# Graph is invalid
330366
raise CycleError(None)
331-
332-
sccs = compute_sccs(graph)
333-
cycles = [scc for scc in sccs if len(scc) > 1]
334-
335-
if cycles:
336-
# Cycles that are not self-loops
337-
node = cycles[0][0]
338367
else:
339-
# Self-loop SCCs also have a length of 1
340-
node = next(iter(n for n, num_preds in
341-
nodes_to_num_predecessors.items() if num_preds != 0))
342-
343-
raise CycleError(node)
368+
raise CycleError(cycles[0][0])
344369

345370
return order
346371

test/test_graph_tools.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,24 +395,33 @@ def test_is_connected():
395395
assert is_connected({})
396396

397397

398-
def test_cycle_detection():
399-
from pytools.graph import compute_topological_order, CycleError
398+
def test_find_cycles():
399+
from pytools.graph import compute_topological_order, CycleError, find_cycles
400400

401401
# Non-Self Loop
402402
graph = {1: {}, 5: {1, 8}, 8: {5}}
403+
assert find_cycles(graph) == [[5, 8]]
403404
with pytest.raises(CycleError, match="5|8"):
404405
compute_topological_order(graph)
405406

406407
# Self-Loop
407408
graph = {1: {1}, 5: {8}, 8: {}}
409+
assert find_cycles(graph) == [[1]]
408410
with pytest.raises(CycleError, match="1"):
409411
compute_topological_order(graph)
410412

411413
# Invalid graph with loop
412414
graph = {1: {42}, 5: {8}, 8: {5}}
415+
# Can't run find_cycles on this graph since it is invalid
413416
with pytest.raises(CycleError, match="None"):
414417
compute_topological_order(graph)
415418

419+
# Multiple loops
420+
graph = {1: {1}, 5: {8}, 8: {5}}
421+
assert find_cycles(graph) == [[1], [5, 8]]
422+
with pytest.raises(CycleError, match="1"):
423+
compute_topological_order(graph)
424+
416425

417426
if __name__ == "__main__":
418427
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)