Skip to content

Segment Tree

Generated on 2025-07-10 02:04:56 Algorithm Cheatsheet for Technical Interviews


A Segment Tree is a tree data structure used for efficiently performing range queries and updates on an array. It allows you to answer queries like “sum of elements in range [L, R]” or “minimum element in range [L, R]” in logarithmic time. It excels when you have frequent range queries and updates.

When to Use:

  • Frequent range queries (sum, min, max, etc.)
  • Frequent range updates (addition, multiplication, assignment)
  • Static array size (or occasional resizing)
  • Offline queries (queries known in advance) or online queries (queries arrive dynamically)
OperationTime ComplexitySpace Complexity
BuildO(N)O(N)
Query (Range)O(log N)O(1)
Update (Point)O(log N)O(1)
Update (Range)O(log N)O(1)
  • N: Size of the input array

Here’s a basic implementation of a Segment Tree for Range Sum Queries and Point Updates in Python, Java, and C++:

Python

class SegmentTree:
def __init__(self, arr):
self.arr = arr
self.n = len(arr)
self.tree = [0] * (4 * self.n) # Allocate space for the tree
self.build(0, 0, self.n - 1)
def build(self, node, start, end):
"""Builds the segment tree recursively."""
if start == end:
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
self.build(2 * node + 1, start, mid) # Left child
self.build(2 * node + 2, mid + 1, end) # Right child
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def update(self, index, value):
"""Updates the value at a given index in the array and the tree."""
self._update(0, 0, self.n - 1, index, value)
self.arr[index] = value # Update the underlying array as well
def _update(self, node, start, end, index, value):
if start == end:
self.tree[node] = value
return
mid = (start + end) // 2
if index <= mid:
self._update(2 * node + 1, start, mid, index, value)
else:
self._update(2 * node + 2, mid + 1, end, index, value)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def query(self, left, right):
"""Performs a range sum query on the array."""
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node, start, end, left, right):
if right < start or end < left:
return 0 # Out of range
if left <= start and end <= right:
return self.tree[node] # Completely within range
mid = (start + end) // 2
left_sum = self._query(2 * node + 1, start, mid, left, right)
right_sum = self._query(2 * node + 2, mid + 1, end, left, right)
return left_sum + right_sum
# Example usage:
arr = [1, 3, 5, 7, 9, 11]
st = SegmentTree(arr)
print(f"Sum of range [1, 3]: {st.query(1, 3)}") # Output: 15
st.update(1, 10) # Update index 1 to value 10
print(f"Sum of range [1, 3]: {st.query(1, 3)}") # Output: 26

Java

class SegmentTree {
private int[] arr;
private int[] tree;
private int n;
public SegmentTree(int[] arr) {
this.arr = arr;
this.n = arr.length;
this.tree = new int[4 * n]; // Allocate space for the tree
build(0, 0, n - 1);
}
private void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
return;
}
int mid = (start + end) / 2;
build(2 * node + 1, start, mid); // Left child
build(2 * node + 2, mid + 1, end); // Right child
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
public void update(int index, int value) {
updateHelper(0, 0, n - 1, index, value);
arr[index] = value; // Update the underlying array as well
}
private void updateHelper(int node, int start, int end, int index, int value) {
if (start == end) {
tree[node] = value;
return;
}
int mid = (start + end) / 2;
if (index <= mid) {
updateHelper(2 * node + 1, start, mid, index, value);
} else {
updateHelper(2 * node + 2, mid + 1, end, index, value);
}
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
public int query(int left, int right) {
return queryHelper(0, 0, n - 1, left, right);
}
private int queryHelper(int node, int start, int end, int left, int right) {
if (right < start || end < left) {
return 0; // Out of range
}
if (left <= start && end <= right) {
return tree[node]; // Completely within range
}
int mid = (start + end) / 2;
int leftSum = queryHelper(2 * node + 1, start, mid, left, right);
int rightSum = queryHelper(2 * node + 2, mid + 1, end, left, right);
return leftSum + rightSum;
}
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
SegmentTree st = new SegmentTree(arr);
System.out.println("Sum of range [1, 3]: " + st.query(1, 3)); // Output: 15
st.update(1, 10); // Update index 1 to value 10
System.out.println("Sum of range [1, 3]: " + st.query(1, 3)); // Output: 26
}
}

C++

