欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

bzoj3160 万径人踪灭 FFT+manacher

程序员文章站 2022-06-29 21:49:25
...

Description


好长啊
bzoj3160 万径人踪灭 FFT+manacher
bzoj3160 万径人踪灭 FFT+manacher

Solution


可以先算出回文的答案,然后减去连续的回文的答案

注意到两个位置i和j的字符关于k对称满足si=sj(i+j=k)
考虑用FFT加速这个过程(好像也可以叫生成函数什么的,我们做两次FFT分别求出a的对称和b的对称,这样算出来的就是包含不合法方案的答案

然后变成求每个位置为中心有多少个连续回文,这个可以用回文树也可以manacher
手写复数类的确会快,而且好像快很多囧
可能我昨天没有get到正确的姿势

Code


#include <stdio.h>
#include <string.h>
#include <math.h>
#include <algorithm>
#define rep(i,st,ed) for (int i=st;i<=ed;++i)

typedef long long LL;
const int MOD=1000000007;
const int N=1048576;
const double pi=acos(-1);

struct com {
    double real,imag;
    com operator +(const com &b) const {
        return (com) {real+b.real,imag+b.imag};
    }
    com operator -(const com &b) const {
        return (com) {real-b.real,imag-b.imag};
    }
    com operator *(const com &b) const {
        return (com) {real*b.real-imag*b.imag,real*b.imag+imag*b.real};
    }
    com operator /(const double b) const {
        return (com) {real/b,imag/b};
    }
} b[N],c[N];

char str[N];

int rev[N],rad[N],cnt[N];

int read() {
    int x=0,v=1; char ch=getchar();
    for (;ch<'0'||ch>'9';v=(ch=='-')?(-1):(v),ch=getchar());
    for (;ch<='9'&&ch>='0';x=x*10+ch-'0',ch=getchar());
    return x*v;
}

int ksm(int x,int dep) {
    int ret=1;
    for (;dep;dep>>=1) {
        (dep&1)?ret=1LL*ret*x%MOD:0;
        x=1LL*x*x%MOD;
    }
    return ret;
}

void FFT(com *a,int len,double f) {
    for (int i=0;i<len;i++) if (i<rev[i]) std:: swap(a[i],a[rev[i]]);
    for (int i=1;i<len;i<<=1) {
        com wn=(com){cos(pi/i),f*sin(pi/i)};
        for (int j=0;j<len;j+=i*2) {
            com w=(com){1,0};
            for (int k=0;k<i;k++) {
                com u=a[j+k],v=a[j+k+i]*w;
                a[j+k]=u+v; a[j+k+i]=u-v;
                w=w*wn;
            }
        }
    }
    if (f==-1) for (int i=0;i<len;i++) a[i]=a[i]/len;
}

int manacher(char *ptr) {
    int n=strlen(ptr+1); LL ans=0;
    static char str[N];
    rep(i,1,n) {
        str[i*2-1]='#';
        str[i*2]=ptr[i];
    } str[n*2+1]='#';
    int pos=0;
    rep(i,1,n*2+1) {
        if (i<=pos+rad[pos]) rad[i]=std:: min(rad[pos*2-i],pos+rad[pos]-i);
        for (;i+rad[i]+1<=n*2+1&&i-rad[i]-1>=1&&str[i+rad[i]+1]==str[i-rad[i]-1];) ++rad[i];
        if (rad[i]+i>=pos+rad[pos]) {
            pos=i;
        }
        ans=(ans+((rad[i]+1)>>1))%MOD;
    }
    return ans;
}

void solve(char *str) {
    int n=strlen(str+1); int ans=0;;
    int len,lg; for (len=1,lg=0;len<=n*2;len<<=1,lg++);
    for (int i=0;i<len;i++) rev[i]=(rev[i/2]/2)|((i&1)<<(lg-1));

    rep(i,1,n) {
        b[i]=(com){(double)(str[i]=='a'),0};
        c[i]=(com){(double)(str[i]=='b'),0};
    }
    FFT(b,len,1); FFT(c,len,1);
    rep(i,0,len) {
        b[i]=b[i]*b[i];
        c[i]=c[i]*c[i];
    }
    FFT(b,len,-1); FFT(c,len,-1);
    rep(i,0,len) {
        cnt[i]=(cnt[i]+(int)(b[i].real+0.5)+(int)(c[i].real+0.5));
        cnt[i]=(cnt[i]+1)>>1;
    }

    rep(i,0,len) ans=(ans+ksm(2,cnt[i])-1)%MOD;
    ans=(ans-manacher(str)+MOD)%MOD;
    printf("%d\n", ans);
}

int main(void) {
    scanf("%s",str+1);
    solve(str);
    return 0;
}
相关标签: bzoj