Skip to content

Union-Find (Disjoint Set Union)

Generated on 2025-07-10 01:57:07 Algorithm Cheatsheet for Technical Interviews


Union-Find (Disjoint Set Union) Cheatsheet

Section titled “Union-Find (Disjoint Set Union) Cheatsheet”
  • What is it? A data structure that tracks a set of elements partitioned into a number of disjoint (non-overlapping) subsets. It provides two primary operations:

    • Find(x): Determines which subset a particular element x is in. Returns a “representative” element for that subset. If Find(x) == Find(y), x and y are in the same set.
    • Union(x, y): Merges the subsets containing elements x and y into a single subset.
  • When to use it?

    • Connectivity problems (e.g., determining if two nodes are connected in a graph).
    • Detecting cycles in undirected graphs.
    • Clustering problems.
    • Equivalent class problems.
    • Finding connected components.
OperationAverage Time Complexity (with Path Compression & Union by Rank/Size)Worst-Case Time Complexity (without optimizations)Space Complexity
FindO(α(n)) (effectively constant)O(n)O(n)
UnionO(α(n)) (effectively constant)O(n)O(n)
InitializeO(n)O(n)O(n)
  • α(n): Inverse Ackermann function, which grows extremely slowly. For all practical purposes, it can be considered a constant (≤ 5 for any realistically sized input).
class UnionFind:
def __init__(self, n):
self.parent = list(range(n)) # Initialize each element as its own parent
self.rank = [0] * n # Used for Union by Rank optimization
self.count = n # Number of connected components
def find(self, x):
"""Finds the root/representative of the set containing x with path compression."""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]
def union(self, x, y):
"""Merges the sets containing x and y using Union by Rank."""
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1 # Increment rank if ranks are equal
self.count -= 1
def connected(self, x, y):
"""Checks if x and y are in the same set."""
return self.find(x) == self.find(y)
def get_count(self):
"""Returns the number of connected components."""
return self.count
# Example Usage
uf = UnionFind(5)
uf.union(0, 1)
uf.union(2, 3)
print(uf.connected(0, 1)) # True
print(uf.connected(0, 2)) # False
uf.union(1, 2)
print(uf.connected(0, 2)) # True
print(uf.get_count()) # 2
class UnionFind {
private int[] parent;
private int[] rank;
private int count;
public UnionFind(int n) {
parent = new int[n];
rank = new int[n];
count = n;
for (int i = 0; i < n; i++) {
parent[i] = i;
rank[i] = 0;
}
}
public int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]); // Path compression
}
return parent[x];
}
public void union(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX != rootY) {
if (rank[rootX] < rank[rootY]) {
parent[rootX] = rootY;
} else if (rank[rootX] > rank[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank[rootX]++;
}
count--;
}
}
public boolean connected(int x, int y) {
return find(x) == find(y);
}
public int getCount() {
return count;
}
public static void main(String[] args) {
UnionFind uf = new UnionFind(5);
uf.union(0, 1);
uf.union(2, 3);
System.out.println(uf.connected(0, 1)); // True
System.out.println(uf.connected(0, 2)); // False
uf.union(1, 2);
System.out.println(uf.connected(0, 2)); // True
System.out.println(uf.getCount()); // 2
}
}
#include <iostream>
#include <vector>
class UnionFind {
private:
std::vector<int> parent;
std::vector<int> rank;
int count;
public:
UnionFind(int n) : parent(n), rank(n, 0), count(n) {
for (int i = 0; i < n; ++i) {
parent[i] = i;
}
}
int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]); // Path compression
}
return parent[x];
}
void unionSet(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX != rootY) {
if (rank[rootX] < rank[rootY]) {
parent[rootX] = rootY;
} else if (rank[rootX] > rank[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank[rootX]++;
}
count--;
}
}
bool connected(int x, int y) {
return find(x) == find(y);
}
int getCount() const {
return count;
}
};
int main() {
UnionFind uf(5);
uf.unionSet(0, 1);
uf.unionSet(2, 3);
std::cout << uf.connected(0, 1) << std::endl; // 1 (True)
std::cout << uf.connected(0, 2) << std::endl; // 0 (False)
uf.unionSet(1, 2);
std::cout << uf.connected(0, 2) << std::endl; // 1 (True)
std::cout << uf.getCount() << std::endl; // 2
return 0;
}
  • Connecting Nodes: The most basic usage. Given pairs of nodes (edges), determine if two nodes are connected.
  • Finding Connected Components: Count the number of disjoint sets. Useful in graph analysis.
  • Cycle Detection (Undirected Graph): If Union(u, v) is called and Find(u) == Find(v) before the union, then adding the edge (u, v) would create a cycle.
  • Kruskal’s Algorithm (Minimum Spanning Tree): Union-Find is a crucial component. Sort edges by weight and add them to the MST if they connect disjoint components.
  • Dynamic Connectivity: Handling connectivity queries in a dynamic environment where edges can be added over time.
  • Path Compression: Crucial for performance. During Find(x), update the parent of x and all its ancestors to point directly to the root. This flattens the tree.
  • Union by Rank/Size: Also crucial. Attach the smaller tree to the larger tree to minimize tree height. “Rank” is an estimate of the tree’s height. “Size” is the number of nodes in the tree. Union by Rank generally performs slightly better.
  • Initialization: Ensure that each element is initially in its own disjoint set.
  • Don’t Mix Up Find and Union: Find locates the root. Union merges two sets.
  • Avoid Redundant Unions: If Find(x) == Find(y) before calling Union(x, y), then the union is redundant and can be skipped for better performance.
  • Space Optimization: If the values are in a very large range, consider using a HashMap/Dictionary instead of an array for parent and rank to store only the elements that are actually used.
  1. Number of Islands (LeetCode 200): Given a 2D grid map of ‘1’s (land) and ‘0’s (water), count the number of islands. Use Union-Find to merge adjacent land cells.
  2. Redundant Connection (LeetCode 684): In this problem, a tree is an undirected graph that is connected and has no cycles. You are given a graph that started as a tree with n nodes labeled from 1 to n, with one additional edge added. The added edge was chosen randomly between two different vertices that were not directly connected. Find the edge that can be removed so that the resulting graph is a tree. Use Union-Find to detect the cycle.
  3. Accounts Merge (LeetCode 721): Given a list of accounts where each element accounts[i] is a list of strings, where the first element accounts[i][0] is the name of the account, and the rest of the elements are emails representing emails of the account. We would like to merge these accounts. Two accounts definitely belong to the same person if there is some email that is common to both accounts. Note that even if two accounts have the same name, they may belong to different people as people might have the same name. A person can have any number of accounts initially, but all of their accounts definitely have the same name. After merging the accounts, return the accounts in the following format: the first element of each account is the name of the account, and the rest of the elements are emails in sorted order. The accounts themselves can be returned in any order.
