Skip to content

Prim's Minimum Spanning Tree

Generated on 2025-07-10 01:58:30 Algorithm Cheatsheet for Technical Interviews


What is it? Prim’s algorithm finds the Minimum Spanning Tree (MST) for a weighted, undirected, connected graph. The MST is a subset of the edges that connects all the vertices together, without any cycles and with the minimum possible total edge weight.

When to use it? Use Prim’s when you need to find the cheapest way to connect all vertices in a graph, such as:

  • Network design (connecting computers, cities, etc.)
  • Clustering
  • Approximation algorithms for NP-hard problems
FeatureComplexityNotes
Time:O(E log V) (using a binary heap)E is the number of edges, V is the number of vertices. Using a Fibonacci heap can improve this to O(E + V log V), but it’s rarely used in practice due to higher constant factors. In a dense graph (E ≈ V^2), O(V^2) is achievable with adjacency matrix and no heap.
Space:O(V) (adjacency list representation) or O(V^2) (adjacency matrix representation) + O(V) (for the visited set/priority queue)The space complexity depends on the graph representation and the data structures used. Adjacency list generally uses less space for sparse graphs. The visited set or priority queue stores information about the vertices, hence O(V).

Here are examples in Python, Java and C++

Python:

import heapq
def prims_mst(graph):
"""
Finds the Minimum Spanning Tree (MST) of a graph using Prim's algorithm.
Args:
graph: A dictionary representing the graph, where keys are vertices
and values are lists of (neighbor, weight) tuples.
Example: {'A': [('B', 2), ('C', 3)], 'B': [('A', 2), ('C', 4)], ...}
Returns:
A list of edges in the MST, represented as tuples (vertex1, vertex2, weight).
Returns an empty list if the graph is empty or disconnected.
"""
if not graph:
return []
start_node = next(iter(graph)) # Arbitrary starting node
visited = set()
mst = []
pq = [(0, start_node, start_node)] # (weight, current_node, parent_node) - Weight for current_node is always 0
while pq:
weight, current_node, parent_node = heapq.heappop(pq)
if current_node in visited:
continue
visited.add(current_node)
if current_node != parent_node: # Avoid adding the initial "self-loop"
mst.append((parent_node, current_node, weight))
for neighbor, w in graph[current_node]:
if neighbor not in visited:
heapq.heappush(pq, (w, neighbor, current_node))
# Check for connectivity (all nodes reachable)
if len(visited) != len(graph):
return [] # Graph is disconnected; no MST exists
return mst
# Example usage:
graph = {
'A': [('B', 2), ('C', 3)],
'B': [('A', 2), ('C', 4), ('D', 7)],
'C': [('A', 3), ('B', 4), ('D', 5)],
'D': [('B', 7), ('C', 5)]
}
mst = prims_mst(graph)
print("MST:", mst)
total_weight = sum(weight for _, _, weight in mst)
print("Total MST weight:", total_weight)
# Complexity Analysis:
# Time: O(E log V) - E for iterating through edges, log V for heap operations.
# Space: O(V) - for visited set and priority queue.

Java:

import java.util.*;
public class PrimsMST {
public static List<Edge> primsMST(Map<Character, List<Edge>> graph) {
if (graph == null || graph.isEmpty()) {
return new ArrayList<>();
}
char startNode = graph.keySet().iterator().next();
Set<Character> visited = new HashSet<>();
List<Edge> mst = new ArrayList<>();
PriorityQueue<Edge> pq = new PriorityQueue<>(Comparator.comparingInt(e -> e.weight)); // Min-heap
pq.add(new Edge(startNode, startNode, 0)); // Add initial node with weight 0
while (!pq.isEmpty()) {
Edge edge = pq.poll();
char current = edge.to;
if (visited.contains(current)) {
continue;
}
visited.add(current);
if (edge.from != edge.to) { // Avoid adding initial self-loop
mst.add(edge);
}
if (graph.containsKey(current)) {
for (Edge neighborEdge : graph.get(current)) {
if (!visited.contains(neighborEdge.to)) {
pq.add(neighborEdge);
}
}
}
}
// Check for connectivity
if (visited.size() != graph.size()) {
return new ArrayList<>(); // Graph disconnected
}
return mst;
}
public static class Edge {
char from;
char to;
int weight;
public Edge(char from, char to, int weight) {
this.from = from;
this.to = to;
this.weight = weight;
}
@Override
public String toString() {
return "(" + from + ", " + to + ", " + weight + ")";
}
}
public static void main(String[] args) {
Map<Character, List<Edge>> graph = new HashMap<>();
graph.put('A', Arrays.asList(new Edge('A', 'B', 2), new Edge('A', 'C', 3)));
graph.put('B', Arrays.asList(new Edge('B', 'A', 2), new Edge('B', 'C', 4), new Edge('B', 'D', 7)));
graph.put('C', Arrays.asList(new Edge('C', 'A', 3), new Edge('C', 'B', 4), new Edge('C', 'D', 5)));
graph.put('D', Arrays.asList(new Edge('D', 'B', 7), new Edge('D', 'C', 5)));
List<Edge> mst = primsMST(graph);
System.out.println("MST: " + mst);
int totalWeight = mst.stream().mapToInt(e -> e.weight).sum();
System.out.println("Total MST weight: " + totalWeight);
}
}
// Complexity Analysis:
// Time: O(E log V) - E for iterating through edges, log V for heap operations.
// Space: O(V) - for visited set and priority queue.

