HDU 6338 2018HDU多校赛 第四场 Depth-First Search(组合数学+平衡树/pbds)
大致题意:给你一个dfs序列B和一棵树,现在让你在这个树上随机选择一个点,然后按照随机的dfs顺序走。问你最后能走出几个dfs序列,是得该dfs序列字典序小于给定的dfs序B。
首先,我们考虑一棵树有根树他的dfs序有多少种。我们可以这么考虑,对于任意点x,我都可以任意的向它的所有儿子走去,那么就会对应 种方法。我们注意到,除了根之外,所有的点的儿子的数目等于其度数减一,那么,我们便可以得出一棵有根树的dfs序列为:。进一步,我们可以令 ,那么对于不同的根,其对应树的方案数就是 res*deg[root],也即res就是所谓的公共部分。
接着,我们来考虑这道题目。由于题目要求是字典序比给定的要小,而且是dfs序,所以我么考虑按照它给定的顺序进行dfs,逐位计算种类数。初始根的时候,我们先利用上面的公式,计算所有以编号小于B[0]的点为根的方案。然后开始dfs,当我们走到树上的x节点,序列的第i位的时候,在x的所有的可选儿子中,查询有多少个的编号小于B[i]。不妨设此时恰好有t个可选儿子的编号小于B[i],那么这个点的贡献就是,即字典序小于B[i]我可以在t个中选择一个,选完之后,相当于x的所有儿子中少了一个可以任意选择位置的点,因此要少去一些方案。计算完贡献,我们继续顺着给定的dfs序列往下走,同时要把每次按照这个dfs序经过的所有点标记为不可选,对应父亲的可用节点数目要减一,这里可以对应deg[fa]减一,因为相当于,我已经确定这个点的位置了。res的值也要改变,原因同计算贡献的时候。
再整理一下,就是遇到一个点首先计算贡献,然后顺着往下走的同时,把这个点在它父亲的可选点中删除,对应的res也要少一个产生贡献的*点。由于这是dfs,所以经过某一个点之后,后面可能再次经过,然后第二次经过的时候,我们还是需要对于一个定值,看有多少可行的点比它小,这个过程如果没有高效的方法,复杂度将会比较大。
由此,我们就需要一个,支持插入、删除和查询有多少个点比某一个定值小的数据结构。用一个Treap平衡树即可,需要一个rank操作。另外,今天还新学到一个pbds库,里面有可以用到的红黑树(red-black tree),直接就可以支持插入、删除和查询rank的操作,可以省去大量的代码,值得正式比赛用。具体见代码:
#include<bits/stdc++.h>
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;
LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];
int rt[N],sz;
bool flag;
struct Treap
{
#define ls T[i].ch[0]
#define rs T[i].ch[1]
struct treap{int ch[2],sz,val,cnt,fix;} T[N];
void up(int i)
{
T[i].sz=T[i].cnt+T[ls].sz+T[rs].sz;
}
void Rotate(int &x,bool d)
{
int y=T[x].ch[d];
T[x].ch[d]=T[y].ch[d^1];
T[y].ch[d^1]=x; up(x),up(x=y);
}
void ins(int &i,int x)
{
if (!i)
{
i=++sz;T[i].fix=rand();
T[i].sz=T[i].cnt=1;
T[i].val=x;ls=rs=0;
return;
}
T[i].sz++;
if (x==T[i].val) {T[i].cnt++;return;}
int d=x>T[i].val; ins(T[i].ch[d],x);
if (T[T[i].ch[d]].fix<T[i].fix) Rotate(i,d);
}
void del(int &i,int x)
{
if (!i) return;
if (T[i].val==x)
{
if (T[i].cnt>1){T[i].cnt--,T[i].sz--;return;}
int d=T[ls].fix>T[rs].fix;
if (ls==0||rs==0) i=ls+rs;
else Rotate(i,d),del(i,x);
} else T[i].sz--,del(T[i].ch[x>T[i].val],x);
}
int rank(int i,int x)
{
if (!i) return 0;
if (T[i].val>x) return rank(ls,x);
if (T[i].val==x) return T[ls].sz+1;
return rank(rs,x)+T[ls].sz+T[i].cnt;
}
} treap;
void init()
{
fac[1]=fac[0]=1;
inv[1]=inv[0]=1;
for(int i=2;i<N;i++)
{
fac[i]=fac[i-1]*i%mod;
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=2;i<N;i++)
inv[i]=inv[i-1]*inv[i]%mod;
}
void dfs(int x,int fa)
{
f[x]=fa;
if (fa) treap.ins(rt[fa],x);
for(int i=0;i<g[x].size();i++)
{
int y=g[x][i];
if (y==fa) continue;
dfs(y,x); d[y]--;
}
}
void dfs(int x)
{
if (m>n||flag||!x) return;
if (d[x]!=0)
{
int t=treap.rank(rt[x],b[m+1]-1); //查找有多少点小于b[m+1]
ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod; //计算贡献
if (f[b[m+1]]!=x) {flag=1;return;} m++;
res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--; //改变res,同时x要少一个可选点
treap.del(rt[x],b[m]); dfs(b[m]); //把b[m]从x的可选点中删除
} else dfs(f[x]);
}
int main()
{
init();
IO;int T;cin>>T;
while(T--)
{
cin>>n; res=1,ans=sz=0;
for(int i=1;i<=n;i++)
cin>>b[i],d[i]=rt[i]=0,g[i].clear();
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
d[x]++,d[y]++;
}
for(int i=1;i<=n;i++)
res=res*fac[d[i]-1]%mod;
for(int i=1;i<b[1];i++)
ans=(ans+res*d[i]%mod)%mod;
res=res*d[b[1]]%mod; m=1; flag=0;
dfs(b[1],0); dfs(b[1]);
cout<<ans<<endl;
}
return 0;
}
然后我们还有pbds库版本的代码,这个更加的简洁,适合比赛用,但是速度可能就会慢一点点。
#include<bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define mod 1000000007
#define LL long long
#define N 1000010
using namespace std;
using namespace __gnu_pbds;
tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> rbt[N];
LL inv[N],fac[N],ans,res;
int d[N],b[N],f[N],n,m;
vector<int> g[N];
bool flag;
void init()
{
fac[1]=fac[0]=1;
inv[1]=inv[0]=1;
for(int i=2;i<N;i++)
{
fac[i]=fac[i-1]*i%mod;
inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=2;i<N;i++)
inv[i]=inv[i-1]*inv[i]%mod;
}
void dfs(int x,int fa)
{
f[x]=fa;
rbt[fa].insert(x);
for(int i=0;i<g[x].size();i++)
if (g[x][i]!=fa) dfs(g[x][i],x);
}
void dfs(int x)
{
if (m>n||flag||!x) return;
if (d[x]!=0)
{
int t=rbt[x].order_of_key(b[m+1]);
ans=(ans+res*inv[d[x]]%mod*t%mod*fac[d[x]-1]%mod)%mod;
if (f[b[m+1]]!=x) {flag=1;return;} m++;
res=res*inv[d[x]]%mod*fac[d[x]-1]%mod; d[x]--;
rbt[x].erase(b[m]); dfs(b[m]);
} else dfs(f[x]);
}
int main()
{
init();
IO;int T;cin>>T;
while(T--)
{
cin>>n; res=1,ans=0;
for(int i=1;i<=n;i++)
cin>>b[i],d[i]=0,rbt[i].clear(),g[i].clear();
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
d[x]++,d[y]++;
}
for(int i=1;i<=n;i++)
res=res*fac[d[i]-1]%mod;
for(int i=1;i<b[1];i++)
ans=(ans+res*(LL)d[i]%mod)%mod;
for(int i=1;i<=n;i++)
if (i!=b[1]) d[i]--;
res=res*d[b[1]]%mod; m=1; flag=0;
dfs(b[1],0); dfs(b[1]);
cout<<ans<<endl;
}
return 0;
}