SegTree(线段树)

什么是线段树?

线段树是一棵二叉树,每个节点维护一个区间和区间上的值,是一种用来维护 区间信息 的数据结构。

我们先从线段树能提供的操作上来理解。想象一个数组,每个下标上可以存一个值。所谓区间就是一段连续的数组下标。

我们可以进行的操作有:

1)区间赋值:设置一个区间内所有下标对应的值

2)单点赋值:设置单点的值

3)区间查询:查询一个区间内的最大值 / 最小值 / 值的和

4)单点查询:查询一个点的值

如果用数组来实现上述操作,所有区间操作的时间复杂度都是O(n),n是区间长度,单点操作的复杂度都是O(1),而线段树可以把上述全部操作的时间复杂度同时变成O(logn)

线段树的实现

根据维护的信息和支持的操作不同线段树分很多种,常用的有:着色线段树、覆盖线段树、求和线段树、最值线段树。

计数线段树:可以对区间反复覆盖,动态查询单点的覆盖次数。

着色线段树:可以对区间反复着色,动态查询单点的颜色。

求和线段树:可以对单点或区间赋值,动态查询区间数据的和。

最值线段树:可以对单点或区间赋值,动态查询区间数据的最值。

为了方便起见我用求和线段树进行讲解,并且只实现两种操作:

1)对单点赋值。

2)查询区间和。

假设我们有一个数组

1
arr = {3, 8, 3, 3, 5, 6, 8, 7};

通常情况比如我们求区间[2,7)的和需要遍历区间上的元素,O(区间长度),如何减少运算次数呢,最简单的思路是我们预处理前缀和sum[],令

1
sum[i]= arr[0]+arr[1]+...+arr[i-1]

当我们需要算[2,7)的区间和,其实就算[0,7)前缀和减去[0,2)的前缀和,即sum[7]-sum[2],这样就可以用O(1)的时间算出任意区间的和,如果数组元素不会发生动态变化这样是可以的,但如果需要交替修改数组元素和查询区间和,这样处理会导致前缀和维护的成本很高,原本的arr[i]=x,我们不得不修改所有k>i的sum[k]来维护前缀和,这样修改数组元素的时间就从原本的O(1)变成了O(n)。

有没有方法可以在修改单点值的便利性和查询区间和的便利性上做个折中呢,肯定是有的,我们可以预处理一些子段和而不是所有前缀和,让修改单点值和查询区间和都只需要访问O(logn)的元素,树状数组线段树都是类似这个思想,其实树状数组的设计更符合这个思想,但它只能用于求和,树状数组能做的线段树都能做,我们之后有空再讲树状数组,对于线段树我们先在这个数组上生成一个子段和数组,让

1
b = { arr[0]+arr[1],  arr[2]+arr[3], arr[4]+arr[5], arr[6]+arr[7]} = {11, 6, 11, 15};

这样求区间[2, 7)的和的时候,可以取b[1] + b[2] + arr[6],只需要算三个数的和就行了。

更近一步,我们可以在数组b上再合并建立 c = { b[0]+ b[1], b[2]+b[3], …}, d, e, …,直至把整个数组合并成一个元素,如下图

segtree

这样之后,对于这个数组上的任意区间的和都可以只取O(logn)的元素的和来得到。

为了方便,我们通过二叉树来维护这些数据,我们把最上面的整个数组的和43作为二叉树的根,同时记住这个节点对应的左右端点即0和8,我们把这个节点的数据表示为(0, 8, 43),然后把它的左孩子就是如图的和为17的节点,数据为(0, 4, 17),同理右孩子(4, 8, 26),以此类推,这样就可以建立一颗完全二叉树(我这里有意将数组的大小选为2的幂,如果不是2的幂我们也可以扩充到2的幂来建树)。

对于修改单点值的操作,我们要从树根一直修改到叶子,正好修改了树的深度个节点,复杂度也是O(logn)。

因为是完全二叉树,我们可以通过数组来实现,用data[0]作为根,data[k]的孩子是data[2*k+1]data[2*k+2]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
using namespace std;

#define MAXN 8
int bg[MAXN*2]; // 区间左端点,闭
int nd[MAXN*2]; // 区间右端点,开
int sm[MAXN*2]; // 节点值,即区间和

