欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

[清华集训2017] 生成树计数

程序员文章站 2022-04-15 15:13:11
"题目链接" 分析 一类树(连出的边数集合一定)的贡献 $$ \mathbb{Ans}(\{d_n\}|\sum_id_i=2(n 1))=\prod_ia_i^{d_i}\prod_id_i^m\sum_{i}d_i^m $$ 引入Prufer序列,设$d_i$为点(联通块)在序列中出现的次数,转 ......

题目链接

分析

一类树(连出的边数集合一定)的贡献
\[ \mathbb{ans}(\{d_n\}|\sum_id_i=2(n-1))=\prod_ia_i^{d_i}\prod_id_i^m\sum_{i}d_i^m \]
引入prufer序列,设\(d_i\)为点(联通块)在序列中出现的次数,转换
\[ \begin{aligned} \mathbb{ans}(\{d_n\}|\sum_{i}d_i=n-2) &=\prod_ia_i^{d_i+1}\prod_i{(d_i+1)^m}\sum_{i}(d_i+1)^m\\ &=\prod_ia_i^{d_i+1}(d_i+1)^m\sum_{i}(d_i+1)^m\\ \end{aligned} \]
那么枚举所有的prufer的组合,总答案
\[ \begin{aligned} \mathbb{ans} &=\sum_{\sum_id_i=n-2}\frac{(n-2)!}{\prod_id_i!}\prod_ia_i^{d_i+1}(d_i+1)^m\sum_{i}(d_i+1)^m\\ &=(n-2)!\prod_ia_i\sum_{\sum_id_i=n-2}\prod_i\frac{a_i^{d_i}(d_i+1)^m}{d_i!}\sum_i(d_i+1)^{m}\\ &=(n-2)!\prod_ia_i(\sum_{\sum_id_i=n-2}\sum_i\frac{a_i^{d_i}(d_i+1)^{2m}}{d_i!}\prod_{i\not=j}\frac{a_j^{d_i}(d_j+1)^m}{a_j!}) \end{aligned} \]
g8麻烦……请出生成函数
\[ a(x)=\sum_i\frac{x^i(i+1)^{2m}}{i!}\\ b(x)=\sum_i\frac{x^i(i+1)^m}{i!}\\ f(x)=\sum_ia(a_ix)\prod_{i\not=j}b(a_jx) \]
注意到\([n-2]f(x)\)正是\(\mathbb{ans}\)中非常数部分(括号注明部分)。于是需要处理\(f(x)\)
\[ f(x)=\sum_i\frac{a(a_ix)}{b(a_ix)}\prod_{i}b(a_ix)=\sum_i\frac{a(a_ix)}{b(a_ix)}\exp(\sum_{i}\ln(b(a_ix))) \]
\(\ln\exp\)中视\(a_ix\)整体为变量,处理\(\frac{a(x)}{b(x)}\)\(\ln(b(x))\)的系数,再用\(a_ix\)代换\(x\),求出\(\sum\frac{a(a_ix)}{b(a_ix)}\)以及\(\sum\ln(b(a_ix))\)

这两个过程本质相同:和函数\(\sum\)的第\(i\)向系数是单个函数的\(i\)项系数(常量)乘上\(\sum_ka_k^i\)

于是涉及到一个序列的幂和,它的生成函数
\[ f(x)=\sum_i\sum_ja_j^ix^i=\sum_i\sum_j(a_ix)^j=\sum_i\frac{1}{1-a_ix}\\ g(x)=\sum_i\ln(\frac{1}{1-a_ix})=\sum_i\frac{-a_i}{1-a_ix}=-\sum_i\sum_ja_i^{j+1}x^j\\ f(x)=n-x\times g(x)\\ g(x)=\sum_i\ln((1-a_ix)^{-1})=-\ln(\prod_i1-a_ix) \]
关于这个\(\ln\)\(x\)为变量而非\(a_ix\),于是分治fft处理\(g(x)\),设\(l\)为不小于\(n\)的2的幂,可以粗略的估计为复杂度
\[ \sum_d^{\log l}\frac{l}dd\log d=l\sum_d^{\log l}\log d=l\log(\log(l)!) \]
打表发现\(\log(\log(l)!)\)\(\log(l)\)略大,远小于\(\log^2\)所在规模,于是近似认为复杂度为\(o(n\log n)\)

最后慢慢推回去……

实现

有很多波折……目前洛谷rank1

#include <bits/stdc++.h>
#define ll long long
using namespace std;

const int n=200000;
const int mod=998244353;

inline int qpow(int x,int y) {
    int c=1;
    for(; y; y>>=1,x=(ll)x*x%mod) if(y&1) c=(ll)c*x%mod;
    return c;
}