#include <iostream>
#include <vector>
class UnionFind {
private:
std::vector<int> parent;
std::vector<int> rank;
int count;
public:
UnionFind(int n) : parent(n), rank(n, 0), count(n) {
for (int i = 0; i < n; ++i) {
parent[i] = i;
}
}
int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]); // Path compression
}
return parent[x];
}
void unionSet(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX != rootY) {
if (rank[rootX] < rank[rootY]) {
parent[rootX] = rootY;
} else if (rank[rootX] > rank[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank[rootX]++;
}
count--;
}
}
bool connected(int x, int y) {
return find(x) == find(y);
}
int getCount() const {
return count;
}
};
// Example Usage:
// UnionFind uf(n); // Initialize with n elements
// uf.unionSet(a, b); // Union elements a and b
// bool areConnected = uf.connected(x, y); // Check if x and y are connected
// int numComponents = uf.getCount(); // Get the number of connected components
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
self.count -= 1
def connected(self, x, y):
return self.find(x) == self.find(y)
def get_count(self):
return self.count
# Example Usage:
# uf = UnionFind(n) # Initialize with n elements
# uf.union(a, b) # Union elements a and b
# are_connected = uf.connected(x, y) # Check if x and y are connected
# num_components = uf.get_count() # Get the number of connected components
class UnionFind {
private int[] parent;
private int[] rank;
private int count;
public UnionFind(int n) {
parent = new int[n];
rank = new int[n];
count = n;
for (int i = 0; i < n; i++) {
parent[i] = i;
rank[i] = 0;
}
}
public int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]);
}
return parent[x];
}
public void union(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX != rootY) {
if (rank[rootX] < rank[rootY]) {
parent[rootX] = rootY;
} else if (rank[rootX] > rank[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank[rootX]++;
}
count--;
}
}
public boolean connected(int x, int y) {
return find(x) == find(y);
}
public int getCount() {
return count;
}
}
// Example Usage:
// UnionFind uf = new UnionFind(n); // Initialize with n elements
// uf.union(a, b); // Union elements a and b
// boolean areConnected = uf.connected(x, y); // Check if x and y are connected
// int numComponents = uf.getCount(); // Get the number of connected components