欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【AGC008F】Black Radius【树形dp】

程序员文章站 2022-06-03 14:30:02
...

题目链接
这道题。。。神题,花了半天的时间终于搞懂了。
我们设二元组(u,d)表示以u为染色中心,d为染色半径染黑的点的集合。为了方便,我们先钦定这些二元组染出的点的集合不为全集,这样答案最后加一就好了。
我们先钦定一个点u为染色中心。
1.这个点可以作为染色中心。则这个点使染色不会发生重复的染色半径d的范围为d0,且

{d<dep1d<dep2+2

其中dep1为u子树中距离u最远的距离,dep2为u删去任意一个子树剩下的最远的距离,也就是不经过某个子树的最远的距离。注意此时u为根,它原树中的父亲也在这时算它的子树。

为什么?
d<dep1很好理解,因为我们钦定点的染色集合不能是全集。
为什么要满足第二个条件呢?
看一看这两张图。
【AGC008F】Black Radius【树形dp】
我们假设原本有一种染色方法(u,d)。如果染色中心向子树A移动一个单位,而且要使子树A内的染色情况不变,染色方法一定要变为(v,d+1)。这样子树B,C内的染色深度就会往上升两个单位。
考虑如果d=dep2+2
【AGC008F】Black Radius【树形dp】
所以要使得每种染色方案只被统计一次,必须要满足d<dep2+2
2.这个点u不能作为染色中心。
那么若存在一个可以作为染色中心的节点v满足方案(u,d)v所在子树内所有节点均被染成黑色,(u,d) 就是一个合法的染色方案。故我们只需要求出从u出发的某一个儿子v满足v子树内有可以作为染色中心的节点,到达v子树中的最远节点的距离的最小值。这个就是可行的d的最小值。
这样我们就可以限定上下界,不重不漏地算出所有方案数。
代码

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=200005,inf=0x7f7f7f7f;
int n,u,v,cnt,head[N],fa[N],siz[N],f[N][5],to[N*2],nxt[N*2];
bool flag;
char s[N];
long long ans;
void adde(int u,int v){
    to[++cnt]=v;
    nxt[cnt]=head[u];
    head[u]=cnt;
}
void dfs1(int u){
    siz[u]=s[u]-'0';
    if(siz[u]){
        flag=true;
    }
    if(!head[u]){
        f[u][0]=f[u][1]=0;
        f[u][3]=s[u]=='1'?0:inf;
    }
    int v;
    f[u][0]=f[u][1]=0;
    f[u][3]=inf;
    for(int i=head[u];i;i=nxt[i]){
        v=to[i];
        if(v!=fa[u]){
            fa[v]=u;
            dfs1(v);
            siz[u]+=siz[v];
            if(f[v][0]+1>f[u][0]){
                f[u][1]=f[u][0];
                f[u][0]=f[v][0]+1;
            }else if(f[v][0]+1>=f[u][1]){
                f[u][1]=f[v][0]+1;
            }
            if(f[v][3]<inf){
                f[u][3]=min(f[u][3],f[v][0]+1);
            }
        }
    }
    if(s[u]=='1'){
        f[u][3]=min(f[u][3],f[u][0]);
    }
}
void dfs2(int u){
    if(fa[u]){
        f[u][4]=max(f[u][4],f[u][2]-1);
    }
    int v;
    for(int i=head[u];i;i=nxt[i]){
        v=to[i];
        if(v!=fa[u]){
            if(f[v][0]+1==f[u][0]){
                f[v][2]=max(f[u][2],f[u][1])+1;
            }else{
                f[v][2]=max(f[u][2],f[u][0])+1;
            }
            dfs2(v);
        }
    }
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        adde(u,v);
        adde(v,u);
    }
    scanf("%s",s+1);
    dfs1(1);
    dfs2(1);
    if(!flag){
        puts("0");
        return 0;
    }
    for(int i=1;i<=n;i++){
        int up=min(max(f[i][0],f[i][2])-1,f[i][0]+1),down=s[i]=='1'?0:min(f[i][3],siz[i]==siz[1]?inf:f[i][2]);
        for(int j=head[i];j;j=nxt[j]){
            if(to[j]!=fa[i]){
                up=min(up,f[to[j]][4]+1);
            }
        }
        if(up>=down){
            ans+=up-down+1;
        }
    }
    printf("%lld\n",ans+1);
    return 0;
}