Skip to content

Commit 741b2d1

Browse files
Feat/segment tree 2d (#7363)
* feat: add 2D segment tree implementation * test: add comprehensive unit tests for 2D segment tree * style: format code using clang-format
1 parent 13aaad2 commit 741b2d1

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
/**
4+
* 2D Segment Tree (Tree of Trees) implementation.
5+
* This data structure supports point updates and submatrix sum queries
6+
* in a 2D grid. It achieves this by nesting 1D Segment Trees within a 1D Segment Tree.
7+
*
8+
* Time Complexity:
9+
* - Build/Initialization: O(N * M)
10+
* - Point Update: O(log N * log M)
11+
* - Submatrix Query: O(log N * log M)
12+
*
13+
* @see <a href="https://cp-algorithms.com/data_structures/segment_tree.html#2d-segment-tree">2D Segment Tree</a>
14+
*/
15+
public class SegmentTree2D {
16+
17+
/**
18+
* Represents a 1D Segment Tree.
19+
* This is equivalent to your 'Sagara' struct. It manages the columns (X-axis).
20+
*/
21+
public static class SegmentTree1D {
22+
private int n;
23+
private final int[] tree;
24+
25+
/**
26+
* Initializes the 1D Segment Tree with the nearest power of 2.
27+
*
28+
* @param size The expected number of elements (columns).
29+
*/
30+
public SegmentTree1D(int size) {
31+
n = 1;
32+
while (n < size) {
33+
n *= 2;
34+
}
35+
tree = new int[n * 2];
36+
}
37+
38+
/**
39+
* Recursively updates a point in the 1D tree.
40+
*/
41+
private void update(int index, int val, int node, int lx, int rx) {
42+
if (rx - lx == 1) {
43+
tree[node] = val;
44+
return;
45+
}
46+
47+
int mid = lx + (rx - lx) / 2;
48+
int leftChild = node * 2 + 1;
49+
int rightChild = node * 2 + 2;
50+
51+
if (index < mid) {
52+
update(index, val, leftChild, lx, mid);
53+
} else {
54+
update(index, val, rightChild, mid, rx);
55+
}
56+
57+
tree[node] = tree[leftChild] + tree[rightChild];
58+
}
59+
60+
/**
61+
* Public wrapper to update a specific index.
62+
*
63+
* @param index The column index to update.
64+
* @param val The new value.
65+
*/
66+
public void update(int index, int val) {
67+
update(index, val, 0, 0, n);
68+
}
69+
70+
/**
71+
* Retrieves the exact value at a specific leaf node.
72+
*
73+
* @param index The column index.
74+
* @return The value at the given index.
75+
*/
76+
public int get(int index) {
77+
return query(index, index + 1, 0, 0, n);
78+
}
79+
80+
/**
81+
* Recursively queries the sum in a 1D range.
82+
*/
83+
private int query(int l, int r, int node, int lx, int rx) {
84+
if (lx >= r || rx <= l) {
85+
return 0; // Out of bounds
86+
}
87+
if (lx >= l && rx <= r) {
88+
return tree[node]; // Fully inside
89+
}
90+
91+
int mid = lx + (rx - lx) / 2;
92+
int leftSum = query(l, r, node * 2 + 1, lx, mid);
93+
int rightSum = query(l, r, node * 2 + 2, mid, rx);
94+
95+
return leftSum + rightSum;
96+
}
97+
98+
/**
99+
* Public wrapper to query the sum in the range [l, r).
100+
*
101+
* @param l Left boundary (inclusive).
102+
* @param r Right boundary (exclusive).
103+
* @return The sum of the range.
104+
*/
105+
public int query(int l, int r) {
106+
return query(l, r, 0, 0, n);
107+
}
108+
}
109+
110+
// --- Start of 2D Segment Tree (equivalent to 'Sagara2D') ---
111+
112+
private int n;
113+
private final SegmentTree1D[] tree;
114+
115+
/**
116+
* Initializes the 2D Segment Tree.
117+
*
118+
* @param rows The number of rows in the matrix.
119+
* @param cols The number of columns in the matrix.
120+
*/
121+
public SegmentTree2D(int rows, int cols) {
122+
n = 1;
123+
while (n < rows) {
124+
n *= 2;
125+
}
126+
tree = new SegmentTree1D[n * 2];
127+
for (int i = 0; i < n * 2; i++) {
128+
// Every node in the outer tree is a full 1D tree!
129+
tree[i] = new SegmentTree1D(cols);
130+
}
131+
}
132+
133+
/**
134+
* Recursively updates a point in the 2D grid.
135+
*/
136+
private void update(int row, int col, int val, int node, int lx, int rx) {
137+
if (rx - lx == 1) {
138+
tree[node].update(col, val);
139+
return;
140+
}
141+
142+
int mid = lx + (rx - lx) / 2;
143+
int leftChild = node * 2 + 1;
144+
int rightChild = node * 2 + 2;
145+
146+
if (row < mid) {
147+
update(row, col, val, leftChild, lx, mid);
148+
} else {
149+
update(row, col, val, rightChild, mid, rx);
150+
}
151+
152+
// The value of the current node's column is the sum of its children's column values
153+
int leftVal = tree[leftChild].get(col);
154+
int rightVal = tree[rightChild].get(col);
155+
tree[node].update(col, leftVal + rightVal);
156+
}
157+
158+
/**
159+
* Public wrapper to update a specific point (row, col).
160+
*
161+
* @param row The row index.
162+
* @param col The column index.
163+
* @param val The new value.
164+
*/
165+
public void update(int row, int col, int val) {
166+
update(row, col, val, 0, 0, n);
167+
}
168+
169+
/**
170+
* Recursively queries the sum in a submatrix.
171+
*/
172+
private int query(int top, int bottom, int left, int right, int node, int lx, int rx) {
173+
if (lx >= bottom || rx <= top) {
174+
return 0; // Out of bounds
175+
}
176+
if (lx >= top && rx <= bottom) {
177+
// Fully inside the row range, so delegate the column query to the 1D tree
178+
return tree[node].query(left, right);
179+
}
180+
181+
int mid = lx + (rx - lx) / 2;
182+
int leftSum = query(top, bottom, left, right, node * 2 + 1, lx, mid);
183+
int rightSum = query(top, bottom, left, right, node * 2 + 2, mid, rx);
184+
185+
return leftSum + rightSum;
186+
}
187+
188+
/**
189+
* Public wrapper to query the sum of a submatrix.
190+
* Note: boundaries are [top, bottom) and [left, right).
191+
*
192+
* @param top Top row index (inclusive).
193+
* @param bottom Bottom row index (exclusive).
194+
* @param left Left column index (inclusive).
195+
* @param right Right column index (exclusive).
196+
* @return The sum of the submatrix.
197+
*/
198+
public int query(int top, int bottom, int left, int right) {
199+
return query(top, bottom, left, right, 0, 0, n);
200+
}
201+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.thealgorithms.datastructures.trees;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import org.junit.jupiter.api.Test;
6+
7+
public class SegmentTree2DTest {
8+
9+
@Test
10+
void testInitialEmptyQueries() {
11+
SegmentTree2D segmentTree = new SegmentTree2D(4, 4);
12+
13+
// Initial tree should return 0 for any query
14+
assertEquals(0, segmentTree.query(0, 4, 0, 4));
15+
assertEquals(0, segmentTree.query(1, 3, 1, 3));
16+
}
17+
18+
@Test
19+
void testUpdateAndPointQuery() {
20+
SegmentTree2D segmentTree = new SegmentTree2D(5, 5);
21+
22+
segmentTree.update(2, 3, 10);
23+
segmentTree.update(0, 0, 5);
24+
25+
// Querying single points [row, row+1) x [col, col+1)
26+
assertEquals(10, segmentTree.query(2, 3, 3, 4));
27+
assertEquals(5, segmentTree.query(0, 1, 0, 1));
28+
29+
// Empty point should be 0
30+
assertEquals(0, segmentTree.query(1, 2, 1, 2));
31+
}
32+
33+
@Test
34+
void testSubmatrixQuery() {
35+
SegmentTree2D segmentTree = new SegmentTree2D(4, 4);
36+
37+
// Matrix simulation:
38+
// [1, 2, 0, 0]
39+
// [3, 4, 0, 0]
40+
// [0, 0, 0, 0]
41+
// [0, 0, 0, 0]
42+
segmentTree.update(0, 0, 1);
43+
segmentTree.update(0, 1, 2);
44+
segmentTree.update(1, 0, 3);
45+
segmentTree.update(1, 1, 4);
46+
47+
// Top-left 2x2 sum: 1+2+3+4 = 10
48+
assertEquals(10, segmentTree.query(0, 2, 0, 2));
49+
50+
// First row sum: 1+2 = 3
51+
assertEquals(3, segmentTree.query(0, 1, 0, 4));
52+
53+
// Second column sum: 2+4 = 6
54+
assertEquals(6, segmentTree.query(0, 4, 1, 2));
55+
}
56+
57+
@Test
58+
void testUpdateOverwriting() {
59+
SegmentTree2D segmentTree = new SegmentTree2D(3, 3);
60+
61+
segmentTree.update(1, 1, 5);
62+
assertEquals(5, segmentTree.query(1, 2, 1, 2));
63+
64+
// Overwrite the same point
65+
segmentTree.update(1, 1, 20);
66+
assertEquals(20, segmentTree.query(1, 2, 1, 2));
67+
68+
// Full matrix sum should just be this point
69+
assertEquals(20, segmentTree.query(0, 3, 0, 3));
70+
}
71+
}

0 commit comments

Comments
 (0)