C++:

#include <iostream>
#include <vector>
#include <queue>
#include <map>
#include <set>
using namespace std;
struct Edge {
char to;
int weight;
Edge(char t, int w) : to(t), weight(w) {}
bool operator>(const Edge& other) const {
return weight > other.weight; // For min-heap
}
};
vector<pair<char, char>> primsMST(map<char, vector<pair<char, int>>>& graph) {
if (graph.empty()) return {};
char startNode = graph.begin()->first;
set<char> visited;
vector<pair<char, char>> mst; // Store MST edges (parent, child)
priority_queue<pair<int, pair<char, char>>, vector<pair<int, pair<char, char>>>, greater<pair<int, pair<char, char>>>> pq; // Min-heap (weight, (current, parent))
pq.push({0, {startNode, startNode}}); // Start with weight 0 and self-loop
while (!pq.empty()) {
int weight = pq.top().first;
char current = pq.top().second.first;
char parent = pq.top().second.second;
pq.pop();
if (visited.count(current)) continue;
visited.insert(current);
if (current != parent) { // Avoid initial self loop
mst.push_back({parent, current});
}
for (auto& neighbor : graph[current]) {
if (!visited.count(neighbor.first)) {
pq.push({neighbor.second, {neighbor.first, current}});
}
}
}
// Check connectivity
if (visited.size() != graph.size()) {
return {}; // Graph disconnected
}
return mst;
}
int main() {
map<char, vector<pair<char, int>>> graph = {
{'A', {{'B', 2}, {'C', 3}}},
{'B', {{'A', 2}, {'C', 4}, {'D', 7}}},
{'C', {{'A', 3}, {'B', 4}, {'D', 5}}},
{'D', {{'B', 7}, {'C', 5}}}
};
vector<pair<char, char>> mst = primsMST(graph);
cout << "MST Edges: ";
int totalWeight = 0;
for (auto& edge : mst) {
cout << "(" << edge.first << ", " << edge.second << ") ";
// Find weight from the graph
for(auto& neighbor : graph[edge.first]){
if(neighbor.first == edge.second){
totalWeight += neighbor.second;
break;
}
}
}
cout << endl;
cout << "Total MST weight: " << totalWeight << endl;
return 0;
}
// Complexity Analysis:
// Time: O(E log V) - E for iterating through edges, log V for heap operations.
// Space: O(V) - for visited set and priority queue.
  • Adjacency List vs. Adjacency Matrix: Adjacency lists are generally preferred for sparse graphs (fewer edges) because they consume less memory. Adjacency matrices are simpler to implement but use O(V^2) space regardless of the number of edges.
  • Priority Queue (Min-Heap): The efficiency of Prim’s algorithm hinges on the priority queue. It allows you to quickly retrieve the edge with the smallest weight connecting the visited and unvisited vertices.
  • Disconnected Graphs: Prim’s algorithm, in its basic form, only works for connected graphs. If the graph is disconnected, the algorithm will only find the MST for the connected component containing the starting vertex. You can modify it to iterate through each connected component.
  • Edge Representation: The representation of edges in the MST can vary. The examples above use tuples or custom Edge classes. Choose a representation that’s suitable for your specific problem.
  • Choosing a Starting Vertex: The choice of the starting vertex doesn’t affect the final MST’s total weight, but it can affect the order in which edges are added to the MST.
  • Lazy Deletion: In some implementations of Prim’s, you might encounter “stale” entries in the priority queue (edges that point to vertices that have already been visited). Instead of removing these entries (which can be expensive), you can simply ignore them when they are dequeued from the priority queue. This is called lazy deletion.
  • Connectivity Check: After running Prim’s, verify that the number of edges in the MST is equal to V-1, where V is the number of vertices. If not, the graph might be disconnected, or there might be an error in your implementation.
  • Dense Graphs: For very dense graphs (E ≈ V^2), using an adjacency matrix and a simple linear search to find the minimum-weight edge connecting the visited and unvisited vertices can be more efficient than using a priority queue. This results in an O(V^2) time complexity.
  1. LeetCode 1584: Min Cost to Connect All Points - This is a direct application of Prim’s algorithm. The points are the vertices, and the Manhattan distance between the points are the edge weights.

    import heapq
    def minCostConnectPoints(points):
    n = len(points)
    edges = []
    for i in range(n):
    for j in range(i + 1, n):
    dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
    edges.append((dist, i, j))
    graph = {i: [] for i in range(n)}
    for d, u, v in edges:
    graph[u].append((v, d))
    graph[v].append((u, d))
    mst_cost = 0
    visited = set()
    pq = [(0, 0)] # (cost, node)
    while pq:
    cost, node = heapq.heappop(pq)
    if node in visited:
    continue
    visited.add(node)
    mst_cost += cost
    for neighbor, weight in graph[node]:
    if neighbor not in visited:
    heapq.heappush(pq, (weight, neighbor))
    return mst_cost
  2. LeetCode 1135: Connecting Cities With Minimum Cost - Another straightforward MST problem. Cities are vertices, and the given connections are the edges with their costs. You need to check if all cities are connected after finding the MST.

    import heapq
    def minimumCost(n, connections):
    graph = {i: [] for i in range(1, n + 1)}
    for u, v, cost in connections:
    graph[u].append((v, cost))
    graph[v].append((u, cost))
    mst_cost = 0
    visited = set()
    pq = [(0, 1)] # (cost, node)
    while pq:
    cost, node = heapq.heappop(pq)
    if node in visited:
    continue
    visited.add(node)
    mst_cost += cost
    for neighbor, weight in graph[node]:
    if neighbor not in visited:
    heapq.heappush(pq, (weight, neighbor))
    if len(visited) != n:
    return -1 # Not all cities connected
    return mst_cost

