欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[BZOJ3697]采药人的路径(点分治)

程序员文章站 2022-05-08 18:06:30
...

题目:

我是超链接

题解:

所谓的阴阳平衡,我们可以把0的边设为-1,那么就是说总和为0的路就是阴阳平衡,至于中间的休息站嘛,那也就是说存在一个点前缀和为0
路径相关,那就点分治吧!
getdeep的时候,我们对于每一个点到根求前缀和,称为点权
显然要将这些点两两组合才是合法路径
两条路径点权和为0称为【合法匹配】
如果当前点到根的路径中存在点和这个点的点权相等,那么这条路径已经单独存在休息站,将这个点标记 (前缀和相等代表中间变化为0)

我们分类讨论一下

点权为0的:

①如果没有标记,可以和任意一条点权为0的匹配(即合法的)
②如果有标记,那么ta们不仅可以和任意一条点权为0的匹配(合法的),还可以自己到根节点成一条路径

这个①情况为什么不要求和有标记的在一起呢?因为此时根节点作为休息站

只要我们避免根节点加入没有标记的行列,那所有的0权节点都是有标记的!我们只需要单独加自己到根节点的路径就好了
当我们分治子树的时候,这个节点并不是作为根节点出现的,我们要让ta能够加入到没有标记的行列,暗箱操作一下就好咯

点权不为0的:

①如果没有标记,那么必须和 有标记合法的 匹配
②如果有标记,那就要和合法的匹配

答案

=有标记 * 有标记+无标记 * 有标记

然后就是一个小问题了,你一个一个加好慢的,要是有个区间就好啦,这时候我们设置一个p,和r代表一个区间,这样相等的就不用再+1+1+1了
其他见注释咯!

代码:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define LL long long
#define INF 1e9
using namespace std;
const int N=100005;
const int base=100000;
int tot,nxt[N*2],point[N],v[N*2],c[N*2],size[N],f[N],root,sum,k,app[N*2],ss,nn,nsig[N],sign[N];
LL ans;bool vis[N];
void addline(int x,int y,int z)
{
    ++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
    ++tot; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=z;
}
void getroot(int x,int fa)
{
    size[x]=1; f[x]=0;
    for (int i=point[x];i;i=nxt[i])
      if (v[i]!=fa && !vis[v[i]])
      {
        getroot(v[i],x);
        size[x]+=size[v[i]];
        f[x]=max(f[x],size[v[i]]);
      }
    f[x]=max(f[x],sum-f[x]);
    if (f[x]<f[root]) root=x;
}
void getdep(int x,int fa,int quan,int vv)
{
    if (app[quan+base]) sign[++ss]=quan;
    //出现第二次啦 
    else if (app[base]) nsig[++nn]=quan;
    //可以有效的避免根节点加入这个没有标记的行列 

    if (!quan && app[base]>=2) ans+=(LL)vv;
    //点权为0的有标记的和根节点连成路径 
    //这个2一个是根节点一个是x
    app[quan+base]++;
    for (int i=point[x];i;i=nxt[i])
      if (v[i]!=fa && !vis[v[i]])
        getdep(v[i],x,quan+c[i],vv);
    app[quan+base]--;
}
LL calc(int x,int quan,int vv)
{
    nn=0; ss=0;
    getdep(x,0,quan,vv);
    sort(nsig+1,nsig+nn+1);
    sort(sign+1,sign+ss+1);
    sign[0]=nsig[0]=-INF;
    int l=1,r=ss,p=r;LL t=0;
    for (;l<r;)
    {
        while (p>l && sign[p]>=sign[r]) p--;
        p=max(p,l);//注意p一定要大于l,可能说上面已经有限制p>l了,但是l会加 
        if (sign[l]+sign[r]==0){t+=(LL)(r-p); l++;}
        else
        {
            if (sign[l]+sign[r]>0) r--; 
            else l++;
        }
    }//有标记*有标记 
    l=1;r=ss;p=r;
    for (;l<=nn && r>=1;)
    {
        while (p>=1 && sign[p]>=sign[r]) p--;
        if (nsig[l]+sign[r]==0){t+=(LL)(r-p); l++;}
        else
        {
            if (nsig[l]+sign[r]>0) r--; 
            else l++;
        }
    }//无标记*有标记 
    return t;
}
void work(int x)
{
    ans+=calc(x,0,1);
    vis[x]=1;
    for (int i=point[x];i;i=nxt[i])
      if (!vis[v[i]])
      {
        app[base]++;
        ans-=calc(v[i],c[i],-1);
        app[base]--;
        root=0; sum=size[v[i]]; getroot(v[i],x);
        work(root);
      }
}
int main()
{
    int n,i;
    scanf("%d",&n);
    for (i=1;i<n;i++)
    {
        int x,y,z; scanf("%d%d%d",&x,&y,&z);
        if (!z) z=-1;
        addline(x,y,z);
    }
    sum=n; root=0; f[0]=INF; getroot(1,0);
    work(root);printf("%lld",ans);
}