tests: update and improve tests for nitunit
[nit.git] / lib / standard / collection / union_find.nit
1 # This file is part of NIT ( http://www.nitlanguage.org ).
2 #
3 # This file is free software, which comes along with NIT. This software is
4 # distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
5 # without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
6 # PARTICULAR PURPOSE. You can modify it is you want, provided this header
7 # is kept unaltered, and a notification of the changes is added.
8 # You are allowed to redistribute it and sell it, alone or is a part of
9 # another product.
10
11 # union–find algorithm using an efficient disjoint-set data structure
12 module union_find
13
14 import hash_collection
15
16 # Data structure to keeps track of elements partitioned into disjoint subsets
17 # var s = new DisjointSet[Int]
18 # s.add(1)
19 # s.add(2)
20 # assert not s.in_same_subset(1,2)
21 # s.union(1,2)
22 # assert s.in_same_subset(1,2)
23 #
24 # `in_same_subset` is transitive, reflexive and symetric
25 #
26 # s.add(3)
27 # assert not s.in_same_subset(1,3)
28 # s.union(3,2)
29 # assert s.in_same_subset(1,3)
30 #
31 # Unkike theorical Disjoint-set data structures, the underling implementation is opaque
32 # that makes the traditionnal `find` method unavailable for clients.
33 # The methods `in_same_subset`, `to_partitions`, and their variations are offered instead.
34 class DisjointSet[E: Object]
35 super SimpleCollection[E]
36
37 # The node in the hiearchical structure for each element
38 private var nodes = new HashMap[E, DisjointSetNode]
39
40 # Get the root node of an element
41 # require: `has(e)`
42 private fun find(e:E): DisjointSetNode
43 do
44 assert nodes.has_key(e)
45 var ne = nodes[e]
46 if ne.parent == ne then return ne
47 var res = nfind(ne)
48 nodes[e] = res
49 return res
50 end
51
52 # Get the root node of a node
53 # Use *path compression* to flatten the structure
54 # ENSURE: `result.parent == result`
55 private fun nfind(ne: DisjointSetNode): DisjointSetNode
56 do
57 var nf = ne.parent
58 if nf == ne then return ne
59 var ng = nfind(nf)
60 ne.parent = ng
61 return ng
62 end
63
64 # Is the element in the structure
65 #
66 # var s = new DisjointSet[Int]
67 # assert not s.has(1)
68 # s.add(1)
69 # assert s.has(1)
70 # assert not s.has(2)
71 redef fun has(e: E): Bool
72 do
73 return nodes.has_key(e)
74 end
75
76 redef fun iterator do return nodes.keys.iterator
77
78 # Add a new element in the structure.
79 # Initially it is in its own disjoint subset
80 #
81 # ENSURE: `has(e)`
82 redef fun add(e:E)
83 do
84 if nodes.has_key(e) then return
85 var ne = new DisjointSetNode
86 nodes[e] = ne
87 end
88
89 # Are two elements in the same subset?
90 fun in_same_subset(e,f:E): Bool
91 do
92 return e == f or find(e) == find(f)
93 end
94
95 # Are all elements of `es` in the same subset?
96 # var s = new DisjointSet[Int]
97 # s.add_all([1,2,3,4,5,6])
98 # s.union_all([1,2,3])
99 # assert not s.all_in_same_subset([2,3,4])
100 # s.union_all([1,4,5])
101 # assert s.all_in_same_subset([2,3,4])
102 fun all_in_same_subset(es: Collection[E]): Bool
103 do
104 if es.is_empty then return true
105 var nf = find(es.first)
106 for e in es do
107 var ne = find(e)
108 if ne != nf then return false
109 end
110 return true
111 end
112
113 # Construct the current partitionning
114 #
115 # var s = new DisjointSet[Int]
116 # s.add_all([1,2,3,4,5,6])
117 # s.union(1,2)
118 # s.union(1,3)
119 # s.union(4,5)
120 # var p = s.to_partitions
121 # assert p.length == 3
122 fun to_partitions: Collection[Set[E]]
123 do
124 return to_subpartition(self)
125 end
126
127 # Construct a partitionning on `es`, a subset of elements
128 #
129 # var s = new DisjointSet[Int]
130 # s.add_all([1,2,3,4,5,6])
131 # s.union(1,2)
132 # s.union(1,3)
133 # s.union(4,5)
134 # var p = s.to_subpartition([1,2,4])
135 # assert p.length == 2
136 fun to_subpartition(es: Collection[E]): Collection[Set[E]]
137 do
138 var map = new HashMap[DisjointSetNode, Set[E]]
139 for e in es do
140 var ne = find(e)
141 var set = map.get_or_null(ne)
142 if set == null then
143 set = new HashSet[E]
144 map[ne] = set
145 end
146 set.add(e)
147 end
148 return map.values
149 end
150
151 # Combine the subsets of `e` and `f`
152 # ENSURE: `in_same_subset(e,f)`
153 fun union(e,f:E)
154 do
155 var ne = find(e)
156 var nf = find(f)
157 if ne == nf then return
158
159 # merge them using *union by rank*
160 # attach the smaller tree to the root of the deeper tree
161 var er = ne.rank
162 var fr = nf.rank
163 if er < fr then
164 ne.parent = nf
165 nodes[e] = nf
166 else
167 nf.parent = ne
168 nodes[f] = ne
169 if er == fr then
170 # The only case where the deep is increased is when both are equals
171 ne.rank = er+1
172 end
173 end
174 end
175
176 # Combine the subsets of all elements of `es`
177 # ENSURE: `all_in_same_subset(cs)`
178 fun union_all(es:Collection[E])
179 do
180 if es.is_empty then return
181 var f = es.first
182 for e in es do union(e,f)
183 end
184 end
185
186 # A node in the hierarchical representation of subsets
187 private class DisjointSetNode
188 # If parent == self then the node is a root
189 var parent: DisjointSetNode = self
190
191 # The rank to no desequilibrate the structure.
192 # The term rank is used instead of depth since
193 # path compression is used, see `DisjointSet::nfind`
194 var rank = 0
195 end