Skip to content

Commit 964fb8d

Browse files
committed
Split out union find into separate module
1 parent 6eee49e commit 964fb8d

4 files changed

Lines changed: 211 additions & 110 deletions

File tree

src/boruvkas_algorithm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Boruvka's algorithm for finding minimum spanning trees."""
2+
3+
from boruvkas_algorithm.boruvka import Graph, find_mst_with_boruvkas_algorithm
4+
from boruvkas_algorithm.union_find import UnionFind
5+
6+
__all__: list[str] = ["Graph", "UnionFind", "find_mst_with_boruvkas_algorithm"]

src/boruvkas_algorithm/boruvka.py

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import matplotlib.pyplot as plt
66
import networkx as nx
77

8+
from boruvkas_algorithm.union_find import UnionFind
9+
810

911
class Graph:
1012
"""A graph that contains nodes and edges."""
@@ -76,84 +78,6 @@ def draw_mst(self, mst_edges: list[tuple[int, int, int]]) -> None:
7678
plt.show()
7779

7880

79-
class UnionFind:
80-
"""
81-
Union-find (disjoint set union) data structure for tracking connected
82-
components with path compression and union by size.
83-
"""
84-
85-
def __init__(self, size: int) -> None:
86-
"""
87-
Initialises the Union-Find structure.
88-
89-
Args:
90-
size: The number of elements in the structure.
91-
"""
92-
# Each node is its own parent initially.
93-
self.parent: list[int] = list(range(size))
94-
# Each tree has size 1 (itself) initially.
95-
self.rank: list[int] = [1] * size
96-
97-
def find(self, node: int) -> int:
98-
"""
99-
Finds the root parent of the node using path compression.
100-
101-
Args:
102-
node: The node to find the root parent of.
103-
104-
Returns:
105-
The root parent of the node.
106-
"""
107-
cur_parent = self.parent[node]
108-
while cur_parent != self.parent[cur_parent]:
109-
# Compress the links as we go up the chain of parents to make
110-
# it faster to traverse in the future - amortised O(a(n)) time,
111-
# where a(n) is the inverse Ackermann function.
112-
self.parent[cur_parent] = self.parent[self.parent[cur_parent]]
113-
cur_parent = self.parent[cur_parent]
114-
return cur_parent
115-
116-
def union(self, node1: int, node2: int) -> bool:
117-
"""
118-
Combines the two nodes into the larger segment.
119-
120-
Args:
121-
node1: The first node to combine.
122-
node2: The second node to combine.
123-
124-
Returns:
125-
True if the nodes were combined, False if they were already in the
126-
same segment.
127-
"""
128-
root1 = self.find(node1)
129-
root2 = self.find(node2)
130-
# If they have the same root parent, they're already connected.
131-
if root1 == root2:
132-
return False
133-
134-
# Combine the two nodes into the larger segment based on the rank.
135-
if self.rank[root1] > self.rank[root2]:
136-
self.parent[root2] = root1
137-
self.rank[root1] += self.rank[root2]
138-
else:
139-
self.parent[root1] = root2
140-
self.rank[root2] += self.rank[root1]
141-
return True
142-
143-
def is_connected(self, node1: int, node2: int) -> bool:
144-
"""
145-
Checks if two nodes are in the same component.
146-
147-
Args:
148-
node1: The first node.
149-
node2: The second node.
150-
151-
Returns:
152-
True if the nodes are connected, False otherwise.
153-
"""
154-
return self.find(node1) == self.find(node2)
155-
156-
15781
def find_mst_with_boruvkas_algorithm(
15882
graph: Graph,
15983
union_find: UnionFind | None = None,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
class UnionFind:
2+
"""
3+
Union-find (disjoint set union) data structure for tracking connected
4+
components with path compression and union by size.
5+
"""
6+
7+
def __init__(self, size: int) -> None:
8+
"""
9+
Initialises the Union-Find structure.
10+
11+
Args:
12+
size: The number of elements in the structure.
13+
"""
14+
# Each node is its own parent initially.
15+
self.parent: list[int] = list(range(size))
16+
# Each tree has size 1 (itself) initially.
17+
self.rank: list[int] = [1] * size
18+
19+
def find(self, node: int) -> int:
20+
"""
21+
Finds the root parent of the node using path compression.
22+
23+
Args:
24+
node: The node to find the root parent of.
25+
26+
Returns:
27+
The root parent of the node.
28+
"""
29+
cur_parent = self.parent[node]
30+
while cur_parent != self.parent[cur_parent]:
31+
# Compress the links as we go up the chain of parents to make
32+
# it faster to traverse in the future - amortised O(a(n)) time,
33+
# where a(n) is the inverse Ackermann function.
34+
self.parent[cur_parent] = self.parent[self.parent[cur_parent]]
35+
cur_parent = self.parent[cur_parent]
36+
return cur_parent
37+
38+
def union(self, node1: int, node2: int) -> bool:
39+
"""
40+
Combines the two nodes into the larger segment.
41+
42+
Args:
43+
node1: The first node to combine.
44+
node2: The second node to combine.
45+
46+
Returns:
47+
True if the nodes were combined, False if they were already in the
48+
same segment.
49+
"""
50+
root1 = self.find(node1)
51+
root2 = self.find(node2)
52+
# If they have the same root parent, they're already connected.
53+
if root1 == root2:
54+
return False
55+
56+
# Combine the two nodes into the larger segment based on the rank.
57+
if self.rank[root1] > self.rank[root2]:
58+
self.parent[root2] = root1
59+
self.rank[root1] += self.rank[root2]
60+
else:
61+
self.parent[root1] = root2
62+
self.rank[root2] += self.rank[root1]
63+
return True
64+
65+
def is_connected(self, node1: int, node2: int) -> bool:
66+
"""
67+
Checks if two nodes are in the same component.
68+
69+
Args:
70+
node1: The first node.
71+
node2: The second node.
72+
73+
Returns:
74+
True if the nodes are connected, False otherwise.
75+
"""
76+
return self.find(node1) == self.find(node2)

tests/test_boruvka.py

Lines changed: 127 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import pytest
22

3-
from boruvkas_algorithm.boruvka import (
4-
Graph,
5-
UnionFind,
6-
find_mst_with_boruvkas_algorithm,
7-
)
3+
from boruvkas_algorithm.boruvka import Graph, find_mst_with_boruvkas_algorithm
4+
from boruvkas_algorithm.union_find import UnionFind
85

96

107
@pytest.fixture
@@ -52,6 +49,131 @@ def test_add_edge_invalid_vertices(setup_graph: Graph):
5249
graph.add_edge(10, 11, 5)
5350

5451

52+
# =============================================================================
53+
# UnionFind Tests
54+
# =============================================================================
55+
56+
57+
def test_union_find_initialization():
58+
"""Tests that UnionFind initialises with correct parent and rank arrays."""
59+
uf = UnionFind(5)
60+
assert uf.parent == [0, 1, 2, 3, 4], "Each node should be its own parent"
61+
assert uf.rank == [1, 1, 1, 1, 1], "Each node should have rank 1"
62+
63+
64+
def test_union_find_find_single_node():
65+
"""Tests that find returns the node itself when it's its own parent."""
66+
uf = UnionFind(5)
67+
assert uf.find(0) == 0
68+
assert uf.find(4) == 4
69+
70+
71+
def test_union_find_union_two_nodes():
72+
"""Tests that union correctly combines two nodes."""
73+
uf = UnionFind(5)
74+
result = uf.union(0, 1)
75+
assert result is True, "Union should return True when nodes are combined"
76+
assert uf.find(0) == uf.find(1), "Nodes should have the same root after union"
77+
78+
79+
def test_union_find_union_already_connected():
80+
"""Tests that union returns False when nodes are already connected."""
81+
uf = UnionFind(5)
82+
uf.union(0, 1)
83+
result = uf.union(0, 1)
84+
assert result is False, "Union should return False when already connected"
85+
86+
87+
def test_union_find_union_by_size():
88+
"""Tests that smaller trees are merged into larger trees."""
89+
uf = UnionFind(5)
90+
# Create a larger tree: 0 <- 1, 0 <- 2
91+
uf.union(0, 1)
92+
uf.union(0, 2)
93+
# Now union with node 3 - node 3 should be merged into the larger tree.
94+
uf.union(3, 0)
95+
# The root of the larger tree should remain the root.
96+
root = uf.find(0)
97+
assert uf.find(3) == root, "Smaller tree should be merged into larger tree"
98+
99+
100+
def test_union_find_path_compression():
101+
"""Tests that path compression flattens the tree structure."""
102+
uf = UnionFind(5)
103+
# Create a chain: 0 <- 1 <- 2 <- 3
104+
uf.parent = [0, 0, 1, 2, 4]
105+
uf.rank = [4, 1, 1, 1, 1]
106+
# Find on node 3 should compress the path.
107+
root = uf.find(3)
108+
assert root == 0, "Root should be 0"
109+
# After path compression, intermediate nodes should point closer to root.
110+
assert uf.parent[2] in (0, 1), "Path compression should shorten the path"
111+
112+
113+
def test_union_find_multiple_components():
114+
"""Tests UnionFind with multiple separate components."""
115+
uf = UnionFind(6)
116+
# Create two components: {0, 1, 2} and {3, 4, 5}
117+
uf.union(0, 1)
118+
uf.union(1, 2)
119+
uf.union(3, 4)
120+
uf.union(4, 5)
121+
122+
# Check components are separate.
123+
assert uf.find(0) == uf.find(1) == uf.find(2)
124+
assert uf.find(3) == uf.find(4) == uf.find(5)
125+
assert uf.find(0) != uf.find(3), "Components should be separate"
126+
127+
# Merge the two components.
128+
uf.union(2, 3)
129+
assert uf.find(0) == uf.find(5), "Components should be merged"
130+
131+
132+
def test_union_find_is_connected():
133+
"""Tests the is_connected convenience method."""
134+
uf = UnionFind(5)
135+
assert not uf.is_connected(0, 1), "Nodes should not be connected initially"
136+
137+
uf.union(0, 1)
138+
assert uf.is_connected(0, 1), "Nodes should be connected after union"
139+
assert not uf.is_connected(0, 2), "Unconnected nodes should return False"
140+
141+
uf.union(1, 2)
142+
assert uf.is_connected(0, 2), "Transitively connected nodes should return True"
143+
144+
145+
# =============================================================================
146+
# MST Algorithm Tests
147+
# =============================================================================
148+
149+
150+
def test_mst_with_injected_union_find(setup_graph: Graph):
151+
"""Tests that the algorithm works with an injected UnionFind instance."""
152+
graph = setup_graph
153+
graph.add_edge(0, 1, 4)
154+
graph.add_edge(0, 6, 7)
155+
graph.add_edge(1, 6, 11)
156+
graph.add_edge(1, 7, 20)
157+
graph.add_edge(1, 2, 9)
158+
graph.add_edge(2, 3, 6)
159+
graph.add_edge(2, 4, 2)
160+
graph.add_edge(3, 4, 10)
161+
graph.add_edge(3, 5, 5)
162+
graph.add_edge(4, 5, 15)
163+
graph.add_edge(4, 7, 1)
164+
graph.add_edge(4, 8, 5)
165+
graph.add_edge(5, 8, 12)
166+
graph.add_edge(6, 7, 1)
167+
graph.add_edge(7, 8, 3)
168+
169+
# Inject a custom UnionFind instance.
170+
union_find = UnionFind(len(graph.vertices))
171+
mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph, union_find)
172+
173+
assert mst_weight == 29, "MST weight should be 29"
174+
assert len(mst_edges) == 8, "MST should have 8 edges for 9 vertices"
175+
176+
55177
def test_mst(setup_graph: Graph):
56178
"""
57179
Tests that the MST has the correct total weight and structure by comparing
@@ -94,33 +216,6 @@ def test_mst(setup_graph: Graph):
94216
)
95217

96218

97-
def test_mst_with_injected_union_find(setup_graph: Graph):
98-
"""Tests that the algorithm works with an injected UnionFind instance."""
99-
graph = setup_graph
100-
graph.add_edge(0, 1, 4)
101-
graph.add_edge(0, 6, 7)
102-
graph.add_edge(1, 6, 11)
103-
graph.add_edge(1, 7, 20)
104-
graph.add_edge(1, 2, 9)
105-
graph.add_edge(2, 3, 6)
106-
graph.add_edge(2, 4, 2)
107-
graph.add_edge(3, 4, 10)
108-
graph.add_edge(3, 5, 5)
109-
graph.add_edge(4, 5, 15)
110-
graph.add_edge(4, 7, 1)
111-
graph.add_edge(4, 8, 5)
112-
graph.add_edge(5, 8, 12)
113-
graph.add_edge(6, 7, 1)
114-
graph.add_edge(7, 8, 3)
115-
116-
# Inject a custom UnionFind instance.
117-
union_find = UnionFind(len(graph.vertices))
118-
mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph, union_find)
119-
120-
assert mst_weight == 29, "MST weight should be 29"
121-
assert len(mst_edges) == 8, "MST should have 8 edges for 9 vertices"
122-
123-
124219
def test_mst_simple_triangle():
125220
"""Tests MST on a simple triangle graph."""
126221
graph = Graph(3)

0 commit comments

Comments
 (0)