树链剖分
【概念】
树链剖分,说简单点,就是把一颗树划分成多个序列,然后在序列上进行一些操作的算法
下面是一些必须知道的概念:
- 重儿子:父亲节点的所有儿子中子树结点数目最多的结点;
- 轻儿子:父亲节点中除了重儿子以外的儿子;
- 重边:父亲结点和重儿子连成的边;
- 轻边:父亲节点和轻儿子连成的边;
- 重链:由多条重边连接而成的路径;
- 轻链:由多条轻边连接而成的路径;
例如下图:
其中:加粗了的边是重边,其他的是轻边;
用重边连接的结点是重儿子,其余的是轻儿子;
2-4-9-13-14 是其中的一条重链,2-5是一条轻链;
用红点标记的点就是该结点所在重链中的起点,也就是下文提到的 top;
还有一些变量:
变量 | 解释 |
dep[ x ] | x 在树中的深度 |
size[ x ] | 以 x 为根的子树中节点个数 |
top[ x ] | x 所在重链的顶部节点 |
son[ x ] | x 的重儿子 |
pos[ x ] | x 在序列中的位置(下标) |
idx[ x ] | 序列中第 x 个位置对应的树中节点编号 |
father[ x ] | x 的父亲节点 |
【基本思想】
我们介绍一下最常用的轻重链剖分
我们将重链视作一段区间,以刚才的图为例,我们可以得到(1,4,9,13,14),(2,6,11),(3,7),(5),(8),(10),(12)这几个序列,为了方便我们会把它们放入一个序列里,不过重链一定是其中连续的一段
对于树链剖分,我们有两种方法解决,一种是两遍 dfs,另一种是 bfs
【两遍dfs】
第一遍 dfs 我们得到 father,dep,size,son 的值
第二遍 dfs 我们得到 pos,idx,top 的值
dfs 有一个好处,即以任意点为根子树中所有节点在序列中也是连续的一段,就是 pos[ x ] ~ pos[ x ] + size [ x ] - 1,这样的话处理起有关子树的问题是会方便许多
void add_point(int x)
{
tot++;
pos[x]=tot;
idx[tot]=x;
}
void dfs1(int x)
{
int i,j;
size[x]=1;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father[x])
{
father[j]=x;
dep[j]=dep[x]+1;
dfs1(j);
size[x]+=size[j];
if(size[j]>size[son[x]])
son[x]=j;
}
}
}
void dfs2(int x)
{
int i,j;
if(son[x])
{
add_point(son[x]);
top[son[x]]=top[x];
dfs2(son[x]);
}
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father[x]&&j!=son[x])
{
add_point(j);
top[j]=j;
dfs2(j);
}
}
}
【bfs】
有时为了防止 dfs 栈溢出,我们也可以考虑 bfs,和 dfs 其实差不多
void bfs()
{
int x,y,i,j,end=1;
q[1]=1;
dep[1]=1;
for(i=1;i<=end;++i)
{
x=q[i];
size[i]=1;
for(j=first[x];j;j=next[j])
{
y=v[j];
if(y==father[x])
continue;
father[y]=x;
dep[y]=dep[x]+1;
q[++end]=y;
}
}
for(i=end;i>=2;--i)
{
x=q[i];
y=father[x];
size[y]+=size[x];
if(size[x]>size[son[y]])
son[y]=x;
}
for(i=1;i<=end;++i)
{
x=q[i];
if(top[x])
continue;
for(j=x;j;j=son[j])
{
tot++;
pos[j]=tot;
idx[tot]=j;
top[j]=x;
}
}
}
【修改(或查询)操作】
假如说现在给出(x,y),要对 x 到 y 的路径上进行修改或查询答案,我们可以分别处理 x,y 到它们 lca 的路径。对于重链,它们相当于是一段区间,用线段树维护即可;对于轻边,直接跳过即可,因为轻边的两端点一定在两条重链上
下面以修改(x 到 y 的路径都加上 z)为例(modify 是线段树的区间修改操作):
void change(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
modify(1,1,n,pos[top[x]],pos[x],z);
x=father[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
modify(1,1,n,pos[x],pos[y],z);
}
【复杂度分析】
在证明复杂度之前,先了解一下几条性质:
性质 1:如果(u,v)为轻边,则 size(v) ≤ size(u) / 2 。
证明:如果 size(v) > size(u) / 2 ,则 size(v) 必然比其他儿子的 size 要大,那么(u,v)必然为重边,与(u,v)为轻边矛盾。
性质 2:从根到某一点 x 的路径上的轻边个数不大于 log n。
证明:x 为叶子节点时轻边个数最多。由性质 1 可知,每经过一条轻边,子树的节点个数至少比原来少一半,所以至多经过 log n 条轻边就到达叶子节点了。
性质 3:那么对于每个点到根的路径上都不超过 log n 条轻边和 log n 条重路径。
证明:显然每条重路径的起点和终点都是由轻边构成,而由性质 2 可知,每个点到根节点的轻边个数为 log n,所以重路径个数也为 log n。
dfs 的时间复杂度:O()
总时间复杂度:O()
【例题】
emmm,很简单的板题吧
因为有关于子树的操作,我们用两遍 dfs
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 100005
#define M 200005
using namespace std;
int n,m,r,p,cnt,tot;
int add[4*N],sum[4*N];
int first[N],v[M],next[M];
int a[N],dep[N],idx[N],pos[N],son[N],top[N],size[N],father[N];
void add_edge(int x,int y)
{
cnt++;
next[cnt]=first[x];
first[x]=cnt;
v[cnt]=y;
}
void add_point(int x)
{
tot++;
pos[x]=tot;
idx[tot]=x;
}
void dfs1(int x)
{
int i,j;
size[x]=1;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father[x])
{
father[j]=x;
dep[j]=dep[x]+1;
dfs1(j);
size[x]+=size[j];
if(size[j]>size[son[x]])
son[x]=j;
}
}
}
void dfs2(int x)
{
int i,j;
if(son[x])
{
add_point(son[x]);
top[son[x]]=top[x];
dfs2(son[x]);
}
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father[x]&&j!=son[x])
{
add_point(j);
top[j]=j;
dfs2(j);
}
}
}
void build(int root,int l,int r)
{
if(l==r)
{
sum[root]=a[idx[l]]%p;
return;
}
int mid=(l+r)>>1;
build(root<<1,l,mid);
build(root<<1|1,mid+1,r);
sum[root]=(sum[root<<1]+sum[root<<1|1])%p;
}
void pushdown(int root,int l,int r)
{
int mid=(l+r)>>1;
add[root<<1]=(add[root<<1]+add[root])%p;
add[root<<1|1]=(add[root<<1|1]+add[root])%p;
sum[root<<1]=(sum[root<<1]+1ll*(mid-l+1)*add[root]%p)%p;
sum[root<<1|1]=(sum[root<<1|1]+1ll*(r-mid)*add[root]%p)%p;
add[root]=0;
}
void modify(int root,int l,int r,int x,int y,int z)
{
if(l>=x&&r<=y)
{
add[root]=(add[root]+z)%p;
sum[root]=(sum[root]+1ll*(r-l+1)*z%p)%p;
return;
}
int mid=(l+r)>>1;
if(add[root]) pushdown(root,l,r);
if(x<=mid) modify(root<<1,l,mid,x,y,z);
if(y>mid) modify(root<<1|1,mid+1,r,x,y,z);
sum[root]=(sum[root<<1]+sum[root<<1|1])%p;
}
void change(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
modify(1,1,n,pos[top[x]],pos[x],z);
x=father[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
modify(1,1,n,pos[x],pos[y],z);
}
int query(int root,int l,int r,int x,int y)
{
if(l>=x&&r<=y)
return sum[root];
int ans=0,mid=(l+r)>>1;
if(add[root]) pushdown(root,l,r);
if(x<=mid) ans=(ans+query(root<<1,l,mid,x,y))%p;
if(y>mid) ans=(ans+query(root<<1|1,mid+1,r,x,y))%p;
return ans;
}
int ask(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+query(1,1,n,pos[top[x]],pos[x]))%p;
x=father[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=(ans+query(1,1,n,pos[x],pos[y]))%p;
return ans;
}
int main()
{
int s,i,x,y,z;
scanf("%d%d%d%d",&n,&m,&r,&p);
for(i=1;i<=n;++i)
scanf("%d",&a[i]);
for(i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
add_edge(x,y);
add_edge(y,x);
}
add_point(r);
top[r]=r;
dfs1(r);
dfs2(r);
build(1,1,n);
for(i=1;i<=m;++i)
{
scanf("%d",&s);
if(s==1)
{
scanf("%d%d%d",&x,&y,&z);
change(x,y,z);
}
if(s==2)
{
scanf("%d%d",&x,&y);
printf("%d\n",ask(x,y));
}
if(s==3)
{
scanf("%d%d",&x,&z);
modify(1,1,n,pos[x],pos[x]+size[x]-1,z);
}
if(s==4)
{
scanf("%d",&x);
printf("%d\n",query(1,1,n,pos[x],pos[x]+size[x]-1));
}
}
return 0;
}