Skip to content

Commit 5b34651

Browse files
committed
Split Graph class to be separate to Boruvka's MST function
1 parent f0f850d commit 5b34651

1 file changed

Lines changed: 122 additions & 190 deletions

File tree

src/boruvkas_algorithm/boruvka.py

Lines changed: 122 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,23 @@
77

88

99
class Graph:
10-
def __init__(self, num_vertices: int):
10+
"""A graph that contains nodes and edges."""
11+
12+
def __init__(self, num_vertices: int) -> None:
1113
"""
14+
Initialises the graph with a given number of vertices.
15+
1216
Args:
17+
num_nodes: The number of nodes to generate in the graph.
1318
num_vertices: The number of vertices to generate in the graph.
1419
"""
1520
self.vertices: list[int] = list(range(num_vertices))
1621
# [(node1, node2, weight)]
1722
self.edges: list[tuple[int, int, int]] = []
18-
# Each node is its own parent initially.
19-
self.parent: list[int] = list(range(num_vertices))
20-
# Each tree has size 1 (itself) initially.
21-
self.rank: list[int] = [1] * num_vertices
2223

2324
def add_edge(self, node1: int, node2: int, weight: int) -> None:
2425
"""
25-
Add an edge to the graph.
26+
Adds an edge to the graph.
2627
2728
Args:
2829
node1: The first node of the edge.
@@ -37,15 +38,60 @@ def add_edge(self, node1: int, node2: int, weight: int) -> None:
3738
self.edges.append((node1, node2, weight))
3839

3940
def print_graph_info(self) -> None:
40-
"""
41-
Print the graph's vertices and edges.
42-
"""
41+
"""Print the graph's vertices and edges."""
4342
print(f"Vertices: {self.vertices}")
4443
print("Edges (node1, node2, weight):")
4544
for edge in sorted(self.edges):
4645
print(f" {edge}")
4746

48-
def find(self, node: int) -> int:
47+
def draw_mst(self, mst_edges: list[tuple[int, int, int]]) -> None:
48+
"""
49+
Draw the graph with the minimum spanning tree highlighted using
50+
networkx.
51+
52+
Args:
53+
mst_edges: A list of edges in the minimum spanning tree.
54+
"""
55+
G = nx.Graph()
56+
# Add nodes to the graph.
57+
G.add_nodes_from(self.vertices)
58+
# Add all edges to the graph with weights.
59+
for edge in self.edges:
60+
node1, node2, weight = edge
61+
G.add_edge(node1, node2, weight=weight)
62+
pos = nx.spring_layout(G)
63+
# Draw the graph edges and highlight the edges in the MST in red.
64+
nx.draw_networkx_edges(
65+
G, pos, edgelist=self.edges, edge_color="gray", alpha=0.5
66+
)
67+
nx.draw_networkx_edges(G, pos, edgelist=mst_edges, edge_color="red", width=2)
68+
# Draw the graph nodes and labels.
69+
nx.draw_networkx_nodes(G, pos, node_size=700, node_color="lightblue")
70+
nx.draw_networkx_labels(G, pos)
71+
nx.draw_networkx_edge_labels(
72+
G, pos, edge_labels={(u, v): d["weight"] for u, v, d in G.edges(data=True)}
73+
)
74+
75+
plt.title("Graph with Minimum Spanning Tree Highlighted")
76+
plt.axis("off")
77+
plt.show()
78+
79+
80+
def find_mst_with_boruvkas_algorithm(
81+
graph: Graph,
82+
) -> tuple[int, list[tuple[int, int, int]]]:
83+
"""
84+
Finds the minimum spanning tree (MST) of a graph using Boruvka's algorithm.
85+
86+
Args:
87+
graph: The graph to find the MST of.
88+
89+
Returns:
90+
A tuple containing the total weight of the MST and a list of the
91+
edges in the MST.
92+
"""
93+
94+
def find(node: int) -> int:
4995
"""
5096
Finds the root parent of the node using path compression.
5197
@@ -55,17 +101,16 @@ def find(self, node: int) -> int:
55101
Returns:
56102
The root parent of the node.
57103
"""
58-
cur_parent = self.parent[node]
59-
while cur_parent != self.parent[cur_parent]:
104+
cur_parent = parent[node]
105+
while cur_parent != parent[cur_parent]:
60106
# Compress the links as we go up the chain of parents to make
61107
# it faster to traverse in the future - amortised O(a(n)) time,
62108
# where a(n) is the inverse Ackermann function.
63-
self.parent[cur_parent] = self.parent[self.parent[cur_parent]]
64-
cur_parent = self.parent[cur_parent]
65-
109+
parent[cur_parent] = parent[parent[cur_parent]]
110+
cur_parent = parent[cur_parent]
66111
return cur_parent
67112

