欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds)

程序员文章站 2022-06-09 19:37:23
...

HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds)

 

 

 

大致题意:给你一个dfs序列B和一棵树,现在让你在这个树上随机选择一个点,然后按照随机的dfs顺序走。问你最后能走出几个dfs序列,是得该dfs序列字典序小于给定的dfs序B。

首先,我们考虑一棵树有根树他的dfs序有多少种。我们可以这么考虑,对于任意点x,我都可以任意的向它的所有儿子走去,那么就会对应 HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds) 种方法。我们注意到,除了根之外,所有的点的儿子的数目等于其度数减一,那么,我们便可以得出一棵有根树的dfs序列为:HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds)。进一步,我们可以令HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds) ,那么对于不同的根,其对应树的方案数就是 res*deg[root],也即res就是所谓的公共部分。

接着,我们来考虑这道题目。由于题目要求是字典序比给定的要小,而且是dfs序,所以我么考虑按照它给定的顺序进行dfs,逐位计算种类数。初始根的时候,我们先利用上面的公式,计算所有以编号小于B[0]的点为根的方案。然后开始dfs,当我们走到树上的x节点,序列的第i位的时候,在x的所有的可选儿子中,查询有多少个的编号小于B[i]。不妨设此时恰好有t个可选儿子的编号小于B[i],那么这个点的贡献就是HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds),即字典序小于B[i]我可以在t个中选择一个,选完之后,相当于x的所有儿子中少了一个可以任意选择位置的点,因此要少去一些方案。计算完贡献,我们继续顺着给定的dfs序列往下走,同时要把每次按照这个dfs序经过的所有点标记为不可选,对应父亲的可用节点数目要减一,这里可以对应deg[fa]减一,因为相当于,我已经确定这个点的位置了。res的值也要改变,原因同计算贡献的时候。

再整理一下,就是遇到一个点首先计算贡献,然后顺着往下走的同时,把这个点在它父亲的可选点中删除,对应的res也要少一个产生贡献的*点。由于这是dfs,所以经过某一个点之后,后面可能再次经过,然后第二次经过的时候,我们还是需要对于一个定值,看有多少可行的点比它小,这个过程如果没有高效的方法,复杂度将会比较大。

由此,我们就需要一个,支持插入、删除和查询有多少个点比某一个定值小的数据结构。用一个Treap平衡树即可,需要一个rank操作。另外,今天还新学到一个pbds库,里面有可以用到的红黑树(red-black tree),直接就可以支持插入、删除和查询rank的操作,可以省去大量的代码,值得正式比赛用。具体见代码:

#include<bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;

LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];
int rt[N],sz;
bool flag;

struct Treap
{
    #define ls T[i].ch[0]
    #define rs T[i].ch[1]

    struct treap{int ch[2],sz,val,cnt,fix;} T[N];

    void up(int i)
    {
        T[i].sz=T[i].cnt+T[ls].sz+T[rs].sz;
    }

    void Rotate(int &x,bool d)
    {
        int y=T[x].ch[d];
        T[x].ch[d]=T[y].ch[d^1];
        T[y].ch[d^1]=x; up(x),up(x=y);
    }

    void ins(int &i,int x)
    {
        if (!i)
        {
            i=++sz;T[i].fix=rand();
            T[i].sz=T[i].cnt=1;
            T[i].val=x;ls=rs=0;
            return;
        }
        T[i].sz++;
        if (x==T[i].val) {T[i].cnt++;return;}
        int d=x>T[i].val; ins(T[i].ch[d],x);
        if (T[T[i].ch[d]].fix<T[i].fix) Rotate(i,d);
    }

    void del(int &i,int x)
    {
        if (!i) return;
        if (T[i].val==x)
        {
            if (T[i].cnt>1){T[i].cnt--,T[i].sz--;return;}
            int d=T[ls].fix>T[rs].fix;
            if (ls==0||rs==0) i=ls+rs;
            else Rotate(i,d),del(i,x);
        } else T[i].sz--,del(T[i].ch[x>T[i].val],x);
    }

