欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

[CF960G] Bandit Blues

程序员文章站 2022-03-04 08:29:08
problem 求满足$\sum_i[p_i=\max_{j=1}^i p_j]=a$,$\sum_i[p_i=\max_{j=i}^n p_j]=b$的1到n的排列p的个数。 solution 设f[i,j]为从大到小地向序列中加入i个数,形成了j个前缀最大值的情况,转移有 $$ \begin{a ......

problem

求满足\(\sum_i[p_i=\max_{j=1}^i p_j]=a\)\(\sum_i[p_i=\max_{j=i}^n p_j]=b\)的1到n的排列p的个数。

solution

设f[i,j]为从大到小地向序列中加入i个数,形成了j个前缀最大值的情况,转移有
\[ \begin{aligned} f[0,0]=1,&&f[i,j]=f[i-1,j-1]+(i-1)f[i-1,j] \end{aligned} \]
显然这恰是第一类斯特林数,即\(f[i,j]=s(i,j)\)

一个数集与一个操作方案能对应一个序列。考虑枚举数n的位置,那么答案为
\[ \sum_{i=1}^ns(i-1,a-1)s(n-i,b-1)\times c(n-1,i-1) \]
这相当于是把1到n-1给分成a+b-2个环的方案数(其中环有两类,每类分别由a+1个和b+1个)即答案
\[ s(n-1,a+b-2)\times c(a+b-2,a-1) \]
至此问题已完结。

#include <bits/stdc++.h>
#define ll long long 
using namespace std;

const int n=2e5+10;
const int mod=998244353;
const int inf=0x3f3f3f3f;

inline ll qpow(ll x,ll y) {
    ll c=1;
    for(; y; y>>=1,x=x*x%mod)
        if(y&1) c=x*c%mod;
    return c;
}
int p,pcur,rev[n];
inline void ntt_init(int len) {
    for(p=1,pcur=0; p<(len<<1);) p<<=1,pcur++;
    for(int i=0; i<p; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(pcur-1));
}
inline void ntt(ll*a,int tp) {
    for(int i=0; i<p; ++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int m=1; m<p; m<<=1) {
        int wm=qpow(3,(mod-1)/(m<<1)); if(tp<0) wm=qpow(wm,mod-2);
        for(int i=0; i<p; i+=(m<<1)) { ll w=1,tmp;
            for(int j=0; j<m; ++j,w=w*wm%mod) {
                tmp=w*a[i+j+m]%mod;
                a[i+j+m]=(a[i+j]-tmp+mod)%mod;
                a[i+j]=(a[i+j]+tmp)%mod;
            }
        } 
    }
    if(tp<0) {
        ll tmp=qpow(p,mod-2);
        for(int i=0; i<p; ++i) a[i]=tmp*a[i]%mod;
    }
}
inline void chm(ll*a,ll*b) {
    ntt(a,1); ntt(b,1);
    for(int i=0; i<p; ++i) (a[i]*=b[i])%=mod;
    ntt(a,-1); 
}
ll fac[n],fav[n],a[n],b[n];
void calc(int n,ll*s) {
    if(n==0) {s[0]=1; return;}
    if(n==1) {s[1]=1; return;}
    int m(n/2); calc(m,s); ntt_init(m+1);
    for(int i=0; i<=m; ++i) a[m-i]=fac[i]*s[i]%mod;
    for(int i=0; i<=m; ++i) b[i]=fav[i]*qpow(m,i)%mod;
    for(int i=m+1; i<p; ++i) a[i]=b[i]=0;
    chm(a,b);
    for(int i=0; i<=m; ++i) b[i]=a[m-i]*fav[i]%mod;
    for(int i=0; i<=m; ++i) a[i]=s[i];
    for(int i=m+1; i<p; ++i) a[i]=b[i]=0;
    chm(a,b);
    for(int i=0; i<=m+m; ++i) s[i]=a[i];
    if(n&1)
    for(int i=n; i>=0; --i) s[i]=((i?s[i-1]:0)+(n-1)*s[i]%mod)%mod;
}

ll s[n];
int main() {
    fac[0]=fac[1]=fav[0]=fav[1]=1;
    for(int i=2; i<n; ++i) fav[i]=fav[mod%i]*(mod-mod/i)%mod;
    for(int i=2; i<n; ++i) fav[i]=fav[i-1]*fav[i]%mod,fac[i]=fac[i-1]*i%mod;
    //int n; scanf("%d",&n); calc(n,s);
    //for(int i=0; i<=n; ++i) printf("s(%d,%d)=%d\n",n,i,s[i]);
    int n,a,b;
    scanf("%d%d%d",&n,&a,&b);
    calc(n-1,s);
    printf("%lld",fac[a+b-2]*fav[a-1]%mod*fav[b-1]%mod*s[a+b-2]%mod);
    return 0;
}