// 递归构建树
void build(int k, int b, int e) {
bg[k] = b;
nd[k] = e;
sm[k] = 0;
if(e-b>1) {
int mid = b + (e-b) / 2;
build(k*2+1, b, mid);
build(k*2+2, mid, e);
}
}

// 设置数组下标i的元素值,类似arr[i] = val
void change(int k, int i, int val) {
if(nd[k]-bg[k]==1) { // 叶子节点
sm[k] = val; // 直接更新
} else {
int mid = bg[k] + (nd[k]-bg[k]) / 2;
// 递归向下更新
if(i<mid) change(k*2+1, i, val);
else change(k*2+2, i, val);
// 回溯时通过孩子的值更新当前节点值
sm[k] = sm[k*2+1] + sm[k*2+2];
}
}

// 查询[b, e)的元素和
int query(int k, int b, int e) {
if(b<=bg[k] && nd[k]<=e) {
cout << "DEBUG:[" << bg[k] << "," << nd[k] << ") sm[k]=" << sm[k] << endl;
return sm[k];
}
int mid = bg[k] + (nd[k]-bg[k]) / 2;
return (b<mid ?query(k*2+1, b, e):0) + (mid<e ? query(k*2+2, b, e):0);
}

// 把n向上对齐到2的幂
int align(int n) {
if(n & (n-1)) {
while(n & (n-1)) {
n &= (n-1);
}
n = (n << 1);
}
return n;
}

int main()
{
int arr[] = {3,8,3,3,5,6,8,7};
build(0, 0, align(8));
for(int i=0;i<8;i++) {
change(0, i, arr[i]);
}
cout << query(0, 2, 7) << endl;
}

输出

1
2
3
4
DEBUG:[2,4) sm[k]=6
DEBUG:[4,6) sm[k]=11
DEBUG:[6,7) sm[k]=8
25

可以看到,和预期的一样,只通过3个节点的值求和算出了区间[2,7)的和为25。

最值线段树

对于最大值、最小值的线段树,我们只需要把更新节点值的代码改掉。

区间更新

对于区间更新,例如把[a,b)的值全部设置为x,或者让[a,b)的值全部增加x,我们不能对区间内的点逐一更新,否则复杂度会变成O(nlogn)了,区间更新时,我们要引入懒惰标志,我们还是从根开始拆分区间,如果当前节点的区间被完全覆盖,我们就更新当前节点值并不向下继续更新,并把待更新的值记在懒惰标志内,直到查询需要用到节点值时再逐级更新并把懒惰标志一层层推下去。

离散化

我们可以看到上述线段树实现方式使用的数据空间是数组最大下标的2倍,有时候区间的范围很大,但区间的更新和查询操作的数量并不多,我们可以先统计所有的区间端点,比如有n个端点,然后排序,在所有区间端点和 0,1,…,n-1 之间做一一映射,然后就可以把线段树的空间降到 2*n,和区间个数相关而和区间范围无关。

动态开点

我们也可以不用数组表示的完全二叉树来实现线段树,而用一棵记录左右孩子指针的动态申请空间的普通二叉树,根据更新的区间来动态生成区间节点,这一般被称作动态开点。

动态开点线段树实现

下面是动态开点线段树的完整实现,包含了懒惰标志的使用,默认区间范围(-IINF, IINF),支持负下标,支持使用下标读写单点数据,支持使用下标修改区间数据,支持区间求和、最大值、最小值查询。
区间表示前闭后开[b, e)
支持单点修改 tree[i] = x; tree[i] += x;
支持区间修改 tree[{b,e}] = x; tree[{b,e}] += x;
支持取单点值 tree[i];
支持取区间sum/min/max tree.sum(b, e);

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
typedef long long ll;

class SegTree {
private:
enum { INF = ll(4.6e18) };

struct Node {
Node* left;
Node* right;
ll sum;
ll min;
ll max;
ll lazy;
bool set_not_add;
Node():left(0), right(0), sum(0), min(0), max(0), lazy(INF), set_not_add(false) {}
};

void update(Node* node, int start, int end, int l, int r, bool set_not_add, ll val) {
if (l <= start && end <= r) {
if(set_not_add) {
node->sum = (end - start + 1) * val;
node->min = val;
node->max = val;
node->lazy = val;
} else {
node->sum += (end - start + 1) * val;
node->min += val;
node->max += val;
node->lazy = (node->lazy!=INF?node->lazy:0) + val;
}
node->set_not_add = set_not_add;
return;
}
int mid = (start + end) >> 1;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) update(node->left, start, mid, l, r, set_not_add, val);
if (r > mid) update(node->right, mid + 1, end, l, r, set_not_add, val);
pushUp(node);
}