    int rank(int i,int x)
    {
        if (!i) return 0;
        if (T[i].val>x) return rank(ls,x);
        if (T[i].val==x) return T[ls].sz+1;
        return rank(rs,x)+T[ls].sz+T[i].cnt;
    }

} treap;


void init()
{
    fac[1]=fac[0]=1;
    inv[1]=inv[0]=1;
    for(int i=2;i<N;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    }
    for(int i=2;i<N;i++)
        inv[i]=inv[i-1]*inv[i]%mod;
}

void dfs(int x,int fa)
{
    f[x]=fa;
    if (fa) treap.ins(rt[fa],x);
    for(int i=0;i<g[x].size();i++)
    {
        int y=g[x][i];
        if (y==fa) continue;
        dfs(y,x); d[y]--;
    }
}

void dfs(int x)
{
    if (m>n||flag||!x) return;
    if (d[x]!=0)
    {
        int t=treap.rank(rt[x],b[m+1]-1);                    //查找有多少点小于b[m+1]
        ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod;        //计算贡献
        if (f[b[m+1]]!=x) {flag=1;return;} m++;
        res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--;       //改变res,同时x要少一个可选点
        treap.del(rt[x],b[m]); dfs(b[m]);                //把b[m]从x的可选点中删除
    } else dfs(f[x]);
}

int main()
{
    init();
    IO;int T;cin>>T;
    while(T--)
    {
        cin>>n; res=1,ans=sz=0;
        for(int i=1;i<=n;i++)
            cin>>b[i],d[i]=rt[i]=0,g[i].clear();
        for(int i=1;i<n;i++)
        {
            int x,y;
            cin>>x>>y;
            g[x].push_back(y);
            g[y].push_back(x);
            d[x]++,d[y]++;
        }
        for(int i=1;i<=n;i++)
            res=res*fac[d[i]-1]%mod;
        for(int i=1;i<b[1];i++)
            ans=(ans+res*d[i]%mod)%mod;
        res=res*d[b[1]]%mod; m=1; flag=0;
        dfs(b[1],0); dfs(b[1]);
        cout<<ans<<endl;
    }
    return 0;
}

然后我们还有pbds库版本的代码,这个更加的简洁,适合比赛用,但是速度可能就会慢一点点。

#include<bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;
using namespace __gnu_pbds;
tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> rbt[N];
LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];

bool flag;

void init()
{
    fac[1]=fac[0]=1;
    inv[1]=inv[0]=1;
    for(int i=2;i<N;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=(mod-mod/i)*inv[mod%i]%mod;
    }
    for(int i=2;i<N;i++)
        inv[i]=inv[i-1]*inv[i]%mod;
}

void dfs(int x,int fa)
{
    f[x]=fa;
    rbt[fa].insert(x);
    for(int i=0;i<g[x].size();i++)
        if (g[x][i]!=fa) dfs(g[x][i],x);
}

void dfs(int x)
{
    if (m>n||flag||!x) return;
    if (d[x]!=0)
    {
        int t=rbt[x].order_of_key(b[m+1]);
        ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod;
        if (f[b[m+1]]!=x) {flag=1;return;} m++;
        res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--;
        rbt[x].erase(b[m]); dfs(b[m]);
    } else dfs(f[x]);
}

int main()
{
    init();
    IO;int T;cin>>T;
    while(T--)
    {
        cin>>n; res=1,ans=0;
        for(int i=1;i<=n;i++)
            cin>>b[i],d[i]=0,rbt[i].clear(),g[i].clear();
        for(int i=1;i<n;i++)
        {
            int x,y;
            cin>>x>>y;
            g[x].push_back(y);
            g[y].push_back(x);
            d[x]++,d[y]++;
        }
        for(int i=1;i<=n;i++)
            res=res*fac[d[i]-1]%mod;
        for(int i=1;i<b[1];i++)
            ans=(ans+res*(LL)d[i]%mod)%mod;
        for(int i=1;i<=n;i++)
            if (i!=b[1]) d[i]--;
        res=res*d[b[1]]%mod; m=1; flag=0;
        dfs(b[1],0); dfs(b[1]);
        cout<<ans<<endl;
    }
    return 0;
}