bzoj3160 万径人踪灭 FFT+manacher
程序员文章站
2022-06-29 21:49:25
...
Description
好长啊
Solution
可以先算出回文的答案,然后减去连续的回文的答案
注意到两个位置i和j的字符关于k对称满足
考虑用FFT加速这个过程(好像也可以叫生成函数什么的,我们做两次FFT分别求出a的对称和b的对称,这样算出来的就是包含不合法方案的答案
然后变成求每个位置为中心有多少个连续回文,这个可以用回文树也可以manacher
手写复数类的确会快,而且好像快很多囧
可能我昨天没有get到正确的姿势
Code
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <algorithm>
#define rep(i,st,ed) for (int i=st;i<=ed;++i)
typedef long long LL;
const int MOD=1000000007;
const int N=1048576;
const double pi=acos(-1);
struct com {
double real,imag;
com operator +(const com &b) const {
return (com) {real+b.real,imag+b.imag};
}
com operator -(const com &b) const {
return (com) {real-b.real,imag-b.imag};
}
com operator *(const com &b) const {
return (com) {real*b.real-imag*b.imag,real*b.imag+imag*b.real};
}
com operator /(const double b) const {
return (com) {real/b,imag/b};
}
} b[N],c[N];
char str[N];
int rev[N],rad[N],cnt[N];
int read() {
int x=0,v=1; char ch=getchar();
for (;ch<'0'||ch>'9';v=(ch=='-')?(-1):(v),ch=getchar());
for (;ch<='9'&&ch>='0';x=x*10+ch-'0',ch=getchar());
return x*v;
}
int ksm(int x,int dep) {
int ret=1;
for (;dep;dep>>=1) {
(dep&1)?ret=1LL*ret*x%MOD:0;
x=1LL*x*x%MOD;
}
return ret;
}
void FFT(com *a,int len,double f) {
for (int i=0;i<len;i++) if (i<rev[i]) std:: swap(a[i],a[rev[i]]);
for (int i=1;i<len;i<<=1) {
com wn=(com){cos(pi/i),f*sin(pi/i)};
for (int j=0;j<len;j+=i*2) {
com w=(com){1,0};
for (int k=0;k<i;k++) {
com u=a[j+k],v=a[j+k+i]*w;
a[j+k]=u+v; a[j+k+i]=u-v;
w=w*wn;
}
}
}
if (f==-1) for (int i=0;i<len;i++) a[i]=a[i]/len;
}
int manacher(char *ptr) {
int n=strlen(ptr+1); LL ans=0;
static char str[N];
rep(i,1,n) {
str[i*2-1]='#';
str[i*2]=ptr[i];
} str[n*2+1]='#';
int pos=0;
rep(i,1,n*2+1) {
if (i<=pos+rad[pos]) rad[i]=std:: min(rad[pos*2-i],pos+rad[pos]-i);
for (;i+rad[i]+1<=n*2+1&&i-rad[i]-1>=1&&str[i+rad[i]+1]==str[i-rad[i]-1];) ++rad[i];
if (rad[i]+i>=pos+rad[pos]) {
pos=i;
}
ans=(ans+((rad[i]+1)>>1))%MOD;
}
return ans;
}
void solve(char *str) {
int n=strlen(str+1); int ans=0;;
int len,lg; for (len=1,lg=0;len<=n*2;len<<=1,lg++);
for (int i=0;i<len;i++) rev[i]=(rev[i/2]/2)|((i&1)<<(lg-1));
rep(i,1,n) {
b[i]=(com){(double)(str[i]=='a'),0};
c[i]=(com){(double)(str[i]=='b'),0};
}
FFT(b,len,1); FFT(c,len,1);
rep(i,0,len) {
b[i]=b[i]*b[i];
c[i]=c[i]*c[i];
}
FFT(b,len,-1); FFT(c,len,-1);
rep(i,0,len) {
cnt[i]=(cnt[i]+(int)(b[i].real+0.5)+(int)(c[i].real+0.5));
cnt[i]=(cnt[i]+1)>>1;
}
rep(i,0,len) ans=(ans+ksm(2,cnt[i])-1)%MOD;
ans=(ans-manacher(str)+MOD)%MOD;
printf("%d\n", ans);
}
int main(void) {
scanf("%s",str+1);
solve(str);
return 0;
}