欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

[NOI2017] 泳池

程序员文章站 2022-06-22 12:55:03
分析 設$g[i,j]$表示底邊長爲$i$,底部恰好有$j$行全部安全,沒有面積大於$k$的子矩陣(下稱:合法)的概率,$s[i, ]$為$g[i, ]$的前綴和,意義為最下面至少有$j$行全部安全的合法概率。 我們枚舉第一個在第$i+1$行的危險點來轉移$g$和$s$,設這個危險點在第$k+1$列 ......

分析

\(g[i,j]\)表示底邊長爲\(i\),底部恰好有\(j\)行全部安全,沒有面積大於\(k\)的子矩陣(下稱:合法)的概率,\(s[i,*]\)\(g[i,*]\)的前綴和,意義為最下面至少有\(j\)行全部安全的合法概率。

我們枚舉第一個在第\(i+1\)行的危險點來轉移\(g\)\(s\),設這個危險點在第\(k+1\)列,那麽左邊\(k\)列的底部就至少有\(j+1\)行是全安全的,而右邊至少有\(j\)行(不重複計數)是。然後乘上第\(k+1\)列上的相關概率得到:
\[ g[i,j]=\sum_{k=0}^{i-1}q^j(1-q)s[k,j+1]*s[i-1-k,j]\\ s[i,j]=s[i,j+1]+g[i,j]\\ s[0,0\cdots k+1]=1 \]
需要留意的是儅\(s,g[i,j]​\)\(ij\ge k​\)的狀態一定是不合法的,基於這樣的邊界的處理複雜度就不是\(n^3​\)了。

定義\(f[i]​\)表示底邊長爲\(i​\)的泳池的合法概率。若第\(i​\)列底部的格子不安全,則
\[ f[i]\leftarrow f[i-1](1-q) \]
再考慮第\(i-j\)列底部格子不安全,第\([i-j+1,i]\)列底部全部安全且產生的子矩陣面積不超過\(k\)的情形,則
\[ f[i]\leftarrow f[i-j-1](1-q)s[j,1] \]

綜合起來記作
\[ f[i]=\sum_{j=1}^{k+1}f[i-j](1-q)s[j-1,1]\\ f[0]=1 \]
顯然是個常係數齊次遞推,可以用特徵多項式水過。答案就是\(f[n]-f[n-1]\)囖。

總結

這題滿巧妙地把二維上的問題通過預處理變爲一個一維的綫性遞推問題,主要是可以抓住“所判別的矩陣的底部一定是靠在泳池的底部”這一要求;另外,邏輯關係上\(s\)應該在\(g\)之前,就是説,我們使用差分的方式來處理\(s\)這種”至少“的問題,這也是個常見的套路:d

暴力算\(f[0\to k]\)時不要拿\(a\)進去呀,而且此時的轉移一個純的\(s[i,1]\),即沒有任何危險點的情況 tat

代碼實現

#include <bits/stdc++.h>
using namespace std;
const int n=1e5+10;
const int p=998244353;

inline int read() {
    int d=0,f=0;
    char ch=getchar();
    while(!isdigit(ch)&&ch!='-') ch=getchar();
    if(ch=='-') f=1,ch=getchar();
    while(isdigit(ch)) d=d*10+ch-'0',ch=getchar();
    if(f) d=-d;
    return d;
}
inline int qpow(int x,int y) {
    int c=1;
    for(; y; y>>=1,x=1ll*x*x%p) 
        if(y&1) c=1ll*c*x%p;
    return c;
}

