-
-
Notifications
You must be signed in to change notification settings - Fork 96
Expand file tree
/
Copy pathDisjointSets.hs
More file actions
62 lines (53 loc) · 1.99 KB
/
DisjointSets.hs
File metadata and controls
62 lines (53 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
{-|
Created by: Ramy-Badr-Ahmed (https://github.com/Ramy-Badr-Ahmed) in Pull Request: #54
https://github.com/TheAlgorithms/Haskell/pull/54
Please mention me (@Ramy-Badr-Ahmed) in any issue or pull request addressing bugs/corrections to this file.
Thank you!
-}
module DataStructures.DisjointSets where
import Data.Array.ST
import Control.Monad.ST
import Data.STRef
-- Disjoint Set Node represented as an index in an array
type Node = Int
-- Union-Find structure
type DisjointSet s = (STArray s Node Node, STArray s Node Int)
-- Initialize the disjoint set with each node being its own parent and rank zero
makeSet :: Int -> ST s (DisjointSet s)
makeSet n = do
parentArray <- newListArray (0, n-1) [0..n-1]
rankArray <- newListArray (0, n-1) (replicate n 0)
return (parentArray, rankArray)
-- Find with path compression
findSet :: DisjointSet s -> Node -> ST s Node
findSet (parentArray, rankArray) x = do
parent <- readArray parentArray x
if parent == x
then return x
else do
root <- findSet (parentArray, rankArray) parent
writeArray parentArray x root
return root
-- Union by rank
unionSet :: DisjointSet s -> Node -> Node -> ST s ()
unionSet (parentArray, rankArray) x y = do
rootX <- findSet (parentArray, rankArray) x
rootY <- findSet (parentArray, rankArray) y
if rootX /= rootY
then do
rankX <- readArray rankArray rootX
rankY <- readArray rankArray rootY
if rankX > rankY
then writeArray parentArray rootY rootX
else if rankX < rankY
then writeArray parentArray rootX rootY
else do
writeArray parentArray rootY rootX
writeArray rankArray rootY (rankY + 1)
else return ()
-- Example usage
example :: Int -> [(Node, Node)] -> [Node] -> [Node]
example n unions finds = runST $ do
ds <- makeSet n
mapM_ (uncurry $ unionSet ds) unions
mapM (findSet ds) finds