#include <iostream>
#include <vector>
using namespace std;
class SegmentTree {
private:
vector<int> arr;
vector<int> tree;
int n;
void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
return;
}
int mid = (start + end) / 2;
build(2 * node + 1, start, mid); // Left child
build(2 * node + 2, mid + 1, end); // Right child
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
void updateHelper(int node, int start, int end, int index, int value) {
if (start == end) {
tree[node] = value;
return;
}
int mid = (start + end) / 2;
if (index <= mid) {
updateHelper(2 * node + 1, start, mid, index, value);
} else {
updateHelper(2 * node + 2, mid + 1, end, index, value);
}
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
int queryHelper(int node, int start, int end, int left, int right) {
if (right < start || end < left) {
return 0; // Out of range
}
if (left <= start && end <= right) {
return tree[node]; // Completely within range
}
int mid = (start + end) / 2;
int leftSum = queryHelper(2 * node + 1, start, mid, left, right);
int rightSum = queryHelper(2 * node + 2, mid + 1, end, left, right);
return leftSum + rightSum;
}
public:
SegmentTree(vector<int>& arr) : arr(arr), n(arr.size()) {
tree.resize(4 * n); // Allocate space for the tree
build(0, 0, n - 1);
}
void update(int index, int value) {
updateHelper(0, 0, n - 1, index, value);
arr[index] = value; // Update the underlying array as well
}
int query(int left, int right) {
return queryHelper(0, 0, n - 1, left, right);
}
};
int main() {
vector<int> arr = {1, 3, 5, 7, 9, 11};
SegmentTree st(arr);
cout << "Sum of range [1, 3]: " << st.query(1, 3) << endl; // Output: 15
st.update(1, 10); // Update index 1 to value 10
cout << "Sum of range [1, 3]: " << st.query(1, 3) << endl; // Output: 26
return 0;
}

Explanation:

  • build(node, start, end): Recursively builds the segment tree. The base case is when start == end, where the leaf node is assigned the value of the array element. Otherwise, the node’s value is the sum of its children’s values.
  • update(index, value): Updates the value at a given index. This involves updating the leaf node corresponding to the index and then propagating the changes up to the root.
  • query(left, right): Performs a range sum query. It recursively traverses the tree, checking for three cases:
    • Out of range: The query range is completely outside the node’s range. Return 0 (or the appropriate identity value for the operation).
    • Completely within range: The node’s range is completely within the query range. Return the node’s value.
    • Partial overlap: Recursively query the left and right children and combine their results.
  • Range Sum Query (RSQ): The basic example above. The tree stores the sum of elements in each segment.
  • Range Minimum Query (RMQ): The tree stores the minimum element in each segment. The build, update, and query functions need to be modified to use min instead of +.
  • Range Maximum Query (RMaxQ): The tree stores the maximum element in each segment. The build, update, and query functions need to be modified to use max instead of +.
  • Lazy Propagation: Used for range updates (e.g., add a value to all elements in a range). It avoids updating all nodes in the range immediately, instead storing the update information in a “lazy” node and propagating it down only when necessary. This improves the time complexity of range updates from O(N) to O(log N).

Lazy Propagation Example (Conceptual):

  1. Update Range: When updating a range, if a node’s range is completely within the update range, mark the node as “lazy” and store the update value.
  2. Query Range or Update Point: Before accessing a node, check if it’s marked as “lazy.” If it is, apply the lazy update to the node and propagate the update to its children. Then, clear the lazy flag.
  • Zero-Based Indexing: The code assumes zero-based indexing. Adjust accordingly if your problem uses one-based indexing.
  • Tree Size: The size of the tree array should be at least 4 * N. This ensures enough space for all nodes in the segment tree.
  • Identity Element: When handling out-of-range queries, return the identity element for the operation. For sum, it’s 0; for min, it’s infinity; for max, it’s negative infinity.
  • Leaf Nodes: Leaf nodes represent single elements of the original array.
  • Complete Binary Tree: A segment tree is a complete binary tree (or nearly complete).
  • Recursive vs. Iterative: While the implementation is often recursive for clarity, an iterative approach can sometimes be more efficient (avoiding function call overhead).
  • Lazy Propagation Gotchas: Remember to propagate lazy updates before performing any queries or updates on the node. Also, be careful about the order of applying multiple lazy updates. For example, if you have both addition and multiplication updates, you need to apply them in the correct order (usually multiplication first).
  • Array Modification: Make sure to update the original array when you update the segment tree. This ensures that the segment tree and the underlying array stay synchronized.
  1. Range Sum Query - Mutable (LeetCode 307): Implement a segment tree to support range sum queries and point updates.
  2. Range Minimum Query (SPOJ RMQ): Implement a segment tree to support range minimum queries. (This problem can also be solved with sparse tables, but segment trees are a good exercise.)
  3. Range Sum Query 2D - Mutable (LeetCode 308): Implement a 2D segment tree (or a binary indexed tree) to support range sum queries and point updates in a 2D matrix. This is a more advanced problem.

These templates provide a starting point for implementing segment trees in different languages. They focus on the basic structure and can be easily adapted for different query types (min, max, etc.) and updates (range addition, etc.).

C++

