欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

bzoj3697 采药人的路径 点分治

程序员文章站 2022-05-08 17:58:01
...

一道拖了很久的点分治,现在把他搞定了。
来自出题人hta的题解:
本题可以考虑树的点分治。问题就变成求过根满足条件的路径数。
路径上的休息站一定是在起点到根的路径上,或者根到终点的路径上。
如何判断一条从根出发的路径是否包含休息站?只要在dfs中记录下这条路径的和x,同时用个标志数组判断这条路径是否存在前缀和为x的节点。
这样我们枚举根节点的每个子树。用f[i][0…1],g[i][0…1]分别表示前面几个子树以及当前子树和为i的路径数目,0和1用于区分路径上是否存在前缀和为i的节点。那么当前子树的贡献就是f[0][0] * g[0][0] + Σf [i][0] * g [-i][1] + f[i][1] * g[-i][0] + f[i][1] * g[-i][1],其中i的范围[-d,d],d为当前子树的深度。

其实这个式子挺好理解的,就是所有的可能性组合起来,注意一下g要把f的数值加上,然后f清空一下,不然会影响到后面。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=1e6+5;
int n,m;
int head[N],go[N],dis[N],next[N],val[N],f[N];
int son[N];
int len,sum,tot,root,mxdep,dep[N],t[N];
bool vis[N];
typedef long long ll;
ll ans,F[N][2],g[N][2];
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'&&ch>'9'){if (ch=='-')f=-1;ch=getchar();}
    while (ch<='9'&&ch>='0'){x=x*10+ch-'0',ch=getchar();}
    return x*f; 
}
inline void add(int x,int y,int z)
{
    go[++tot]=y;
    next[tot]=head[x];
    val[tot]=z;
    head[x]=tot;
}
inline void getroot(int x,int fa)
{
    son[x]=1,f[x]=0;
    for(int i=head[x];i;i=next[i])
    {
        int v=go[i];
        if (v!=fa&&!vis[v])
        {
            getroot(v,x);
            son[x]+=son[v];
            f[x]=max(f[x],son[v]);
        }
    } 
    //printf("%d\n",x); 
    f[x]=max(f[x],sum-son[x]);
    if (f[x]<f[root])root=x;
}
inline void dfs(int x,int fa)
{
    mxdep=max(mxdep,dep[x]);
    if (t[dis[x]])F[dis[x]][1]++;
    else F[dis[x]][0]++;
    t[dis[x]]++;
    for(int i=head[x];i;i=next[i])
    {
        int v=go[i];
        if (!vis[v]&&v!=fa)
        {
            dis[v]=dis[x]+val[i];
            dep[v]=dep[x]+1;
            dfs(v,x);
        }
    }
    t[dis[x]]--;
}
inline void work(int x)
{
    g[n][0]=1;
    vis[x]=1;
    int mx=0;
    for(int i=head[x];i;i=next[i])
    {
        int v=go[i];
        if (!vis[v])
        {
            dis[v]=n+val[i];
            dep[v]=1;
            mxdep=1;
            dfs(v,0);
            mx=max(mx,mxdep);
            ans+=(g[n][0]-1)*F[n][0];
            fo(j,-mxdep,mxdep)
            ans+=g[n-j][1]*F[n+j][1]+g[n-j][0]*F[n+j][1]+g[n-j][1]*F[n+j][0];
            fo(j,n-mxdep,n+mxdep)
            {
                g[j][0]+=F[j][0];
                g[j][1]+=F[j][1];
                F[j][0]=F[j][1]=0;
            }
        }
    }
    fo(i,n-mx,n+mx)
    g[i][0]=g[i][1]=0;
    for(int i=head[x];i;i=next[i])
    {
        int v=go[i];
        if (!vis[v])
        {
            root=0;
            sum=son[v];
            getroot(v,0);
            work(root);
        }
    }
}
int main()
{
    n=read();
    fo(i,1,n-1)
    {
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        if (!z)z=-1;
        add(x,y,z);
        add(y,x,z);
    }
    sum=f[0]=n;
    getroot(1,0);
    work(root);
    printf("%lld\n",ans);
    return 0;
}