//------- polynomial begin -------
int w[n],rev[n],_inv[n],lmt;
inline void predone(int len) {
    int l=0; lmt=1; _inv[1]=1;
    while(lmt<=len) lmt<<=1,l++;
    for(int i=0; i<lmt; ++i) {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
        if(i>1) _inv[i]=(ll)_inv[mod%i]*(mod-mod/i)%mod;
    }
    int wlmt=qpow(3,(mod-1)>>l),tmp=lmt>>1; w[tmp]=1;
    for(int i=tmp+1; i<lmt; ++i) w[i]=(ll)w[i-1]*wlmt%mod;
    for(int i=tmp-1; i>0; --i) w[i]=w[i<<1];
    lmt=l;
}
inline void dft(int a[],int len) {
    static unsigned long long tmp[n];
    int u=lmt-__builtin_ctz(len),t;
    for(int i=0; i<len; ++i) tmp[rev[i]>>u]=a[i];
    for(int m=1; m<len; m<<=1)
    for(int i=0,s=m<<1; i<len; i+=s)
    for(int j=0; j<m; ++j)
        t=tmp[i+j+m]*w[m+j]%mod,tmp[i+j+m]=tmp[i+j]+mod-t,tmp[i+j]+=t;
    for(int i=0; i<len; ++i) a[i]=tmp[i]%mod;
}
inline void idft(int a[],int len) {
    reverse(a+1,a+len); dft(a,len);
    ll t=mod-(mod-1)/len; 
    for(int i=0; i<len; ++i) a[i]=t*a[i]%mod;
}
inline int getlen(int len) {
    return 1<<(32-__builtin_clz(len));
}
inline void getder(int a[],int b[],int n) {
    for(int i=0; i<n-1; ++i) b[i]=(ll)(i+1)*a[i+1]%mod; b[n-1]=0;
}
inline void getint(int a[],int b[],int n) {
    for(int i=n-1; i>0; --i) b[i]=(ll)_inv[i]*a[i-1]%mod; b[0]=0;
}
inline void getinv(int a[],int b[],int n) {
    static int tmp[n];
    if(n==1) {b[0]=qpow(a[0],mod-2); return;}
    getinv(a,b,(n+1)>>1);
    int len=getlen(n<<1);
    for(int i=0; i<n; ++i) tmp[i]=a[i];
    for(int i=n; i<len; ++i) tmp[i]=0;
    dft(tmp,len); dft(b,len);
    for(int i=0; i<len; ++i) b[i]=(ll)b[i]*(2+mod-(ll)b[i]*tmp[i]%mod)%mod;
    idft(b,len);
    for(int i=n; i<len; ++i) b[i]=0;
}
inline void getln(int a[],int b[],int n) {
    static int tmp[n];
    getinv(a,tmp,n);
    getder(a,b,n);
    int len=getlen(n<<1);
    dft(tmp,len); dft(b,len);
    for(int i=0; i<len; ++i) tmp[i]=(ll)b[i]*tmp[i]%mod;
    idft(tmp,len);
    getint(tmp,b,n);
    for(int i=n; i<len; ++i) b[i]=0;
    for(int i=0; i<len; ++i) tmp[i]=0;
}
inline void getexp(int a[],int b[],int n) {
    static int tmp[n];
    if(n==1) {b[0]=1; return;}
    getexp(a,b,(n+1)>>1);
    getln(b,tmp,n);
    int len=getlen(n<<1);
    for(int i=0; i<n; ++i) tmp[i]=((i==0)+mod-tmp[i]+a[i])%mod;
    for(int i=n; i<len; ++i) tmp[i]=0;
    dft(tmp,len); dft(b,len);
    for(int i=0; i<len; ++i) b[i]=(ll)tmp[i]*b[i]%mod;
    idft(b,len);
    for(int i=n; i<len; ++i) b[i]=0;
    for(int i=0; i<len; ++i) tmp[i]=0;
}
//------- polynomial end -------

int n,m,con,a[n],d[20][n];
void solve(int l,int r,int p) {
    if(l==r) {
        d[p][0]=1;
        d[p][1]=mod-a[l];
        return;
    }
    int mid=(l+r)>>1; 
    solve(l,mid,p);
    solve(mid+1,r,p+1);
    int len=getlen(r-l+1);
    for(int i=mid-l+2; i<len; ++i) d[p][i]=0;
    for(int i=r-mid+1; i<len; ++i) d[p+1][i]=0;
    dft(d[p],len); dft(d[p+1],len);
    for(int i=0; i<len; ++i) d[p][i]=(ll)d[p][i]*d[p+1][i]%mod;
    idft(d[p],len);
}
int fc[n],fv[n];
int s[n],a[n],b[n],p[n],q[n];

int main() {
    //freopen("filename.in","r",stdin);
    fc[0]=fc[1]=fv[0]=fv[1]=1;
    for(int i=2; i<n; ++i) fv[i]=(ll)fv[mod%i]*(mod-mod/i)%mod;
    for(int i=2; i<n; ++i) fv[i]=(ll)fv[i-1]*fv[i]%mod;
    for(int i=2; i<n; ++i) fc[i]=(ll)fc[i-1]*i%mod;
    
    scanf("%d%d",&n,&m); 
    predone((n+1)*4); con=fc[n-2];
    for(int i=1; i<=n; ++i) scanf("%d",a+i),con=(ll)con*a[i]%mod;
    solve(1,n,0); 
    getln(d[0],s,getlen(n)); //改成二的幂次似乎能缓解数组清空问题……
    s[0]=n;
    for(int i=1; i<=n; ++i) s[i]=(mod-(ll)s[i]*i%mod)%mod;

    for(int i=0; i<=n-2; ++i) {
        a[i]=(ll)fv[i]*qpow(i+1,2*m)%mod;
        b[i]=(ll)fv[i]*qpow(i+1,m)%mod;
    }
    getinv(b,p,n-1);
    getln(b,q,n-1);
    int len=getlen(n<<1);
    dft(a,len); dft(p,len);
    for(int i=0; i<len; ++i) p[i]=(ll)a[i]*p[i]%mod;
    idft(p,len);
    for(int i=n-1; i<len; ++i) p[i]=0;
    for(int i=0; i<=n-2; ++i) {
        p[i]=(ll)p[i]*s[i]%mod;
        q[i]=(ll)q[i]*s[i]%mod;
    }
    memset(b,0,sizeof b); //这儿也是……
    getexp(q,b,n-1);
    dft(p,len); dft(b,len);
    for(int i=0; i<len; ++i) p[i]=(ll)p[i]*b[i]%mod;
    idft(p,len);
    
    printf("%lld\n",(ll)con*p[n-2]%mod);
    return 0;
}