struct linear { 
    void nuthtr(int*a,int p,int tp) {
        static int pl=0,r[n];
        if(pl!=p) {
            int l=log2(pl=p);
            for(int i=0; i<p; ++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
        }
        for(int i=0; i<p; ++i) if(i<r[i]) swap(a[i],a[r[i]]);
        for(int m=1; m<p; m<<=1) {
            int wm=qpow(3,(p-1)/(m<<1));
            if(tp==-1) wm=qpow(wm,p-2);
            for(int i=0; i<p; i+=(m<<1)) {
                long long w=1,tmp;
                for(int j=0; j<m; ++j,w=w*wm%p) {
                    tmp=w*a[i+j+m]%p;
                    a[i+j+m]=(a[i+j]-tmp+p)%p;
                    a[i+j]=(a[i+j]+tmp)%p;
                }
            }
        }
        if(tp==-1) {
            long long tmp=qpow(p,p-2);
            for(int i=0; i<p; ++i) a[i]=tmp*a[i]%p;
        }
    }
    void poin(int*a,int*b,int n) {
        static int c[n],p;
        if(n==1) {
            b[0]=qpow(a[0],p-2);
            return;
        }
        poin(a,b,(n+1)>>1);
        for(p=1; p<(n<<1); p<<=1);
        copy(a,a+n,c);
        fill(c+n,c+p,0);
        nuthtr(c,p,1);
        nuthtr(b,p,1);
        for(int i=0; i<p; ++i) b[i]=(2ll-1ll*c[i]*b[i]%p+p)*b[i]%p;
        nuthtr(b,p,-1);
        fill(b+n,b+p,0);
    }
    
    int p,k;
    int c[n],g[n],ig[n]; 
    
    void pomo(int*f,int*r) {
        static int q[n],t[n];
        reverse(f,f+k+k-1);
        copy(f,f+k,t);
        nuthtr(t,p,1);
        for(int i=0; i<p; ++i) q[i]=1ll*t[i]*ig[i]%p;
        nuthtr(q,p,-1);
        fill(q+k-1,q+p,0);
        reverse(f,f+k+k-1);
        reverse(q,q+k-1); 
        nuthtr(q,p,1);
        for(int i=0; i<p; ++i) q[i]=1ll*g[i]*q[i]%p;
        nuthtr(q,p,-1);
        for(int i=0; i<k; ++i) r[i]=(f[i]-q[i]+p)%p;
        fill(r+k,r+p,0);
        fill(q,q+p,0);
        fill(t,t+p,0);
    }
    void init(int*a,int k) {
        this->k=k;
        memset(g,0,sizeof g);
        memset(ig,0,sizeof ig);
        
        for(int i=1; i<=k; ++i) g[k-i]=(p-a[i]); g[k]=1;
        for(p=1; p<=k; p<<=1); p<<=1;
        reverse(g,g+k+1);
        poin(g,ig,p);
        fill(ig+k+1,ig+p,0);
        reverse(g,g+k+1);
        nuthtr(g,p,1);
        nuthtr(ig,p,1);
    }
    int calc(int*f,int n) {
        if(n<k) return f[n];
        static int s[n];
        memset(s,0,sizeof s);
        memset(c,0,sizeof c);
        
        s[1]=1;
        c[0]=1;
        for(; n; n>>=1) {
            if(n&1) {
                nuthtr(c,p,1);
                nuthtr(s,p,1);
                for(int i=0; i<p; ++i) c[i]=1ll*c[i]*s[i]%p;
                nuthtr(c,p,-1);
                nuthtr(s,p,-1);
                pomo(c,c);
            }
            nuthtr(s,p,1);
            for(int i=0; i<p; ++i) s[i]=1ll*s[i]*s[i]%p;
            nuthtr(s,p,-1);
            pomo(s,s);
        }
        int ans=0;
        for(int i=0; i<k; ++i) ans=(ans+1ll*c[i]*f[i])%p;
        return ans;
    }
} c;

int n,k,q,x,y;
int s[1002][1002],g[1002][1002];
int f[1002],a[1002];

int solve(int k) {
    if(k<0) return 0;
    if(k==0) return qpow((p+1-q)%p,n);
    memset(s,0,sizeof s);
    memset(g,0,sizeof g);
    memset(f,0,sizeof f);
    
    for(int i=0; i<=k+1; ++i) s[0][i]=1;
    for(int i=1; i<=k; ++i) {
        for(int j=k/i; j; --j) {
            for(int k=0; k<i; ++k) {
                g[i][j]=(g[i][j]+1ll*qpow(q,j)*(p+1-q)%p*s[k][j+1]%p*s[i-1-k][j])%p;
            }
            s[i][j]=(s[i][j+1]+g[i][j])%p;
        }
    }
    for(int j=1; j<=k+1; ++j) {
        a[j]=1ll*(p+1-q)*s[j-1][1]%p;
    }
    f[0]=1;
    for(int i=1; i<=k; ++i) {
        f[i]=s[i][1];
        for(int j=1; j<=i; ++j) {
            f[i]=(f[i]+1ll*f[i-j]*(p+1-q)%p*s[j-1][1]%p)%p;
        }
    }
    c.init(a,k+1);
    return c.calc(f,n);
}
int main() {
    n=read(),k=read(),x=read(),y=read();
    q=1ll*x*qpow(y,p-2)%p;
    printf("%d\n",(solve(k)-solve(k-1)+p)%p);
    return 0;
}

特別感謝: