bzoj3697 采药人的路径 (点分治)
程序员文章站
2022-05-08 18:27:26
...
bzoj3697 采药人的路径
原题地址:http://www.lydsy.com/JudgeOnline/problem.php?id=3697
题意:
采药人的药田是一个树状结构,每条路径上种植一种药,有0/1两种药。
草药人希望选择一条两种药材数目相等的路径,且选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是两种药数量相等。
他一共可以选择多少种不同的路径。
数据范围
N ≤ 100,000
题解:
如果把0的药看做-1,那么就是选择一条路径,他的边权和为零,且存在一个点,使分成路径的两侧边权和也为零。
1.休息站在一条从根出发的路径上:
那么该路径上存在一个点,根到他的边权和与根到端点的边权和相等。 他可以与其他子树的边权和为相反数的路径组成合法路径。
2.休息站在根。
两条不在同一子树的边权和为0的路径互相组合。
于是在点分治时,统计满足以上的路径条数。
为了区分子树,用f,g两个数组。
f[i][0/1]表示当前子树路径和为i且这条路径上方是否有过和为i的前缀,
g[i][0/1]分别表示之前遍历的所有子树路径和为i且这条路径上方是否有过和为i的前缀,
初值为g[i][0]=1,要考虑到由根节点出发的路径。
ans+=f[i][1]*(g[-i][0]+g[-i][1])+f[i][0]*g[-i][1]
ans+=f[0][0]*(g[0][0]-1)
清空只清空赋了值的。
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define LL long long
using namespace std;
const int N=100005;
LL f[2*N][2],g[2*N][2],ans=0,cnt[2*N];
int n,head[N],to[2*N],w[2*N],nxt[2*N],num=0,size[N],ss[N],sz,root=0,dis[N],dep[N],mxdep=0;
bool del[N];
void build(int u,int v,int ww)
{
num++;
to[num]=v;
nxt[num]=head[u];
w[num]=ww;
head[u]=num;
}
void getroot(int u,int f)
{
size[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==f||del[v]) continue;
getroot(v,u);
size[u]+=size[v];
ss[u]=max(ss[u],size[v]);
}
ss[u]=max(size[u],sz-size[u]);
if(ss[u]<ss[root]) root=u;
}
void getdis(int u,int fa)
{
dep[u]=dep[fa]+1; mxdep=max(dep[u],mxdep);
if(cnt[dis[u]+n]) f[dis[u]+n][1]++; else f[dis[u]+n][0]++; cnt[dis[u]+n]++;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa||del[v]) continue;
dis[v]=dis[u]+w[i];
getdis(v,u);
}
cnt[dis[u]+n]--;
}
LL cal(int u)
{
g[0+n][0]=1; dep[u]=1; int mx=0;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==u||del[v]) continue;
dis[v]=w[i]; mxdep=0;
getdis(v,u); mx=max(mx,mxdep);
for(int i=-mxdep;i<=mxdep;i++)
ans+=f[i+n][1]*g[-i+n][0]+f[i+n][1]*g[-i+n][1]+f[i+n][0]*g[-i+n][1];
ans+=(g[0+n][0]-1LL)*f[0+n][0];
for(int i=-mxdep;i<=mxdep;i++)
{
g[i+n][0]+=f[i+n][0]; g[i+n][1]+=f[i+n][1];
f[i+n][0]=f[i+n][1]=0;
}
}
for(int i=-mx;i<=mx;i++)
g[i+n][0]=g[i+n][1]=0;
}
void dfs(int x)
{
root=0; getroot(x,0);
cal(root);
del[root]=1;
for(int i=head[root];i;i=nxt[i])
{
int v=to[i];
if(del[v]) continue;
sz=size[v];
dfs(v);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int u,v,ww; scanf("%d%d%d",&u,&v,&ww); if(!ww) ww=-1;
build(u,v,ww); build(v,u,ww);
}
sz=n; ss[0]=n+10; dfs(1);
printf("%I64d\n",ans);
return 0;
}