ll sum(Node* node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node->sum;
int mid = (start + end) >> 1;
ll ans = 0;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans += sum(node->left, start, mid, l, r);
if (r > mid) ans += sum(node->right, mid + 1, end, l, r);
return ans;
}

ll min(Node* node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node->min;
int mid = (start + end) >> 1;
ll ans = INF;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans = std::min(ans, min(node->left, start, mid, l, r));
if (r > mid) ans = std::min(ans, min(node->right, mid + 1, end, l, r));
return ans;
}

ll max(Node* node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node->max;
int mid = (start + end) >> 1;
ll ans = -INF;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans = std::max(ans, max(node->left, start, mid, l, r));
if (r > mid) ans = std::max(ans, max(node->right, mid + 1, end, l, r));
return ans;
}

void pushUp(Node* node) {
node->sum = node->left->sum + node->right->sum;
node->min = std::min(node->left->min, node->right->min);
node->max = std::max(node->left->max, node->right->max);
}

void pushDown(Node* node, int leftNum, int rightNum) {
if (node->left == nullptr) node->left = new Node();
if (node->right == nullptr) node->right = new Node();
if (node->lazy == INF) return ;
if(node->set_not_add) {
node->left->sum = node->lazy * leftNum;
node->left->min = node->lazy;
node->left->max = node->lazy;
node->left->lazy = node->lazy;
node->right->sum = node->lazy * rightNum;
node->right->min = node->lazy;
node->right->max = node->lazy;
node->right->lazy = node->lazy;
} else {
node->left->sum += node->lazy * leftNum;
node->left->min += node->lazy;
node->left->max += node->lazy;
node->left->lazy += node->lazy;
node->right->sum += node->lazy * rightNum;
node->right->min += node->lazy;
node->right->max += node->lazy;
node->right->lazy += node->lazy;
}
node->left->set_not_add = node->set_not_add;
node->right->set_not_add = node->set_not_add;
node->lazy = INF;
}

void destory(Node* p) {
if(p == NULL) return;
auto left = p->left;
auto right = p->right;
delete p;
destory(left);
destory(right);
}

public:
Node* root = NULL;
int start;
int end;
SegTree(int start=0, int end=1e9) : root(new Node()), start(start), end(end) {}
~SegTree() { destory(root); }
private:
class ItemAccessor {
public:
ItemAccessor(SegTree& st, int i): st(st), i(i) {}
ItemAccessor& operator = (ll x) {st.update(st.root, st.start, st.end, i, i, true, x);return *this;}
ItemAccessor& operator += (ll x) {st.update(st.root, st.start, st.end, i, i, false, x);return *this;}
ItemAccessor& operator -= (ll x) {st.update(st.root, st.start, st.end, i, i, false, -x);return *this;}
operator ll() const {return st.min(st.root, st.start, st.end, i, i);}
private:
SegTree& st;
int i;
};

class RangeAccessor {
public:
RangeAccessor(SegTree& st, int b, int e):st(st), b(b), e(e) {}
void operator = (ll x) {st.update(st.root, st.start, st.end, b, e-1, true, x);}
void operator += (ll x) {st.update(st.root, st.start, st.end, b, e-1, false, x);}
void operator -= (ll x) {st.update(st.root, st.start, st.end, b, e-1, false, -x);}
private:
SegTree& st;
int b;
int e;
};

public:
// 通过下标访问单点值的语法糖
ItemAccessor operator [] (int n) {
return ItemAccessor(*this, n);
}

// 通过下标访问区间值的语法糖
RangeAccessor operator [] (const tuple<int, int>& range) {
return RangeAccessor(*this, std::get<0>(range), std::get<1>(range));
}

ll sum(int l, int r) { return sum(root, start, end, l, r-1); }
ll min(int l, int r) { return min(root, start, end, l, r-1); }
ll max(int l, int r) { return max(root, start, end, l, r-1); }
};