68-
def union(self, node1: int, node2: int) -> bool:
113+
def union(node1: int, node2: int) -> bool:
69114
"""
70115
Combines the two nodes into the larger segment.
71116
@@ -77,196 +122,82 @@ def union(self, node1: int, node2: int) -> bool:
77122
True if the nodes were combined, False if they were already in the
78123
same segment.
79124
"""
80-
root1 = self.find(node1)
81-
root2 = self.find(node2)
82-
# If they have the same root parent, a cycle exists.
125+
root1 = find(node1)
126+
root2 = find(node2)
127+
# If they have the same root parent, they're already connected.
83128
if root1 == root2:
84129
return False
85130

86131
# Combine the two nodes into the larger segment based on the rank.
87-
if self.rank[root1] > self.rank[root2]:
88-
self.parent[root2] = root1
89-
self.rank[root1] += self.rank[root2]
132+
if rank[root1] > rank[root2]:
133+
parent[root2] = root1
134+
rank[root1] += rank[root2]
90135
else:
91-
self.parent[root1] = root2
92-
self.rank[root2] += self.rank[root1]
93-
136+
parent[root1] = root2
137+
rank[root2] += rank[root1]
94138
return True
95139

96-
def update_min_edge_per_component(self, min_connecting_edge_per_component: list):
97-
"""
98-
Check each edge and update the shortest edge for each node if it
99-
connects two components together.
140+
num_vertices = len(graph.vertices)
141+
# Each node is its own parent initially.
142+
parent: list[int] = list(range(num_vertices))
143+
# Each tree has size 1 (itself) initially.
144+
rank: list[int] = [1] * num_vertices
145+
146+
print("\nFinding MST with Boruvka's algorithm:")
147+
graph.print_graph_info()
148+
149+
mst_weight = 0
150+
mst_edges: list[tuple[int, int, int]] = []
151+
num_components = num_vertices
152+
num_iterations = 0
153+
154+
# Keep connecting components until only one component remains.
155+
while num_components > 1:
156+
num_iterations += 1
157+
print(
158+
f"\nIteration {num_iterations}:\nCurrent MST edges: {mst_edges}\n"
159+
f"Current MST Weight: {mst_weight}"
160+
)
100161

101-
Args:
102-
min_connecting_edge_per_component: A list with the shortest edge
103-
for each node that connects to
104-
a new component.
105-
"""
106-
for edge in self.edges:
162+
# Find the minimum connecting edge for each component.
163+
min_edge_per_component: list[tuple[int, int, int] | None] = [
164+
None
165+
] * num_vertices
166+
for edge in graph.edges:
107167
node1, node2, weight = edge
108-
node1_component = self.find(node1)
109-
node2_component = self.find(node2)
110-
111-
# If the vertices are in different components and the edge is
112-
# smaller than the current minimum weight edge for either
113-
# component, update them.
114-
if node1_component != node2_component:
115-
if (
116-
not min_connecting_edge_per_component[node1_component]
117-
or weight < min_connecting_edge_per_component[node1_component][2]
118-
):
119-
min_connecting_edge_per_component[node1_component] = edge
120-
121-
if (
122-
not min_connecting_edge_per_component[node2_component]
123-
or weight < min_connecting_edge_per_component[node2_component][2]
124-
):
125-
min_connecting_edge_per_component[node2_component] = edge
126-
127-
def connect_components_with_min_edges(
128-
self,
129-
min_connecting_edge_per_component: list,
130-
mst_edges: list[tuple[int, int, int]],
131-
mst_weight: int,
132-
num_components: int,
133-
) -> tuple[int, int]:
134-
"""
135-
Connect components using the minimum connecting edges.
136-
137-
Args:
138-
min_connecting_edge_per_component: List storing the shortest edge
139-
for each component.
140-
mst_edges: List of edges in the minimum spanning tree.
141-
mst_weight: Total weight of the minimum spanning tree.
142-
num_components: Total number of components in the graph.
143-
144-
Returns:
145-
Tuple containing the updated MST weight and number of components.
146-
"""
147-
for edge in min_connecting_edge_per_component:
168+
comp1, comp2 = find(node1), find(node2)
169+
170+
if comp1 != comp2:
171+
current_min1 = min_edge_per_component[comp1]
172+
if current_min1 is None or weight < current_min1[2]:
173+
min_edge_per_component[comp1] = edge
174+
current_min2 = min_edge_per_component[comp2]
175+
if current_min2 is None or weight < current_min2[2]:
176+
min_edge_per_component[comp2] = edge
177+
178+
# Connect components using the minimum connecting edges.
179+
for edge in min_edge_per_component:
148180
if edge is not None:
149181
node1, node2, weight = edge
150-
if self.find(node1) != self.find(node2):
151-
mst_edges.append((node1, node2, weight))
182+
if find(node1) != find(node2):
183+
mst_edges.append(edge)
152184
mst_weight += weight
153-
self.union(node1, node2)
185+
union(node1, node2)
154186
num_components -= 1
155187
print(f"Added edge {node1} - {node2} with weight {weight} to MST.")
156188