#include <iostream>
#include <vector>
using namespace std;
class SegmentTree {
private:
vector<int> tree; // The segment tree array
vector<int> arr; // The input array
int n; // Size of the input array
// Function to build the segment tree
void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start]; // Base case: leaf node
return;
}
int mid = (start + end) / 2;
build(2 * node + 1, start, mid); // Build left subtree
build(2 * node + 2, mid + 1, end); // Build right subtree
tree[node] = tree[2 * node + 1] + tree[2 * node + 2]; // Combine results (e.g., sum)
}
// Function to update a value in the segment tree
void update(int node, int start, int end, int idx, int val) {
if (start == end) {
arr[idx] = val; // Update the original array
tree[node] = val; // Update the leaf node
return;
}
int mid = (start + end) / 2;
if (idx <= mid) {
update(2 * node + 1, start, mid, idx, val); // Update left subtree
} else {
update(2 * node + 2, mid + 1, end, idx, val); // Update right subtree
}
tree[node] = tree[2 * node + 1] + tree[2 * node + 2]; // Update current node
}
// Function to query the segment tree
int query(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0; // Out of range (identity element for sum)
}
if (l <= start && end <= r) {
return tree[node]; // Completely within range
}
int mid = (start + end) / 2;
int p1 = query(2 * node + 1, start, mid, l, r); // Query left subtree
int p2 = query(2 * node + 2, mid + 1, end, l, r); // Query right subtree
return p1 + p2; // Combine results
}
public:
// Constructor: Initializes the segment tree with the input array
SegmentTree(vector<int>& a) : arr(a), n(a.size()) {
tree.resize(4 * n); // Allocate memory for the tree
build(0, 0, n - 1); // Build the tree
}
// Public update function
void update(int idx, int val) {
update(0, 0, n - 1, idx, val);
}
// Public query function
int query(int l, int r) {
return query(0, 0, n - 1, l, r);
}
};
// Example Usage
int main() {
vector<int> arr = {1, 3, 5, 7, 9, 11};
SegmentTree st(arr);
cout << "Sum of range [1, 3]: " << st.query(1, 3) << endl; // Output: 15
st.update(1, 10); // Update index 1 to value 10
cout << "Sum of range [1, 3]: " << st.query(1, 3) << endl; // Output: 26
return 0;
}

Python

class SegmentTree:
def __init__(self, arr):
self.arr = arr
self.n = len(arr)
self.tree = [0] * (4 * self.n) # Allocate space for the tree
self.build(0, 0, self.n - 1)
def build(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
self.build(2 * node + 1, start, mid) # Left child
self.build(2 * node + 2, mid + 1, end) # Right child
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def update(self, index, value):
self._update(0, 0, self.n - 1, index, value)
self.arr[index] = value
def _update(self, node, start, end, index, value):
if start == end:
self.tree[node] = value
return
mid = (start + end) // 2
if index <= mid:
self._update(2 * node + 1, start, mid, index, value)
else:
self._update(2 * node + 2, mid + 1, end, index, value)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def query(self, left, right):
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node, start, end, left, right):
if right < start or end < left:
return 0
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
left_sum = self._query(2 * node + 1, start, mid, left, right)
right_sum = self._query(2 * node + 2, mid + 1, end, left, right)
return left_sum + right_sum
# Example usage:
if __name__ == "__main__":
arr = [1, 3, 5, 7, 9, 11]
st = SegmentTree(arr)
print(f"Sum of range [1, 3]: {st.query(1, 3)}")
st.update(1, 10)
print(f"Sum of range [1, 3]: {st.query(1, 3)}")

Java

class SegmentTree {
private int[] tree;
private int[] arr;
private int n;
public SegmentTree(int[] arr) {
this.arr = arr;
this.n = arr.length;
this.tree = new int[4 * n]; // Allocate space for the tree
build(0, 0, n - 1);
}
private void build(int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
return;
}
int mid = (start + end) / 2;
build(2 * node + 1, start, mid); // Left child
build(2 * node + 2, mid + 1, end); // Right child
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
public void update(int index, int value) {
updateHelper(0, 0, n - 1, index, value);
arr[index] = value;
}
private void updateHelper(int node, int start, int end, int index, int value) {
if (start == end) {
tree[node] = value;
return;
}
int mid = (start + end) / 2;
if (index <= mid) {
updateHelper(2 * node + 1, start, mid, index, value);
} else {
updateHelper(2 * node + 2, mid + 1, end, index, value);
}
tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
}
public int query(int left, int right) {
return queryHelper(0, 0, n - 1, left, right);
}
private int queryHelper(int node, int start, int end, int left, int right) {
if (right < start || end < left) {
return 0;
}
if (left <= start && end <= right) {
return tree[node];
}
int mid = (start + end) / 2;
int leftSum = queryHelper(2 * node + 1, start, mid, left, right);
int rightSum = queryHelper(2 * node + 2, mid + 1, end, left, right);
return leftSum + rightSum;
}
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
SegmentTree st = new SegmentTree(arr);
System.out.println("Sum of range [1, 3]: " + st.query(1, 3));
st.update(1, 10);
System.out.println("Sum of range [1, 3]: " + st.query(1, 3));
}
}

This cheatsheet provides a comprehensive overview of Segment Trees, covering implementation details, common patterns, and practical tips. It should be a helpful reference for developers working with this powerful data structure. Remember to adapt the code and concepts to the specific requirements of your problem. Good luck!