Skip to content

Commit a42663f

Browse files
committed
feat: add 2D segment tree implementation
1 parent 13aaad2 commit a42663f

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
/**
4+
* 2D Segment Tree (Tree of Trees) implementation.
5+
* This data structure supports point updates and submatrix sum queries
6+
* in a 2D grid. It achieves this by nesting 1D Segment Trees within a 1D Segment Tree.
7+
*
8+
* Time Complexity:
9+
* - Build/Initialization: O(N * M)
10+
* - Point Update: O(log N * log M)
11+
* - Submatrix Query: O(log N * log M)
12+
*
13+
* @see <a href="https://cp-algorithms.com/data_structures/segment_tree.html#2d-segment-tree">2D Segment Tree</a>
14+
*/
15+
public class SegmentTree2D {
16+
17+
/**
18+
* Represents a 1D Segment Tree.
19+
* This is equivalent to your 'Sagara' struct. It manages the columns (X-axis).
20+
*/
21+
public static class SegmentTree1D {
22+
private int n;
23+
private final int[] tree;
24+
25+
/**
26+
* Initializes the 1D Segment Tree with the nearest power of 2.
27+
*
28+
* @param size The expected number of elements (columns).
29+
*/
30+
public SegmentTree1D(int size) {
31+
n = 1;
32+
while (n < size) {
33+
n *= 2;
34+
}
35+
tree = new int[n * 2];
36+
}
37+
38+
/**
39+
* Recursively updates a point in the 1D tree.
40+
*/
41+
private void update(int index, int val, int node, int lx, int rx) {
42+
if (rx - lx == 1) {
43+
tree[node] = val;
44+
return;
45+
}
46+
47+
int mid = lx + (rx - lx) / 2;
48+
int leftChild = node * 2 + 1;
49+
int rightChild = node * 2 + 2;
50+
51+
if (index < mid) {
52+
update(index, val, leftChild, lx, mid);
53+
} else {
54+
update(index, val, rightChild, mid, rx);
55+
}
56+
57+
tree[node] = tree[leftChild] + tree[rightChild];
58+
}
59+
60+
/**
61+
* Public wrapper to update a specific index.
62+
*
63+
* @param index The column index to update.
64+
* @param val The new value.
65+
*/
66+
public void update(int index, int val) {
67+
update(index, val, 0, 0, n);
68+
}
69+
70+
/**
71+
* Retrieves the exact value at a specific leaf node.
72+
*
73+
* @param index The column index.
74+
* @return The value at the given index.
75+
*/
76+
public int get(int index) {
77+
return query(index, index + 1, 0, 0, n);
78+
}
79+
80+
/**
81+
* Recursively queries the sum in a 1D range.
82+
*/
83+
private int query(int l, int r, int node, int lx, int rx) {
84+
if (lx >= r || rx <= l) {
85+
return 0; // Out of bounds
86+
}
87+
if (lx >= l && rx <= r) {
88+
return tree[node]; // Fully inside
89+
}
90+
91+
int mid = lx + (rx - lx) / 2;
92+
int leftSum = query(l, r, node * 2 + 1, lx, mid);
93+
int rightSum = query(l, r, node * 2 + 2, mid, rx);
94+
95+
return leftSum + rightSum;
96+
}
97+
98+
/**
99+
* Public wrapper to query the sum in the range [l, r).
100+
*
101+
* @param l Left boundary (inclusive).
102+
* @param r Right boundary (exclusive).
103+
* @return The sum of the range.
104+
*/
105+
public int query(int l, int r) {
106+
return query(l, r, 0, 0, n);
107+
}
108+
}
109+
110+
// --- Start of 2D Segment Tree (equivalent to 'Sagara2D') ---
111+
112+
private int n;
113+
private final SegmentTree1D[] tree;
114+
115+
/**
116+
* Initializes the 2D Segment Tree.
117+
*
118+
* @param rows The number of rows in the matrix.
119+
* @param cols The number of columns in the matrix.
120+
*/
121+
public SegmentTree2D(int rows, int cols) {
122+
n = 1;
123+
while (n < rows) {
124+
n *= 2;
125+
}
126+
tree = new SegmentTree1D[n * 2];
127+
for (int i = 0; i < n * 2; i++) {
128+
// Every node in the outer tree is a full 1D tree!
129+
tree[i] = new SegmentTree1D(cols);
130+
}
131+
}
132+
133+
/**
134+
* Recursively updates a point in the 2D grid.
135+
*/
136+
private void update(int row, int col, int val, int node, int lx, int rx) {
137+
if (rx - lx == 1) {
138+
tree[node].update(col, val);
139+
return;
140+
}
141+
142+
int mid = lx + (rx - lx) / 2;
143+
int leftChild = node * 2 + 1;
144+
int rightChild = node * 2 + 2;
145+
146+
if (row < mid) {
147+
update(row, col, val, leftChild, lx, mid);
148+
} else {
149+
update(row, col, val, rightChild, mid, rx);
150+
}
151+
152+
// The value of the current node's column is the sum of its children's column values
153+
int leftVal = tree[leftChild].get(col);
154+
int rightVal = tree[rightChild].get(col);
155+
tree[node].update(col, leftVal + rightVal);
156+
}
157+
158+
/**
159+
* Public wrapper to update a specific point (row, col).
160+
*
161+
* @param row The row index.
162+
* @param col The column index.
163+
* @param val The new value.
164+
*/
165+
public void update(int row, int col, int val) {
166+
update(row, col, val, 0, 0, n);
167+
}
168+
169+
/**
170+
* Recursively queries the sum in a submatrix.
171+
*/
172+
private int query(int top, int bottom, int left, int right, int node, int lx, int rx) {
173+
if (lx >= bottom || rx <= top) {
174+
return 0; // Out of bounds
175+
}
176+
if (lx >= top && rx <= bottom) {
177+
// Fully inside the row range, so delegate the column query to the 1D tree
178+
return tree[node].query(left, right);
179+
}
180+
181+
int mid = lx + (rx - lx) / 2;
182+
int leftSum = query(top, bottom, left, right, node * 2 + 1, lx, mid);
183+
int rightSum = query(top, bottom, left, right, node * 2 + 2, mid, rx);
184+
185+
return leftSum + rightSum;
186+
}
187+
188+
/**
189+
* Public wrapper to query the sum of a submatrix.
190+
* Note: boundaries are [top, bottom) and [left, right).
191+
*
192+
* @param top Top row index (inclusive).
193+
* @param bottom Bottom row index (exclusive).
194+
* @param left Left column index (inclusive).
195+
* @param right Right column index (exclusive).
196+
* @return The sum of the submatrix.
197+
*/
198+
public int query(int top, int bottom, int left, int right) {
199+
return query(top, bottom, left, right, 0, 0, n);
200+
}
201+
}

0 commit comments

Comments
 (0)