157-
return mst_weight, num_components
158-
159-
def perform_iteration(
160-
self,
161-
num_components: int,
162-
mst_edges: list[tuple[int, int, int]],
163-
mst_weight: int,
164-
):
165-
"""
166-
Perform one iteration of Boruvka's algorithm, finding the minimum
167-
connecting edge for each component and connecting components using
168-
these edges.
169-
170-
Args:
171-
num_components: Total number of components in the graph.
172-
mst_edges: List of edges in the minimum spanning tree so far.
173-
mst_weight: Total weight of the minimum spanning tree so far.
174-
175-
Returns:
176-
Tuple containing the updated MST weight and number of components.
177-
"""
178-
# Initialize list to store minimum connecting edge for each component.
179-
min_connecting_edge_per_component = [None] * len(self.vertices)
180-
# Update the minimum connecting edge for each component.
181-
self.update_min_edge_per_component(min_connecting_edge_per_component)
182-
# Connect components using the minimum connecting edges and update MST
183-
# weight and number of components.
184-
mst_weight, num_components = self.connect_components_with_min_edges(
185-
min_connecting_edge_per_component,
186-
mst_edges,
187-
mst_weight,
188-
num_components,
189-
)
190-
191-
return mst_weight, num_components
192-
193-
def draw_mst(self, mst_edges: list[tuple[int, int, int]]) -> None:
194-
"""
195-
Draw the graph with the minimum spanning tree highlighted using
196-
networkx.
197-
198-
Args:
199-
mst_edges: A list of edges in the minimum spanning tree.
200-
"""
201-
G = nx.Graph()
202-
# Add nodes to the graph.
203-
G.add_nodes_from(self.vertices)
204-
# Add all edges to the graph with weights.
205-
for edge in self.edges:
206-
node1, node2, weight = edge
207-
G.add_edge(node1, node2, weight=weight)
208-
pos = nx.spring_layout(G)
209-
# Draw the graph edges and highlight the edges in the MST in red.
210-
nx.draw_networkx_edges(
211-
G, pos, edgelist=self.edges, edge_color="gray", alpha=0.5
212-
)
213-
nx.draw_networkx_edges(G, pos, edgelist=mst_edges, edge_color="red", width=2)
214-
# Draw the graph nodes and labels.
215-
nx.draw_networkx_nodes(G, pos, node_size=700, node_color="lightblue")
216-
nx.draw_networkx_labels(G, pos)
217-
nx.draw_networkx_edge_labels(
218-
G, pos, edge_labels={(u, v): d["weight"] for u, v, d in G.edges(data=True)}
219-
)
220-
221-
plt.title("Graph with Minimum Spanning Tree Highlighted")
222-
plt.axis("off")
223-
plt.show()
189+
# Summarise the MST found.
190+
print("\nMST found with Boruvka's algorithm.")
191+
print("MST edges (node1, node2, weight):")
192+
for edge in sorted(mst_edges):
193+
print(f" {edge}")
194+
print(f"MST weight: {mst_weight}")
224195

225-
def run_boruvkas_algorithm(self):
226-
"""
227-
Find the minimum spanning tree (MST) of the graph using Boruvka's
228-
algorithm.
229-
230-
Returns:
231-
A tuple containing the total weight of the MST and a list of the
232-
edges in the MST.
233-
"""
234-
print("\nFinding MST with Boruvka's algorithm:")
235-
self.print_graph_info()
236-
mst_weight = 0
237-
mst_edges = []
238-
num_components = len(self.vertices)
239-
# Track the number of iterations.
240-
num_iterations = 0
241-
242-
# Keep connecting components until only one component remains.
243-
while num_components > 1:
244-
num_iterations += 1
245-
print(
246-
f"\nIteration {num_iterations}:\nCurrent MST edges: {mst_edges}\n"
247-
f"Current MST Weight: {mst_weight}"
248-
)
249-
# Perform one iteration of the algorithm.
250-
mst_weight, num_components = self.perform_iteration(
251-
num_components,
252-
mst_edges,
253-
mst_weight,
254-
)
255-
256-
# Summarise the MST found.
257-
print("\nMST found with Boruvka's algorithm.")
258-
print("MST edges (node1, node2, weight):")
259-
for edge in sorted(mst_edges):
260-
print(f" {edge}")
261-
print(f"MST weight: {mst_weight}")
262-
263-
return mst_weight, mst_edges
196+
return mst_weight, mst_edges
264197

265198

266-
def main():
267-
"""
268-
Run Boruvka's algorithm on an example graph.
269-
"""
199+
def run_boruvka_example():
200+
"""Runs Boruvka's algorithm on an example graph."""
270201
graph = Graph(9)
271202
graph.add_edge(0, 1, 4)
272203
graph.add_edge(0, 6, 7)
@@ -283,10 +214,11 @@ def main():
283214
graph.add_edge(5, 8, 12)
284215
graph.add_edge(6, 7, 1)
285216
graph.add_edge(7, 8, 3)
286-
_, mst_edges = graph.run_boruvkas_algorithm()
217+
218+
_, mst_edges = find_mst_with_boruvkas_algorithm(graph)
287219
# Draw the graph with the minimum spanning tree highlighted.
288220
graph.draw_mst(mst_edges)
289221

290222

291223
if __name__ == "__main__":
292-
main()
224+
run_boruvka_example()

0 commit comments

Comments
 (0)