0%

线段树

什么是线段树

​ 线段树,是一种二叉搜索树。它将一段区间划分为若干单位区间,每一个节点都储存着一个区间。它支持区间求和,区间最大值,区间修改,单点修改等操作。线段树的思想和分治思想很相像。
​ 线段树的每一个节点都储存着一段区间[L…R]的信息,其中叶子节点L=R。它的大致思想是:将一段大区间平均地划分成2个小区间,每一个小区间都再平均分成2个更小区间……以此类推,直到每一个区间的L等于R(这样这个区间仅包含一个节点的信息,无法被划分)。通过对这些区间进行修改、查询,来实现对大区间的修改、查询。
​ 这样一来,每一次修改、查询的时间复杂度都只为$O ( log n ) $。

线段树原理

​ 线段树主要是把一段大区间平均地划分成两段小区间进行维护,再用小区间的值来更新大区间。这样既能保证正确性,又能使时间保持在log级别(因为这棵线段树是平衡的)。也就是说,一个[ L , R ] 的区间会被划分成[ L , ⌊ (L + R)/2 ⌋ ] 和[ ⌊ (L + R)/ 2 ⌋ + 1 , R ] 这两个小区间进行维护,直到 L = R 。
​ 下图就是一棵[ 1 , 10 ] 的线段树的分解过程(相同颜色的节点在同一层)

存储方式

通常用的都是堆式储存法,即编号为k的节点的左儿子编号为2k,右儿子编号为2k + 1,父节点编号为⌊ k/2 ⌋,用位运算优化一下,以上的节点编号就变成了k<<1, k<<1|1, k>>1

在表示线段树的时候都要在数据量n的情况下多开更大的空间,一般都是开四倍。

线段树定义

1
2
long long sum[MAX<<2];
long long lazy[MAX<<2]; //懒惰标记

初始化

1
2
3
inline void pushup(int k){
sum[k] = sum[k<<1]+sum[k<<1|1];
}
1
2
3
4
5
6
7
8
9
10
void build(int k, int l, int r){
if (l==r){
sum[k] = number[l];
return;
}
int mid = (l+r)/2;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}

单点修改

1
2
3
4
5
6
7
8
9
10
11
void change(int k,int x,int y)//k为当前节点的编号,要把编号为x的数字修改成y
{
if(a[k].l==a[k].r){
a[k].sum=y;
return;
}
int mid=(a[k].l+a[k].r)/2;//计算下一层子区间的左右边界
if(x<=mid) change(k*2,x,y);//递归到左儿子
else change(k*2+1,x,y);//递归到右儿子
update(k);//记得更新点k的值
}

区间修改

区间修改大体可以分为两步:

  1. 找到区间中全部都是要修改的点的线段树中的区间
  2. 修改这一段区间的所有点

从根节点出发,一直往下走,直到当前区间中的元素全部都是被修改元素。

  • 若当前区间全都是要修改的区间,则修改sum和懒惰标记,return

  • 当左区间包含修改的区间时,就递归到左区间;

  • 当区右间包含修改的区间时,就递归到右区间;

!修改要用到懒惰标记!

懒惰标记

标记的含义:本区间已经被更新过了,但是子区间却没有被更新过,被更新的信息是什么(区间求和只用记录有没有被访问过,而区间加减乘除等多种操作的问题则要记录进行的是哪一种操作)
这里再引入两个很重要的东西:相对标记绝对标记

相对标记

相对标记指的是可以共存的标记,且打标记的顺序与答案无关,即标记可以叠加。 比如说给一段区间中的所有数字都+a,我们就可以把标记叠加一下,比如上一次打了一个+1的标记,这一次要给这一段区间+2,那么就把+1的标记变成+3。

绝对标记

绝对标记是指不可以共存的标记,每一次都要先把标记下传,再给当前节点打上新的标记。这些标记不能改变次序,否则会出错。 比如说给一段区间的数字重新赋值,或是给一段区间进行多种操作。

这样一来,我们每一次修改区间时只要找到目标区间就可以了,给它打上懒惰标记,不用再向下递归到叶节点。

区间+x的代码:

1
2
3
4
5
6
7
8
9
10
11
12
void changeSegment(int k, int l, int r, int x, int nl, int nr){
if (l<=nl && nr<=r){
sum[k]+=(nr-nl+1)*x;
lazy[k]+=x;
return;
}
int mid = (nl+nr)/2;
pushdown(k,nl,nr);
if (l<=mid) changeSegment(k<<1,l,r,x,nl,mid);
if (r>mid) changeSegment(k<<1|1,l,r,x,mid+1,nr);
pushup(k);
}

下传标记

1
2
3
4
5
6
7
8
9
10
void pushdown(int k, int nl, int nr){
if (lazy[k]){
int mid = (nl+nr)>>1;
sum[k<<1] +=(long long)(mid-nl+1)*lazy[k];
sum[k<<1|1]+=(long long)(nr-mid)*lazy[k];
lazy[k<<1]+=lazy[k];
lazy[k<<1|1]+=lazy[k];
lazy[k] = 0;
}
}

区间查询

  1. 当查找区间在当前区间的左子区间时,递归到左子区间;
  2. 当查找区间在当前区间的右子区间时,递归到右子区间;
  3. 否则,这个区间一定是跨越两个子区间的,我们就把它切成2块,分在两个子区间查询。最后把答案合起来处理。

