-
-
Notifications
You must be signed in to change notification settings - Fork 50.8k
Expand file tree
/
Copy pathsegment_tree_node.py
More file actions
159 lines (134 loc) · 5.47 KB
/
Copy pathsegment_tree_node.py
File metadata and controls
159 lines (134 loc) · 5.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class Node:
def __init__(self, start: int, end: int) -> None:
# Initializes a segment tree node with start and end indices
self.start = start
self.end = end
self.value: int = 0
self.left: Node = self
self.right: Node = self
class SegmentTree:
def __init__(self, nums: list[int], mode: str = "max") -> None:
"""
Initializes the Segment Tree.
:param nums: List of integers to build the tree from.
:param mode: Operation mode of the tree ('max' or 'sum').
"""
self.size = len(nums)
self.mode = mode
if mode not in {"max", "sum"}:
self.mode = "max" # Default to max if invalid mode is given
# Build the tree from the input list
self.root: Node = self.build(0, self.size - 1, nums)
def build(self, start: int, end: int, nums: list[int]) -> Node:
"""
Recursively builds the segment tree.
:param start: Start index of the segment.
:param end: End index of the segment.
:param nums: Original input array.
:return: Root node of the constructed subtree.
"""
if start > end:
return Node(0, 0)
if start == end:
# Leaf node
n = Node(start, end)
n.value = nums[start]
return n
mid = (start + end) // 2
root = Node(start, end)
root.left = self.build(start, mid, nums)
root.right = self.build(mid + 1, end, nums)
# Set the value according to the mode
if self.mode == "max":
root.value = max(root.left.value, root.right.value)
else:
root.value = root.left.value + root.right.value
return root
def max_in_range(self, start_index: int, end_index: int) -> int:
"""
Queries the maximum value in a given range.
Only works in 'max' mode.
"""
if self.mode == "sum":
raise Exception("Current Segment Tree doesn't support finding max")
if start_index > end_index or start_index < 0 or end_index >= self.size:
raise Exception("Invalid index")
if self.root is None:
raise ValueError("Tree not initialized")
return self.query(self.root, start_index, end_index, 0, self.size - 1)
def sum_in_range(self, start_index: int, end_index: int) -> int:
"""
Queries the sum of values in a given range.
Only works in 'sum' mode.
"""
if self.mode == "max":
raise Exception("Current Segment Tree doesn't support summing")
if start_index > end_index or start_index < 0 or end_index >= self.size:
raise Exception("Invalid index")
if self.root is None:
raise ValueError("Tree not initialized")
return self.query(self.root, start_index, end_index, 0, self.size - 1)
def query(
self, node: Node, start_index: int, end_index: int, start: int, end: int
) -> int:
"""
Recursively queries a value (max or sum) in a given range.
:param node: Current node in the tree.
:param start_index: Query start index.
:param end_index: Query end index.
:param start: Node's segment start.
:param end: Node's segment end.
:return: Result of query in the range.
"""
# Complete overlap
if start_index <= start and end <= end_index:
return node.value
mid = (start + end) // 2
if end_index <= mid:
# Entire range is in the left child
return self.query(node.left, start_index, end_index, start, mid)
elif start_index > mid:
# Entire range is in the right child
return self.query(node.right, start_index, end_index, mid + 1, end)
elif self.mode == "max":
return max(
self.query(node.left, start_index, end_index, start, mid),
self.query(node.right, start_index, end_index, mid + 1, end),
)
else:
return self.query(
node.left, start_index, end_index, start, mid
) + self.query(node.right, start_index, end_index, mid + 1, end)
def update(self, index: int, new_value: int) -> None:
"""
Updates a value at a specific index in the segment tree.
:param index: Index to update.
:param new_value: New value to set.
"""
if index < 0 or index >= self.size:
raise Exception("Invalid index")
self.modify(self.root, index, new_value, 0, self.size - 1)
def modify(
self, node: Node, index: int, new_value: int, start: int, end: int
) -> None:
"""
Recursively updates the tree to reflect a change at a specific index.
:param node: Current node being processed.
:param index: Index to update.
:param new_value: New value to assign.
:param start: Start index of node's segment.
:param end: End index of node's segment.
"""
if start == end:
node.value = new_value
return
mid = (start + end) // 2
if index <= mid:
self.modify(node.left, index, new_value, start, mid)
else:
self.modify(node.right, index, new_value, mid + 1, end)
# Recompute current node's value after update
if self.mode == "max":
node.value = max(node.left.value, node.right.value)
else:
node.value = node.left.value + node.right.value