22. 并查集
并查集 是一种树形的数据结构,顾名思义,它用于处理一些不相交集合(Disjoint Set)的合并及查询问题。它支持两种操作:
合并(Union):把两个不相交的集合合并为一个集合。
查询(Find):查询两个元素是否在同一个集合中。
并查集不支持集合的分离,但是并查集在经过修改后可以支持集合中单个元素的删除操作。
并查集的重要思想在于,用集合中的一个元素(根节点)代表集合。
22.1. 简单版本
初始化
假设有 \(n\) 个元素,用一个数组 parent[] 来存储每个元素的父节点;初始时,将它们的父节点设为自己。
1int parent[MAXN];
2inline void init(const int n)
3{
4 for(int i = 0; i <= n; ++i) parent[i] = i;
5}
查询
用递归的写法实现对 代表元素 的查询:层层向上访问父节点,直至根节点(根节点的标志就是:父节点是本身)。要判断两个元素是否属于同一个集合,只需要看它们的根节点是否相同即可。
1inline int find(const int x)
2{
3 if(parent[x] == x) return x;
4 else return find(parent[x]);
5}
合并
先找到两个集合的代表元素,然后将前者的父节点设为后者即可(当然也可以将后者的父节点设为前者)。
1inline void union(const int x, const int y)
2{
3 parent[find(x)] = find(y);
4}
22.2. 路径压缩
简单版本的并查集效率是比较低的,因为集合合并可能会导致树结构深度越来越深,想要从底部找到根节点代价会变得越来越大。
既然我们只关心一个元素对应的根节点,那我们希望每个元素到根节点的路径尽可能短(最好只需要一步)。只要我们在查询的过程中,把沿途的每个节点的父节点都设为根节点即可。这样一来,下次查询的效率就很高。
1inline int find(const int x)
2{
3 if(parent[x] == x) return x;
4 else
5 {
6 parent[x] = find(parent[x]);
7 return parent[x];
8 }
9}
22.3. 启发式合并
合并可能会使树的深度(树中最长链的长度)加深,原来的树中每个元素到根节点的距离都变长了,之后寻找根节点的路径也就会相应变长。虽然有路径压缩,但路径压缩也是会消耗时间的。
启发式合并方法:把简单的树往复杂的树上合并。因为这样合并后,到根节点距离变长的节点个数比较少。
用一个数组 rank[] 记录每个根节点对应的树的深度(非根节点的 rank 相当于以它为根节点的子树的深度)。初始时,把所有元素的 rank(秩)设为 1;合并时把 rank 较小的树往较大的树上合并。
1inline void init(const int n)
2{
3 for(int i = 0; i <= n; ++i)
4 {
5 parent[i] = i;
6 rank[i] = i;
7 }
8}
9
10inline void union(const int x, const int y)
11{
12 const int rx = find(x);
13 const int ry = find(y);
14 if(rank[rx] <= rank[ry]) parent[rx] = ry;
15 else parent[ry] = rx;
16 if(rank[rx] == rank[ry] && rx != ry) rank[ry]++; // 如果深度相同且根节点不同,则新的根节点的深度 +1
17}
由于每一次查询都是对树的一次重构,会把叶节点以及其所有的祖先全部变成根节点的子节点,因此 rank 会失真,无法反应真实的树高。还有一种启发式合并方法是:把节点少的树往节点多的树上合并。
22.4. 复杂度
简单来说,对于有 \(n\) 个元素的并查集,空间复杂度是 \(\mathcal{O}(n)\) ;\(m\) 次合并、查询操作的摊还时间是 \(\mathcal{O}(m \log^* n)\),其中 \(\log^*\) 是迭代对数( Iterated Logarithm )。
22.5. Python 参考代码
\(\color{darkgreen}{Code}\)
1"""
2A union-find disjoint set data structure.
3"""
4
5# 2to3 sanity
6from __future__ import (
7 absolute_import, division, print_function, unicode_literals,
8)
9
10# Third-party libraries
11import numpy as np
12
13
14class UnionFind(object):
15 """Union-find disjoint sets datastructure.
16 Union-find is a data structure that maintains disjoint set
17 (called connected components or components in short) membership,
18 and makes it easier to merge (union) two components, and to find
19 if two elements are connected (i.e., belong to the same
20 component).
21 This implements the "weighted-quick-union-with-path-compression"
22 union-find algorithm. Only works if elements are immutable
23 objects.
24 Worst case for union and find: :math:`(N + M \log^* N)`, with
25 :math:`N` elements and :math:`M` unions. The function
26 :math:`\log^*` is the number of times needed to take :math:`\log`
27 of a number until reaching 1. In practice, the amortized cost of
28 each operation is nearly linear [1]_.
29 Terms
30 -----
31 Component
32 Elements belonging to the same disjoint set
33 Connected
34 Two elements are connected if they belong to the same component.
35 Union
36 The operation where two components are merged into one.
37 Root
38 An internal representative of a disjoint set.
39 Find
40 The operation to find the root of a disjoint set.
41 Parameters
42 ----------
43 elements : NoneType or container, optional, default: None
44 The initial list of elements.
45 Attributes
46 ----------
47 n_elts : int
48 Number of elements.
49 n_comps : int
50 Number of distjoint sets or components.
51 Implements
52 ----------
53 __len__
54 Calling ``len(uf)`` (where ``uf`` is an instance of ``UnionFind``)
55 returns the number of elements.
56 __contains__
57 For ``uf`` an instance of ``UnionFind`` and ``x`` an immutable object,
58 ``x in uf`` returns ``True`` if ``x`` is an element in ``uf``.
59 __getitem__
60 For ``uf`` an instance of ``UnionFind`` and ``i`` an integer,
61 ``res = uf[i]`` returns the element stored in the ``i``-th index.
62 If ``i`` is not a valid index an ``IndexError`` is raised.
63 __setitem__
64 For ``uf`` and instance of ``UnionFind``, ``i`` an integer and ``x``
65 an immutable object, ``uf[i] = x`` changes the element stored at the
66 ``i``-th index. If ``i`` is not a valid index an ``IndexError`` is
67 raised.
68 .. [1] http://algs4.cs.princeton.edu/lectures/
69 """
70
71 def __init__(self, elements=None):
72 self.n_elts = 0 # current num of elements
73 self.n_comps = 0 # the number of disjoint sets or components
74 self._next = 0 # next available id
75 self._elts = [] # the elements
76 self._indx = {} # dict mapping elt -> index in _elts
77 self._par = [] # parent: for the internal tree structure
78 self._siz = [] # size of the component - correct only for roots
79
80 if elements is None:
81 elements = []
82 for elt in elements:
83 self.add(elt)
84
85
86 def __repr__(self):
87 return (
88 '<UnionFind:\n\telts={},\n\tsiz={},\n\tpar={},\nn_elts={},n_comps={}>'
89 .format(
90 self._elts,
91 self._siz,
92 self._par,
93 self.n_elts,
94 self.n_comps,
95 ))
96
97 def __len__(self):
98 return self.n_elts
99
100 def __contains__(self, x):
101 return x in self._indx
102
103 def __getitem__(self, index):
104 if index < 0 or index >= self._next:
105 raise IndexError('index {} is out of bound'.format(index))
106 return self._elts[index]
107
108 def __setitem__(self, index, x):
109 if index < 0 or index >= self._next:
110 raise IndexError('index {} is out of bound'.format(index))
111 self._elts[index] = x
112
113 def add(self, x):
114 """Add a single disjoint element.
115 Parameters
116 ----------
117 x : immutable object
118 Returns
119 -------
120 None
121 """
122 if x in self:
123 return
124 self._elts.append(x)
125 self._indx[x] = self._next
126 self._par.append(self._next)
127 self._siz.append(1)
128 self._next += 1
129 self.n_elts += 1
130 self.n_comps += 1
131
132 def find(self, x):
133 """Find the root of the disjoint set containing the given element.
134 Parameters
135 ----------
136 x : immutable object
137 Returns
138 -------
139 int
140 The (index of the) root.
141 Raises
142 ------
143 ValueError
144 If the given element is not found.
145 """
146 if x not in self._indx:
147 raise ValueError('{} is not an element'.format(x))
148
149 p = self._indx[x]
150 while p != self._par[p]:
151 # path compression
152 q = self._par[p]
153 self._par[p] = self._par[q]
154 p = q
155 return p
156
157 def connected(self, x, y):
158 """Return whether the two given elements belong to the same component.
159 Parameters
160 ----------
161 x : immutable object
162 y : immutable object
163 Returns
164 -------
165 bool
166 True if x and y are connected, false otherwise.
167 """
168 return self.find(x) == self.find(y)
169
170 def union(self, x, y):
171 """Merge the components of the two given elements into one.
172 Parameters
173 ----------
174 x : immutable object
175 y : immutable object
176 Returns
177 -------
178 None
179 """
180 # Initialize if they are not already in the collection
181 for elt in [x, y]:
182 if elt not in self:
183 self.add(elt)
184
185 xroot = self.find(x)
186 yroot = self.find(y)
187 if xroot == yroot:
188 return
189 if self._siz[xroot] < self._siz[yroot]:
190 self._par[xroot] = yroot
191 self._siz[yroot] += self._siz[xroot]
192 else:
193 self._par[yroot] = xroot
194 self._siz[xroot] += self._siz[yroot]
195 self.n_comps -= 1
196
197 def component(self, x):
198 """Find the connected component containing the given element.
199 Parameters
200 ----------
201 x : immutable object
202 Returns
203 -------
204 set
205 Raises
206 ------
207 ValueError
208 If the given element is not found.
209 """
210 if x not in self:
211 raise ValueError('{} is not an element'.format(x))
212 elts = np.array(self._elts)
213 vfind = np.vectorize(self.find)
214 roots = vfind(elts)
215 return set(elts[roots == self.find(x)])
216
217 def components(self):
218 """Return the list of connected components.
219 Returns
220 -------
221 list
222 A list of sets.
223 """
224 elts = np.array(self._elts)
225 vfind = np.vectorize(self.find)
226 roots = vfind(elts)
227 distinct_roots = set(roots)
228 return [set(elts[roots == root]) for root in distinct_roots]
229 # comps = []
230 # for root in distinct_roots:
231 # mask = (roots == root)
232 # comp = set(elts[mask])
233 # comps.append(comp)
234 # return comps
235
236 def component_mapping(self):
237 """Return a dict mapping elements to their components.
238 The returned dict has the following semantics:
239 `elt -> component containing elt`
240 If x, y belong to the same component, the comp(x) and comp(y)
241 are the same objects (i.e., share the same reference). Changing
242 comp(x) will reflect in comp(y). This is done to reduce
243 memory.
244 But this behaviour should not be relied on. There may be
245 inconsitency arising from such assumptions or lack thereof.
246 If you want to do any operation on these sets, use caution.
247 For example, instead of
248 ::
249 s = uf.component_mapping()[item]
250 s.add(stuff)
251 # This will have side effect in other sets
252 do
253 ::
254 s = set(uf.component_mapping()[item]) # or
255 s = uf.component_mapping()[item].copy()
256 s.add(stuff)
257 or
258 ::
259 s = uf.component_mapping()[item]
260 s = s | {stuff} # Now s is different
261 Returns
262 -------
263 dict
264 A dict with the semantics: `elt -> component contianing elt`.
265 """
266 elts = np.array(self._elts)
267 vfind = np.vectorize(self.find)
268 roots = vfind(elts)
269 distinct_roots = set(roots)
270 comps = {}
271 for root in distinct_roots:
272 mask = (roots == root)
273 comp = set(elts[mask])
274 comps.update({x: comp for x in comp})
275 # Change ^this^, if you want a different behaviour:
276 # If you don't want to share the same set to different keys:
277 # comps.update({x: set(comp) for x in comp})
278 return comps
22.6. Kruskal 算法
最小生成树算法中的 Kruskal 算法是基于并查集实现的。首先,将边集合放入优先队列,权重越小的边越靠近队首(小顶堆);然后,边依次出队,如果边的两个顶点位于两个集合,则将它们合并,边权重累加;当合并两个集合之后得到的新集合已经包括了所有的顶点,表示已经得到一棵最小生成树。
\(\color{darkgreen}{Code}\)
1// NC159 最小生成树
2// https://www.nowcoder.com/practice/735a34ff4672498b95660f43b7fcd628?tpId=117&&tqId=37869&rp=1&ru=/ta/job-code-high&qru=/ta/job-code-high/question-ranking
3
4struct comparator
5{
6 bool operator()(vector<int>& a, vector<int>& b)
7 {
8 return a[2] > b[2]; // 小顶堆
9 }
10};
11class Solution
12{
13public:
14 /**
15 * 返回最小的花费代价使得这 n 户人家连接起来
16 * @param n int n户人家的村庄
17 * @param cost intvector<vector<>> 一维3个参数,表示连接1个村庄到另外1个村庄的花费的代价
18 * @return int
19 */
20 int miniSpanningTree(int n, vector<vector<int> >& cost)
21 {
22 // write code here
23 if(n <= 1) return 0;
24 vector<int> parents(n+1, 0);
25 iota(parents.begin(), parents.end(), 0);
26 vector<int> capacity(n+1, 1);
27 priority_queue<vector<int>, vector<vector<int> >, comparator> edges;
28 for(auto& edge: cost) edges.push(edge);
29 int c = 0;
30 int v = 0;
31 while(!edges.empty())
32 {
33 auto edge = edges.top();
34 edges.pop();
35 bool u = union_(parents, capacity, edge[0], edge[1], v);
36 if(u) c += edge[2];
37 if(v == n) break; // 已经得到最小生成树
38 }
39 return c;
40 }
41private:
42 int find_(vector<int>& parents, int x)
43 {
44 if(x == parents[x]) return x;
45 else
46 {
47 parents[x] = find_(parents, parents[x]);
48 return parents[x];
49 }
50 }
51 bool union_(vector<int>& parents, vector<int>& capacity, int x, int y, int& v)
52 {
53 x = find_(parents, x);
54 y = find_(parents, y);
55 if(x != y)
56 {
57 if(capacity[x] >= capacity[y])
58 {
59 parents[y] = x;
60 capacity[x] += capacity[y];
61 v = capacity[x];
62 }
63 else
64 {
65 parents[x] = y;
66 capacity[y] += capacity[x];
67 v = capacity[y];
68 }
69 return true;
70 }
71 return false;
72 }
73};
22.7. 参考资料
算法学习笔记(1) : 并查集
并查集
并查集入门
github
Disjoint-set data structure