记得在查询之前下传标记!!!

1
2
3
4
5
6
7
8
9
long long query(int k, int l, int r, int nl, int nr){
if (l<=nl && nr<=r) return sum[k];
if (lazy[k]) pushdown(k,nl,nr);
int mid = (nl+nr)>>1;
long long ans = 0;
if (l<=mid) ans+=query(k<<1,l,r,nl,mid);
if (r>mid) ans+=query(k<<1|1,l,r,mid+1,nr);
return ans;
}

区间乘和区间加

先乘后加!!

所谓先乘后加就是在做乘法的时候把加法标记也乘上这个数,在后面做加法的时候直接加就行了。

【模板题】 洛谷 线段树2

区间*x的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
void update2(int k, int l, int r, int x, int nl, int nr){
if (l<=nl && nr<=r){
sum[k]*=x;
mul[k]*=x;
add[k]*=x;
return;
}
int mid = (nl+nr)/2;
pushdown(k,nl,nr);
if (l<=mid) update2(k<<1,l,r,x,nl,mid);
if (r>mid) update2(k<<1|1,l,r,x,mid+1,nr);
pushup(k);
}

下传标记

1
2
3
4
5
6
7
8
9
10
11
12
13
void pushdown(int k, int nl, int nr){
if (add[k]!=0||mul[k]!=1){
int mid = (nl+nr)>>1;
sum[k<<1] = sum[k<<1]*mul[k]+(mid-nl+1)*add[k];
sum[k<<1|1] = sum[k<<1|1]*mul[k]+(nr-mid)*add[k];
mul[k<<1] = mul[k]*mul[k<<1];
mul[k<<1|1] = mul[k]*mul[k<<1|1];
add[k<<1] = add[k<<1]*mul[k]+add[k];
add[k<<1|1] = add[k<<1|1]*mul[k]+add[k];
mul[k] = 1;
add[k] = 0;
}
}

参考代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>
using namespace std;
const int MAX = 1e5+5;
long long sum[MAX<<2];
long long add[MAX<<2];
long long mul[MAX<<2];
long long number[MAX];
int p;
inline void pushup(int k){
sum[k] = (sum[k<<1]+sum[k<<1|1])%p;
}
void build(int k, int l, int r){
add[k] = 0;
mul[k] = 1;
if (l==r){
sum[k] = number[l]%p;
return;
}
int mid = (l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}
void pushdown(int k, int nl, int nr){
int mid = (nl+nr)>>1;
sum[k<<1] =(sum[k<<1]*mul[k]+(mid-nl+1)*add[k])%p;
sum[k<<1|1] =(sum[k<<1|1]*mul[k]+(nr-mid)*add[k])%p;
mul[k<<1] = (mul[k]*mul[k<<1])%p;
mul[k<<1|1] = (mul[k]*mul[k<<1|1])%p;
add[k<<1] = (add[k<<1]*mul[k]+add[k])%p;
add[k<<1|1] = (add[k<<1|1]*mul[k]+add[k])%p;
mul[k] = 1;
add[k] = 0;

}
void changeSegment(int k, int l, int r, int x, int nl, int nr){
if (l<=nl && nr<=r){
sum[k] = (sum[k]+(nr-nl+1)*x)%p;
add[k] = (x+add[k])%p;
return;
}
int mid = (nl+nr)/2;
pushdown(k,nl,nr);
if (l<=mid) changeSegment(k<<1,l,r,x,nl,mid);
if (r>mid) changeSegment(k<<1|1,l,r,x,mid+1,nr);
pushup(k);
}
void update2(int k, int l, int r, int x, int nl, int nr){
if (l<=nl && nr<=r){
sum[k] = (sum[k]*x)%p;
mul[k] = (mul[k]*x)%p;
add[k] = (add[k]*x)%p;
return;
}
int mid = (nl+nr)/2;
pushdown(k,nl,nr);
if (l<=mid) update2(k<<1,l,r,x,nl,mid);
if (r>mid) update2(k<<1|1,l,r,x,mid+1,nr);
pushup(k);
}

long long query(int k, int l, int r, int nl, int nr){
if (l<=nl && nr<=r) return sum[k];
pushdown(k,nl,nr);
int mid = (nl+nr)>>1;
long long ans = 0;
if (l<=mid) ans = (ans+query(k<<1,l,r,nl,mid))%p;
if (r>mid) ans = (ans+query(k<<1|1,l,r,mid+1,nr))%p;
return ans;
}
int main(){
int n,m,choice;
int x,y;
long long k;
scanf("%d%d%d",&n,&m,&p);
for (int i=1; i<=n; i++){
scanf("%d",number+i);
}
build(1,1,n);
for (int i=0; i<m; i++){
scanf("%d",&choice);
if (choice==1){
scanf("%d%d%lld",&x,&y,&k);
update2(1,x,y,k,1,n);
}
else if (choice==2){
scanf("%d%d%lld",&x,&y,&k);
changeSegment(1,x,y,k,1,n);
}
else{
scanf("%d%d",&x,&y);
printf("%lld\n",query(1,x,y,1,n)%p);
}
}
return 0;
}