-
Notifications
You must be signed in to change notification settings - Fork 1
/
segment_tree.py
298 lines (252 loc) · 10.1 KB
/
segment_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
#! 线段树
https://oi.wiki/ds/seg/
可以多维护一些区间来进行单点/区间修改和查询
"""
"""
单点修改+区间查询
"""
from typing import List
class SegmentTreeSingleModifyAdd:
"""线段树单点增加, 区间求和, 可以用树状数组平替
"""
def __init__(self, n: int, nums: List[int]) -> None:
"""初始化线段树类
Args:
n (int): 初始数组的长度
nums (List[int]): 初始的数组
"""
self.sm = [0] * (2 << n.bit_length())
self.nums = nums
self.build(1, 1, n)
def build(self, o: int, l: int, r: int)->None:
"""构建[l, r]内的线段树
Args:
o (int): 当前的节点编号
l (int): 要构建的区间的左端点
r (int): 要构建的区间的右端点
"""
if l == r: #* 如果左右端点相等, 说明是叶子节点, 直接赋值即可
#* 因为l的范围是[1, n], 所以作为数组的索引时需要减一
self.sm[o] = self.nums[l-1]
return
#* 否则需要递归遍历左右子树
mid = (l+r)//2
self.build(o*2, l, mid)
self.build(o*2+1, mid+1, r)
self.sm[o] = self.sm[o*2] + self.sm[o*2+1]
#* 具体操作就是add[1, 1, n, idx, val], 前面三个参数是不变的
def add(self, o: int, l: int, r: int, idx: int, val: int)->None:
"""在idx位置增加val
#! 需要注意的是idx的范围是[1, n]并不是从0开始
Args:
o (int): 当前的节点编号
l (int): 当前的节点所对应的区间左端点
r (int): 当前的节点所对应的区间右端点
idx (int): 需要增加值的位置, 注意范围是[1, n]
val (int): 需要增加的值
"""
if l == r:
self.sm[o]+=val
return
mid = (l+r)//2
#* 需要判断idx在左右哪个区间内
if(idx <= mid):
self.add(o*2, l, mid, idx, val)
else:
self.add(o*2+1, mid+1, r, idx, val)
self.sm[o] = self.sm[o*2] + self.sm[o*2+1]
#* 具体操作就是query_sum[1, 1, n, L ,R], 前面三个参数是不变的
def query_sum(self, o: int, l: int ,r: int, L: int, R: int)->int:
"""求[L, R]的区间和
Args:
o (int): 当前的节点编号
l (int): 当前节点所对应的区间的左端点
r (int): 当前节点所对应的区间的右端点
L (int): 查询的区间的左端点
R (int): 查询的区间的右端点
Returns:
int: 区间和
"""
if L<=l and R>=r:
return self.sm[o]
sm = 0
mid = (l+r)//2
if(L <= mid):
sm += self.query_sum(o*2, l, mid, L, R)
if(R > mid):
sm += self.query_sum(o*2+1, mid+1, r, L, R)
return sm
from math import *
class SegmentTreeSingleModifyMax:
"""线段树单点更新, 区间求最大值
"""
def __init__(self, n: int, nums: List[int]) -> None:
"""初始化线段树类
Args:
n (int): 初始数组的长度
nums (List[int]): 初始的数组
"""
self.mx = [-inf] * (2 << n.bit_length())
self.nums = nums
self.build(1, 1, n)
def build(self, o: int, l: int, r: int)->None:
"""构建[l, r]内的线段树
Args:
o (int): 当前的节点编号
l (int): 要构建的区间的左端点
r (int): 要构建的区间的右端点
"""
if l == r: #* 如果左右端点相等, 说明是叶子节点, 直接赋值即可
#* 因为l的范围是[1, n], 所以作为数组的索引时需要减一
self.mx[o] = self.nums[l-1]
return
#* 否则需要递归遍历左右子树
mid = (l+r)//2
self.build(o*2, l, mid)
self.build(o*2+1, mid+1, r)
self.mx[o] = max(self.mx[o*2], self.mx[o*2+1])
#* 具体操作就是update[1, 1, n, idx, val], 前面三个参数是不变的
def update(self, o: int, l: int, r: int, idx: int, val: int)->None:
"""将idx位置的值更新为val
#! 需要注意的是idx的范围是[1, n]并不是从0开始
Args:
o (int): 当前的节点编号
l (int): 当前的节点所对应的区间左端点
r (int): 当前的节点所对应的区间右端点
idx (int): 需要更新的位置
val (int): 需要更新的值
"""
if l == r:
self.mx[o]=val
return
mid = (l+r)//2
#* 需要判断idx在左右哪个区间内
if(idx <= mid):
self.update(o*2, l, mid, idx, val)
else:
self.update(o*2+1, mid+1, r, idx, val)
self.mx[o] = max(self.mx[o*2], self.mx[o*2+1])
#* 具体操作就是query_max[1, 1, n, L ,R], 前面三个参数是不变的
def query_max(self, o: int, l: int ,r: int, L: int, R: int)->int:
"""求[L, R]的最大值
Args:
o (int): 当前的节点编号
l (int): 当前节点所对应的区间的左端点
r (int): 当前节点所对应的区间的右端点
L (int): 查询的区间的左端点
R (int): 查询的区间的右端点
Returns:
int: 区间最大值
"""
if L<=l and R>=r:
return self.mx[o]
mx = -inf
mid = (l+r)//2
if(L <= mid):
mx = max(mx, self.query_max(o*2, l, mid, L, R))
if(R > mid):
mx = max(mx, self.query_max(o*2+1, mid+1, r, L, R))
return mx
class SegmentTreeRangeModifyAdd:
"""线段树区间修改, 区间求和, 可以用树状数组平替
需要用到lazy tag
lazy tag: 用一个数组维护每个区间需要更新的值
如果这个值 = 0, 表示不需要更新
如果这个值 != 0, 表示更新操作在这个区间停住了, 不继续递归更新子区间了
如果后面又来了一个更新破坏了lazy tag的区间, 那就得继续更新
#! 当添加懒标记的时候对当前区间进行更新, 但是不对子区间进行更新
#! 不一定用0来做判断, 但是对于求和来说是0
"""
def __init__(self, n: int, nums: List[int]) -> None:
"""初始化线段树类
Args:
n (int): 初始数组的长度
nums (List[int]): 初始的数组
"""
self.sm = [0] * (2 << n.bit_length())
self.nums = nums
self.build(1, 1, n)
self.todo = [0] * (2 << n.bit_length()) #! 存储区间更新是否停住了, 其他情况不一定是1(比如如果是区间乘, 就应该设置成1)
def maintain(self, o: int)->None:
"""更新完左右子树后对当前节点进行维护
Args:
o (int): 当前的节点编号
"""
self.sm[o] = self.sm[o*2] + self.sm[o*2+1]
def build(self, o: int, l: int, r: int)->None:
"""构建[l, r]内的线段树
Args:
o (int): 当前的节点编号
l (int): 要构建的区间的左端点
r (int): 要构建的区间的右端点
"""
if l == r: #* 如果左右端点相等, 说明是叶子节点, 直接赋值即可
#* 因为l的范围是[1, n], 所以作为数组的索引时需要减一
self.sm[o] = self.nums[l-1]
return
#* 否则需要递归遍历左右子树
mid = (l+r)//2
self.build(o*2, l, mid)
self.build(o*2+1, mid+1, r)
self.maintain(o)
#* 具体操作就是add[1, 1, n, L, R, add], 前面三个参数是不变的
def update(self, o: int, l: int, r: int, L: int, R: int, add: int)->None:
"""在[L, R]区间都加上add
#! 需要注意的是idx的范围是[1, n]并不是从0开始
Args:
o (int): 当前的节点编号
l (int): 当前的节点所对应的区间左端点
r (int): 当前的节点所对应的区间右端点
L (int): 需要增加的区间的左端点
R (int): 需要增加的区间的右端点
add (int): 需要增加的数
"""
if L <= l and r <= R:
self.todo[o] += add #* 不再继续递归更新了
self.sm[o] += (r-l+1)*add
return
mid = (l+r)//2
self.do(o, l, r)
if mid>=L: self.update(o*2, l, mid, L, R, add)
if mid < R: self.update(o*2+1, mid+1, r, L, R, add)
self.maintain(o)
def do(self, o: int, l: int, r: int):
"""将值传给子节点
#! 其他的代码需要更新这个
Args:
o (int): 子节点的编号
l (int): 当前节点所代表的区间的左端点
r (int): 当前节点所代表的区间的右端点
"""
mid = (l+r)//2
if l!=r and self.todo[o] != 0: #* 如果不是叶子节点并且有懒标记
#* 传给左右儿子
self.sm[o*2] += (mid-l+1) * self.todo[o]
self.sm[o*2+1] += (r-mid) * self.todo[o]
self.todo[o*2] += self.todo[o]
self.todo[o*2+1] += self.todo[o]
#* 自身清空
self.todo[o] = 0
#* 具体操作就是query_sum[1, 1, n, L ,R], 前面三个参数是不变的
def query_sum(self, o: int, l: int ,r: int, L: int, R: int)->int:
"""求[L, R]的区间和
Args:
o (int): 当前的节点编号
l (int): 当前节点所对应的区间的左端点
r (int): 当前节点所对应的区间的右端点
L (int): 查询的区间的左端点
R (int): 查询的区间的右端点
Returns:
int: 区间和
"""
if L<=l and R>=r:
return self.sm[o]
sm = 0
self.do(o, l, r)
mid = (l+r)//2
if(L <= mid):
sm += self.query_sum(o*2, l, mid, L, R)
if(R > mid):
sm += self.query_sum(o*2+1, mid+1, r, L, R)
return sm