Skip to content

Commit bc0aa0f

Browse files
Merge pull request #553 from Shubham-Khetan-2005/Prims
Prims Algorithm in Go
2 parents 80fe6e1 + 269461e commit bc0aa0f

1 file changed

Lines changed: 343 additions & 0 deletions

File tree

go/graph/Prims.go

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
// PrimMST.go
2+
//
3+
// Prim's Algorithm - Minimum Spanning Tree (numeric node IDs)
4+
//
5+
// Description:
6+
// Prim's algorithm finds a minimum spanning tree (MST) of a connected,
7+
// undirected, weighted graph. This implementation uses a min-heap to pick
8+
// the smallest-weight edge crossing the cut and grows the MST until all
9+
// vertices are included (or determines the graph is disconnected).
10+
//
11+
// Purpose / Use cases:
12+
// - Compute MST and its total weight.
13+
// - Useful for network design, clustering approximations, etc.
14+
//
15+
// Approach / Methodology:
16+
// - Represent undirected weighted graph with adjacency list map[int][]Edge.
17+
// - Use a visited set and a min-heap of crossing edges (weight, u, v).
18+
// - Start from `start` node: push its edges, repeatedly pick smallest edge
19+
// connecting into unvisited node, add to MST, push new edges.
20+
// - If all nodes become visited, return MST edges & total weight; otherwise
21+
// return nil to indicate no spanning tree (disconnected graph).
22+
//
23+
// Complexity Analysis:
24+
// - Time: O(E log E) (or O(E log V) more precisely) due to heap operations.
25+
// - Space: O(E + V) for adjacency and heap.
26+
//
27+
// File contents:
28+
// - Graph type and methods (AddEdge, AddNode).
29+
// - Prim(start) returns []MEdge (MST edges) and total weight, or nil if not found.
30+
// - Tests print MST after each test and indicate pass/fail.
31+
//
32+
// Author: (your name)
33+
// Date: (optional)
34+
35+
package main
36+
37+
import (
38+
"container/heap"
39+
"fmt"
40+
"os"
41+
"sort"
42+
"strconv"
43+
)
44+
45+
// MEdge represents an undirected weighted edge (u -- weight -- v).
46+
type MEdge struct {
47+
U, V int
48+
Weight int
49+
}
50+
51+
// Graph represents an undirected weighted graph with integer node IDs.
52+
type Graph struct {
53+
adj map[int][]MEdge // adjacency list: node -> list of edges (neighbors)
54+
}
55+
56+
// NewGraph creates and returns an empty Graph.
57+
func NewGraph() *Graph {
58+
return &Graph{adj: make(map[int][]MEdge)}
59+
}
60+
61+
// AddNode ensures a node entry exists in the adjacency map.
62+
func (g *Graph) AddNode(id int) {
63+
if _, ok := g.adj[id]; !ok {
64+
g.adj[id] = []MEdge{}
65+
}
66+
}
67+
68+
// AddEdge adds an undirected weighted edge between a and b.
69+
// If nodes don't exist yet they are created automatically.
70+
// The order of AddEdge calls may influence deterministic tie-breaking.
71+
func (g *Graph) AddEdge(a, b, w int) {
72+
g.AddNode(a)
73+
g.AddNode(b)
74+
g.adj[a] = append(g.adj[a], MEdge{U: a, V: b, Weight: w})
75+
g.adj[b] = append(g.adj[b], MEdge{U: b, V: a, Weight: w})
76+
}
77+
78+
// -------- min-heap for crossing edges --------
79+
80+
// heapItem is stored in the priority queue. It keeps (weight, from, to)
81+
// and we break ties deterministically by from then to.
82+
type heapItem struct {
83+
weight int
84+
from int
85+
to int
86+
}
87+
88+
// edgeHeap implements heap.Interface ordered by weight, then from, then to.
89+
type edgeHeap []heapItem
90+
91+
func (h edgeHeap) Len() int { return len(h) }
92+
func (h edgeHeap) Less(i, j int) bool {
93+
if h[i].weight != h[j].weight {
94+
return h[i].weight < h[j].weight
95+
}
96+
if h[i].from != h[j].from {
97+
return h[i].from < h[j].from
98+
}
99+
return h[i].to < h[j].to
100+
}
101+
func (h edgeHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
102+
func (h *edgeHeap) Push(x interface{}) {
103+
*h = append(*h, x.(heapItem))
104+
}
105+
func (h *edgeHeap) Pop() interface{} {
106+
old := *h
107+
n := len(old)
108+
it := old[n-1]
109+
*h = old[:n-1]
110+
return it
111+
}
112+
113+
// -------- Prim's algorithm --------
114+
115+
// Prim runs Prim's algorithm starting from `start`.
116+
// Returns (mstEdges, totalWeight).
117+
// If start not present, or the graph is disconnected (no spanning tree), returns (nil, 0).
118+
func (g *Graph) Prim(start int) ([]MEdge, int) {
119+
// start must exist
120+
if _, ok := g.adj[start]; !ok {
121+
return nil, 0
122+
}
123+
124+
visited := make(map[int]bool, len(g.adj))
125+
h := &edgeHeap{}
126+
heap.Init(h)
127+
128+
// helper: push all edges from node u to heap
129+
pushEdges := func(u int) {
130+
for _, e := range g.adj[u] {
131+
if !visited[e.V] {
132+
heap.Push(h, heapItem{weight: e.Weight, from: u, to: e.V})
133+
}
134+
}
135+
}
136+
137+
visited[start] = true
138+
pushEdges(start)
139+
140+
mst := make([]MEdge, 0, len(g.adj)-1)
141+
total := 0
142+
143+
for h.Len() > 0 && len(visited) < len(g.adj) {
144+
it := heap.Pop(h).(heapItem)
145+
if visited[it.to] {
146+
continue
147+
}
148+
// take this edge into MST
149+
mst = append(mst, MEdge{U: it.from, V: it.to, Weight: it.weight})
150+
total += it.weight
151+
visited[it.to] = true
152+
pushEdges(it.to)
153+
}
154+
155+
// Check if all nodes are visited (spanning tree exists)
156+
if len(visited) != len(g.adj) {
157+
return nil, 0
158+
}
159+
return mst, total
160+
}
161+
162+
// -------- helpers for tests and comparison --------
163+
164+
// normalizeEdgeKey returns a canonical string key for an undirected edge+weight
165+
// (min,max,weight) so we can compare MST edge sets ignoring order.
166+
func normalizeEdgeKey(e MEdge) string {
167+
u, v := e.U, e.V
168+
if u > v {
169+
u, v = v, u
170+
}
171+
return fmt.Sprintf("%d-%d-%d", u, v, e.Weight)
172+
}
173+
174+
// edgesEqualSet checks whether two edge slices represent the same undirected set
175+
// (order-insensitive). Nil == Nil; nil != empty slice.
176+
func edgesEqualSet(a []MEdge, b []MEdge) bool {
177+
if a == nil && b == nil {
178+
return true
179+
}
180+
if (a == nil) != (b == nil) {
181+
return false
182+
}
183+
if len(a) != len(b) {
184+
return false
185+
}
186+
m := make(map[string]int)
187+
for _, e := range a {
188+
m[normalizeEdgeKey(e)]++
189+
}
190+
for _, e := range b {
191+
k := normalizeEdgeKey(e)
192+
if m[k] == 0 {
193+
return false
194+
}
195+
m[k]--
196+
}
197+
for _, v := range m {
198+
if v != 0 {
199+
return false
200+
}
201+
}
202+
return true
203+
}
204+
205+
// sortEdgesForPrint returns a stable, human-friendly ordering for printing (u,v,w) by u,v,w.
206+
func sortEdgesForPrint(edges []MEdge) []MEdge {
207+
cp := make([]MEdge, len(edges))
208+
copy(cp, edges)
209+
sort.Slice(cp, func(i, j int) bool {
210+
ui, vi := cp[i].U, cp[i].V
211+
uj, vj := cp[j].U, cp[j].V
212+
// normalize order for comparison but keep stored U,V as-is for clarity
213+
if ui == uj {
214+
if vi == vj {
215+
return cp[i].Weight < cp[j].Weight
216+
}
217+
return vi < vj
218+
}
219+
return ui < uj
220+
})
221+
return cp
222+
}
223+
224+
// printMST prints MST edges and total weight in a readable way.
225+
func printMST(mst []MEdge, total int) {
226+
if mst == nil {
227+
fmt.Printf("MST: nil (start missing or graph disconnected)\n")
228+
return
229+
}
230+
if len(mst) == 0 {
231+
fmt.Printf("MST: (no edges) total weight = %d\n", total)
232+
return
233+
}
234+
s := sortEdgesForPrint(mst)
235+
fmt.Printf("MST edges (u --w--> v):\n")
236+
for _, e := range s {
237+
fmt.Printf(" %d --%d--> %d\n", e.U, e.Weight, e.V)
238+
}
239+
fmt.Printf("Total weight = %d\n", total)
240+
}
241+
242+
// expect checks result against expected and prints pass/fail (and MST).
243+
func expect(got []MEdge, gotTotal int, expected []MEdge, expectedTotal int, testName string) {
244+
fmt.Printf("%s - Computed MST:\n", testName)
245+
printMST(got, gotTotal)
246+
fmt.Println("Expected MST:")
247+
printMST(expected, expectedTotal)
248+
249+
pass := edgesEqualSet(got, expected) && (got == nil && expected == nil || gotTotal == expectedTotal)
250+
if pass {
251+
fmt.Printf("[PASS] %s\n\n", testName)
252+
} else {
253+
fmt.Printf("[FAIL] %s\n\n", testName)
254+
}
255+
}
256+
257+
// runTests builds small weighted graphs and runs deterministic tests.
258+
func runTests() {
259+
fmt.Println("Prim's Algorithm Tests (numeric nodes)\n")
260+
261+
// Test Graph 1:
262+
// Nodes: 1..6
263+
// edges:
264+
// 1-2:3, 1-3:1, 2-3:7, 2-4:5, 3-4:2, 3-5:4, 4-5:6, 4-6:8, 5-6:9
265+
// Known MST (one valid MST): edges
266+
// (1-3,1), (3-4,2), (1-2,3), (3-5,4), (4-6,8) total = 18
267+
g1 := NewGraph()
268+
g1.AddEdge(1, 2, 3)
269+
g1.AddEdge(1, 3, 1)
270+
g1.AddEdge(2, 3, 7)
271+
g1.AddEdge(2, 4, 5)
272+
g1.AddEdge(3, 4, 2)
273+
g1.AddEdge(3, 5, 4)
274+
g1.AddEdge(4, 5, 6)
275+
g1.AddEdge(4, 6, 8)
276+
g1.AddEdge(5, 6, 9)
277+
278+
expected1 := []MEdge{
279+
{U: 1, V: 3, Weight: 1},
280+
{U: 3, V: 4, Weight: 2},
281+
{U: 1, V: 2, Weight: 3},
282+
{U: 3, V: 5, Weight: 4},
283+
{U: 4, V: 6, Weight: 8},
284+
}
285+
got1, tot1 := g1.Prim(1)
286+
expect(got1, tot1, expected1, 18, "Test 1: sample graph, start=1")
287+
288+
// Test 2: same graph, start at 3 -> MST should be same set & total
289+
got2, tot2 := g1.Prim(3)
290+
expect(got2, tot2, expected1, 18, "Test 2: sample graph, start=3")
291+
292+
// Test 3: disconnected graph -> no spanning tree (nil expected)
293+
// component A: 1--2 (w=1)
294+
// component B: 3--4 (w=2)
295+
g2 := NewGraph()
296+
g2.AddEdge(1, 2, 1)
297+
g2.AddEdge(3, 4, 2)
298+
got3, tot3 := g2.Prim(1)
299+
expect(got3, tot3, nil, 0, "Test 3: disconnected graph (expect nil)")
300+
301+
// Test 4: single isolated node (node exists but no edges) -> MST is empty edges, total=0
302+
g3 := NewGraph()
303+
g3.AddNode(7)
304+
got4, tot4 := g3.Prim(7)
305+
expect(got4, tot4, []MEdge{}, 0, "Test 4: single isolated node => empty MST")
306+
307+
// Test 5: start missing => nil
308+
got5, tot5 := g1.Prim(99)
309+
expect(got5, tot5, nil, 0, "Test 5: start missing => nil")
310+
311+
fmt.Println("Tests completed.")
312+
}
313+
314+
func main() {
315+
// CLI: if an integer arg provided, run Prim on the sample graph and print MST.
316+
if len(os.Args) > 1 {
317+
startStr := os.Args[1]
318+
start, err := strconv.Atoi(startStr)
319+
if err != nil {
320+
fmt.Printf("Invalid start node: %q. Provide integer node id.\n", startStr)
321+
return
322+
}
323+
// build sample graph (same as Test 1)
324+
g := NewGraph()
325+
g.AddEdge(1, 2, 3)
326+
g.AddEdge(1, 3, 1)
327+
g.AddEdge(2, 3, 7)
328+
g.AddEdge(2, 4, 5)
329+
g.AddEdge(3, 4, 2)
330+
g.AddEdge(3, 5, 4)
331+
g.AddEdge(4, 5, 6)
332+
g.AddEdge(4, 6, 8)
333+
g.AddEdge(5, 6, 9)
334+
335+
mst, total := g.Prim(start)
336+
fmt.Printf("Prim's MST starting at %d:\n", start)
337+
printMST(mst, total)
338+
return
339+
}
340+
341+
// default: run tests
342+
runTests()
343+
}

0 commit comments

Comments
 (0)