These templates provide a basic skeleton for implementing Prim’s algorithm. You’ll need to adapt them to the specific problem you’re solving, including defining the graph representation and edge weights.

C++ Template:

#include <iostream>
#include <vector>
#include <queue>
#include <map>
#include <set>
using namespace std;
int primsMST(map<int, vector<pair<int, int>>>& graph) { // graph: {node: [(neighbor, weight), ...]}
if (graph.empty()) return 0;
int startNode = graph.begin()->first;
set<int> visited;
priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> pq; // Min-heap (weight, node)
pq.push({0, startNode});
int mstWeight = 0;
while (!pq.empty()) {
int weight = pq.top().first;
int current = pq.top().second;
pq.pop();
if (visited.count(current)) continue;
visited.insert(current);
mstWeight += weight;
for (auto& neighbor : graph[current]) {
if (!visited.count(neighbor.first)) {
pq.push({neighbor.second, neighbor.first});
}
}
}
if (visited.size() != graph.size()) return -1; //check if graph is fully connected
return mstWeight;
}

Python Template:

import heapq
def prims_mst(graph): # graph: {node: [(neighbor, weight), ...]}
if not graph:
return 0
start_node = next(iter(graph))
visited = set()
pq = [(0, start_node)] # (weight, node)
mst_weight = 0
while pq:
weight, current_node = heapq.heappop(pq)
if current_node in visited:
continue
visited.add(current_node)
mst_weight += weight
for neighbor, w in graph[current_node]:
if neighbor not in visited:
heapq.heappush(pq, (w, neighbor))
if len(visited) != len(graph): return -1 # Check if graph is fully connected
return mst_weight

Java Template:

import java.util.*;
public class PrimsTemplate {
public static int primsMST(Map<Integer, List<Pair>> graph) { // graph: {node: [(neighbor, weight), ...]}
if (graph == null || graph.isEmpty()) {
return 0;
}
int startNode = graph.keySet().iterator().next();
Set<Integer> visited = new HashSet<>();
PriorityQueue<Pair> pq = new PriorityQueue<>(Comparator.comparingInt(p -> p.weight)); // Min-heap
pq.add(new Pair(startNode, 0));
int mstWeight = 0;
while (!pq.isEmpty()) {
Pair current = pq.poll();
int node = current.node;
int weight = current.weight;
if (visited.contains(node)) {
continue;
}
visited.add(node);
mstWeight += weight;
if (graph.containsKey(node)) {
for (Pair neighbor : graph.get(node)) {
if (!visited.contains(neighbor.node)) {
pq.add(neighbor);
}
}
}
}
if (visited.size() != graph.size()) return -1; // Check if graph is fully connected
return mstWeight;
}
static class Pair {
int node;
int weight;
public Pair(int node, int weight) {
this.node = node;
this.weight = weight;
}
}
}