tests: add base_with.nit
[nit.git] / lib / standard / collection / union_find.nit
1 # This file is part of NIT ( http://www.nitlanguage.org ).
2 #
3 # This file is free software, which comes along with NIT. This software is
4 # distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
5 # without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
6 # PARTICULAR PURPOSE. You can modify it is you want, provided this header
7 # is kept unaltered, and a notification of the changes is added.
8 # You are allowed to redistribute it and sell it, alone or is a part of
9 # another product.
10
11 # union–find algorithm using an efficient disjoint-set data structure
12 module union_find
13
14 import hash_collection
15
16 # Data structure to keeps track of elements partitioned into disjoint subsets
17 # var s = new DisjointSet[Int]
18 # s.add(1)
19 # s.add(2)
20 # assert not s.in_same_subset(1,2)
21 # s.union(1,2)
22 # assert s.in_same_subset(1,2)
23 #
24 # `in_same_subset` is transitive, reflexive and symetric
25 #
26 # s.add(3)
27 # assert not s.in_same_subset(1,3)
28 # s.union(3,2)
29 # assert s.in_same_subset(1,3)
30 #
31 # Unkike theorical Disjoint-set data structures, the underling implementation is opaque
32 # that makes the traditionnal `find` method unavailable for clients.
33 # The methods `in_same_subset`, `to_partitions`, and their variations are offered instead.
34 class DisjointSet[E]
35 super SimpleCollection[E]
36
37 # The node in the hiearchical structure for each element
38 private var nodes = new HashMap[E, DisjointSetNode]
39
40 # The number of subsets in the partition
41 #
42 # var s = new DisjointSet[Int]
43 # s.add_all([1,2,3,4,5])
44 # assert s.number_of_subsets == 5
45 # s.union_all([1,4,5])
46 # assert s.number_of_subsets == 3
47 # s.union(4,5)
48 # assert s.number_of_subsets == 3
49 var number_of_subsets: Int = 0
50
51 # Get the root node of an element
52 # require: `has(e)`
53 private fun find(e:E): DisjointSetNode
54 do
55 assert nodes.has_key(e)
56 var ne = nodes[e]
57 if ne.parent == ne then return ne
58 var res = nfind(ne)
59 nodes[e] = res
60 return res
61 end
62
63 # Get the root node of a node
64 # Use *path compression* to flatten the structure
65 # ENSURE: `result.parent == result`
66 private fun nfind(ne: DisjointSetNode): DisjointSetNode
67 do
68 var nf = ne.parent
69 if nf == ne then return ne
70 var ng = nfind(nf)
71 ne.parent = ng
72 return ng
73 end
74
75 # Is the element in the structure
76 #
77 # var s = new DisjointSet[Int]
78 # assert not s.has(1)
79 # s.add(1)
80 # assert s.has(1)
81 # assert not s.has(2)
82 redef fun has(e: E): Bool
83 do
84 return nodes.has_key(e)
85 end
86
87 redef fun iterator do return nodes.keys.iterator
88
89 # Add a new element in the structure.
90 # Initially it is in its own disjoint subset
91 #
92 # ENSURE: `has(e)`
93 redef fun add(e:E)
94 do
95 if nodes.has_key(e) then return
96 var ne = new DisjointSetNode
97 nodes[e] = ne
98 number_of_subsets += 1
99 end
100
101 # Are two elements in the same subset?
102 fun in_same_subset(e,f:E): Bool
103 do
104 return e == f or find(e) == find(f)
105 end
106
107 # Are all elements of `es` in the same subset?
108 # var s = new DisjointSet[Int]
109 # s.add_all([1,2,3,4,5,6])
110 # s.union_all([1,2,3])
111 # assert not s.all_in_same_subset([2,3,4])
112 # s.union_all([1,4,5])
113 # assert s.all_in_same_subset([2,3,4])
114 fun all_in_same_subset(es: Collection[E]): Bool
115 do
116 if es.is_empty then return true
117 var nf = find(es.first)
118 for e in es do
119 var ne = find(e)
120 if ne != nf then return false
121 end
122 return true
123 end
124
125 # Construct the current partitionning
126 #
127 # var s = new DisjointSet[Int]
128 # s.add_all([1,2,3,4,5,6])
129 # s.union(1,2)
130 # s.union(1,3)
131 # s.union(4,5)
132 # var p = s.to_partitions
133 # assert p.length == 3
134 fun to_partitions: Collection[Set[E]]
135 do
136 return to_subpartition(self)
137 end
138
139 # Construct a partitionning on `es`, a subset of elements
140 #
141 # var s = new DisjointSet[Int]
142 # s.add_all([1,2,3,4,5,6])
143 # s.union(1,2)
144 # s.union(1,3)
145 # s.union(4,5)
146 # var p = s.to_subpartition([1,2,4])
147 # assert p.length == 2
148 fun to_subpartition(es: Collection[E]): Collection[Set[E]]
149 do
150 var map = new HashMap[DisjointSetNode, Set[E]]
151 for e in es do
152 var ne = find(e)
153 var set = map.get_or_null(ne)
154 if set == null then
155 set = new HashSet[E]
156 map[ne] = set
157 end
158 set.add(e)
159 end
160 return map.values
161 end
162
163 # Combine the subsets of `e` and `f`
164 # ENSURE: `in_same_subset(e,f)`
165 fun union(e,f:E)
166 do
167 var ne = find(e)
168 var nf = find(f)
169 if ne == nf then return
170
171 # merge them using *union by rank*
172 # attach the smaller tree to the root of the deeper tree
173 var er = ne.rank
174 var fr = nf.rank
175 if er < fr then
176 ne.parent = nf
177 nodes[e] = nf
178 else
179 nf.parent = ne
180 nodes[f] = ne
181 if er == fr then
182 # The only case where the deep is increased is when both are equals
183 ne.rank = er+1
184 end
185 end
186 number_of_subsets -= 1
187 end
188
189 # Combine the subsets of all elements of `es`
190 # ENSURE: `all_in_same_subset(cs)`
191 fun union_all(es:Collection[E])
192 do
193 if es.is_empty then return
194 var f = es.first
195 for e in es do union(e,f)
196 end
197 end
198
199 # A node in the hierarchical representation of subsets
200 private class DisjointSetNode
201 # If parent == self then the node is a root
202 var parent: DisjointSetNode = self
203
204 # The rank to no desequilibrate the structure.
205 # The term rank is used instead of depth since
206 # path compression is used, see `DisjointSet::nfind`
207 var rank = 0
208 end