luogu P2664 树上游戏
程序员文章站
2022-05-12 15:58:19
...
背景:
题目传送门:
https://www.luogu.org/problemnew/show/P2664
题意:
一棵树,每一个点有一个颜色,定义为到的颜色数量,求所有。
思路:
考虑点分治。
到的颜色数量可以分为子树内的和子树外的贡献,用一个桶记录即可(因为一个颜色只有第一次才有贡献)。细节很多。
具体来说,对当前的每个儿子的子树,如果当前点的颜色是到当前点这条链上第一次出现,那么就把当前点的加入桶。
先把所有儿子的子树全处理完,弄出来一个桶,注意根的颜色要特判。
然后统计答案,枚举根的儿子,先消除当前子树对桶的贡献,然后对当前子树,若当前点颜色第一次出现,就把当前颜色的桶的值改为,为当前儿子。
最后回溯时还原,每个子树统计完答案把影响加回来。
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define LL long long
using namespace std;
int n,len=0;
int col[200010],last[200010],size[200010],msize[200010];
LL sum[200010],t[200010],cnt[200010];
bool bz[200010];
struct node{int x,y,next;} a[200010];
int SIZE,MIN,ROOT;
LL tot,del_size;
void ins(int x,int y)
{
a[++len]=(node){x,y,last[x]}; last[x]=len;
}
void find_root(int x,int fa)
{
size[x]=1;
msize[x]=0;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y==fa||bz[y]) continue;
find_root(y,x);
size[x]+=size[y];
msize[x]=max(msize[x],size[y]);
}
msize[x]=max(msize[x],SIZE-size[x]);
if(MIN>msize[x]) MIN=msize[x],ROOT=x;
}
void solve(int x,int fa,int op)
{
size[x]=1;
cnt[col[x]]++;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y==fa||bz[y]) continue;
solve(y,x,op),size[x]+=size[y];
}
cnt[col[x]]--;
if(!cnt[col[x]]&&col[x]!=col[ROOT]) t[col[x]]+=size[x]*op,tot+=size[x]*op;
}
void clear(int x,int fa)
{
t[col[x]]=cnt[col[x]]=0;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y==fa||bz[y]) continue;
clear(y,x);
}
}
void get_ans(int x,int fa)
{
LL tmp=t[col[x]];
if(!cnt[col[x]]&&col[x]!=col[ROOT]) t[col[x]]=del_size,tot=tot-tmp+del_size;
sum[x]+=tot;
cnt[col[x]]++;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y==fa||bz[y]) continue;
get_ans(y,x);
}
cnt[col[x]]--;
if(!cnt[col[x]]&&col[x]!=col[ROOT]) tot=tot-t[col[x]]+tmp,t[col[x]]=tmp;
}
void dfs(int x)
{
bz[x]=true;
clear(x,0);
tot=0;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(!bz[y]) solve(y,x,1);
}
del_size=1;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(!bz[y]) del_size+=size[y];
}
t[col[x]]=del_size;
tot+=del_size;
sum[x]+=tot;
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(bz[y]) continue;
solve(a[i].y,x,-1);del_size-=size[y],t[col[x]]-=size[y],tot-=size[y];
get_ans(y,x);
solve(a[i].y,x,1);del_size+=size[y],t[col[x]]+=size[y],tot+=size[y];
}
for(int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(!bz[y]) SIZE=size[y],MIN=n,ROOT=0,find_root(y,x),dfs(ROOT);
}
}
int main()
{
int x,y;
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&col[i]);
for(int i=1;i<n;i++)
{
scanf("%d %d",&x,&y);
ins(x,y),ins(y,x);
}
SIZE=n,MIN=n,ROOT=0,find_root(1,0),dfs(ROOT);
for(int i=1;i<=n;i++)
printf("%lld\n",sum[i]);
}