虚树
程序员文章站
2024-03-17 08:13:46
...
前言:
舒老师表示虚树不用学,但是看到ZYX在看,所以就了解一下吧
引入
有这样一类问题:给出一棵个结点的树(为1e5级别),每次指定个结点,给予ta们一些性质,求出某答案,保证与同阶。
从保证同阶这一点就可以看出,每次询问的复杂度必须基于,这样才能保证总复杂度
然而这类问题往往需要通过树形dp来解决,这样复杂度就是基于的,朴素算法会超时
由此我们可以利用单调栈建出一棵“虚树”,保证了每次求解只会在一棵节点数与同阶的树上进行,且建树过程只需要的时间
建树
为了形象的理解虚树,我从blog上盗了一张图:
左图是原树,蓝点是询问点;右图中的红点存在于虚树中
我们呢可以发现,任意两点的都会存在于虚树中
确实,在构建虚树的时候,我们就有寻找的过程
首先我们进行预处理:
原树,维护结点的序,深度,倍增父结点和倍增路径长度,
这样就方便计算(倍增算法),计算两点之间的距离
之后读入所有的关键点
把关键点按照序排序(深度小的排在前)
任何树形结构都需要有一个根结点,一般我们把结点1强制作为根结点
之后我们遍历所有的关键点,把ta扔到一个栈里
在入栈之前,我们求出关键点和栈顶元素的
- 如果在栈顶元素代表的子树内
我们就把当前元素和栈顶元素相连,并把结点push进栈里 - 如果不在栈顶元素代表的子树内
就连接栈顶的两个相邻元素,之后把栈顶元素pop出来
直到栈顶元素的子树包含了关键点
之后把关键点push到栈里
最后把栈里的所有元素都pop出来,相邻结点相连
简单的写了一下虚树的模板
因为虚树一般都是和树形dp结合,所以在建完树之后可能还会进行dp之类的操作
还有一点要注意:原树和虚树用的都是一个边数组,所以在dp的过程中有st的初始化操作
dalao们竟然表示虚树好写,mmp不懂啊
const int N=10000;
const int lg=20;
struct node{
int y,nxt,v;
}way[N<<1];
int S[N],top,mark[N],a[N],k,n;
int tot=0,st[N],pre[N][lg],len[N][lg],deep[N],in[N],out[N],clo;
//deep:每个结点的深度
//order:记录每个点的访问次序
//clo:时间戳
//建立原树
void add(int u,int w)
{
tot++;
way[tot].y=w;way[tot].nxt=st[u];st[u]=tot;
tot++;
way[tot].y=u;way[tot].nxt=st[w];st[w]=tot;
}
void dfs(int now,int fa,int dep)
{
deep[now]=dep;
in[now]=++clo;
pre[now][0]=fa;
for (int i=st[now];i;i=way[i].nxt)
if (way[i].y!=fa)
{
len[way[i].y][0]=1; //路径长度
dfs(way[i].y,now,dep+1);
}
}
int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
int d=deep[x]-deep[y];
if (d)
for (int i=0;i<lg&&d;i++,d>>=1)
if (d&1)
x=pre[x][i];
if (x==y) return x;
for (int i=lg-1;i>=0;i--)
if (pre[x][i]!=pre[y][i])
{
x=pre[x][i];
y=pre[y][i];
}
return pre[x][0];
}
void prepare()
{
clo=0;
dfs(1,0,0);
for (int i=1;i<lg;i++)
for (int j=1;j<=n;j++)
pre[j][i]=pre[pre[j][i-1]][i-1],
len[j][i]=len[j][i-1]+len[pre[j][i-1]][i-1];
}
int getlen(int x,int y)
{
int sum=0;
if (deep[x]<deep[y]) swap(x,y);
int d=deep[x]-deep[y];
if (d)
for (int i=0;i<lg&&d;i++,d>>=1)
if (d&1)
sum+=len[x][i],x=pre[x][i];
if (x==y) return sum;
for (int i=lg-1;i>=0;i--)
if (pre[x][i]!=pre[y][i])
{
sum+=len[x][i];
x=pre[x][i];
sum+=len[y][i];
y=pre[y][i];
}
sum+=len[x][0]; sum+=len[y][0];
return sum;
}
void build(int x,int y)
{
if (x==y) return;
tot++;
way[tot].y=y;way[tot].nxt=st[x];st[x]=tot;
way[tot].v=getlen(x,y);
}
int cmp(int a,int b)
{
return in[a]<in[b];
}
void solve()
{
scanf("%d",&k);
for (int i=1;i<=k;i++) scanf("%d",&a[i]),mark[a[i]]=1; //读入关键点
sort(a+1,a+1+k,cmp); //按照dfs序排序
int cnt=0; //
a[++cnt]=a[1]; //
for (int i=2;i<=k;i++) //
if (lca(a[cnt],a[i])!=a[cnt]) a[++cnt]=a[i]; //
k=cnt; //
top=0; top=0; S[++top]=1; //强制结点1为根
for (int i=1;i<=k;i++)
{
int now=a[i];
int p=lca(now,S[top]); //求lca
while (1)
{
if (deep[p]>=deep[S[top-1]]) //lca在S[top]的子树内
{
build(p,S[top--]);
if (p!=S[top]) S[++top]=p;
break;
}
build(S[top-1],S[top]); //弹栈
top--;
}
if (now!=S[top]) S[++top]=now;
}
while (top-1) build(S[top-1],S[top]),top--;
//该干嘛干嘛,该dp就dp
}