欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

多项式多点求值

程序员文章站 2024-03-21 20:12:46
...

给定一个nn次多项式f(x)f(x),现在请你对于 i[1,m]i\in [1,m] ,求出 f(ai)(mod998244353)f(a_i)\pmod {998244353}

好像有一个啥定理:
f(ai)=f(x)mod  (xai)f(a_i) = f(x) \mod (x-a_i)
那么我们可以想到一个分治的做法。
fl,r(x)=f(x)mod  i=lr(xai)f_{l,r}(x)=f(x) \mod \prod_{i=l}^r (x-a_i)
那么f1,n(x)=f(x)f_{1,n}(x) = f(x)
fl,mid=fl,rmod  i=lmid(xai)f_{l,mid} = f_{l,r} \mod \prod_{i=l}^{mid} (x-a_i)
fmid+1,rf_{mid+1,r}同理。
那么我们这么分治下去就可以得到所有的fi,if_{i,i}
也就是我们要的那个点的值。
考虑复杂度:
T(n)=2T(n2)+O(nlogn)T(n) = 2T(\frac n2) + O(n\log n)
时间复杂度O(nlog2n)O(n\log^2n)
不过很慢,稳稳的1e51e5要跑一秒(以上)。

AC Code\mathrm {AC \ Code}

#include<bits/stdc++.h>
#define maxn 300005
#define mod 998244353
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
using namespace std;

int Wl,W[maxn],lg[maxn],inv[maxn],r[maxn];
int Pow(int b,int k){ int r=1;for(;k;k>>=1,b=1ll*b*b%mod) if(k&1) r=1ll*r*b%mod; return r; }
void init(int n){
	for(W[0]=inv[0]=inv[1]=Wl=1;n>=Wl<<1;Wl<<=1);int pw=Pow(3,(mod-1)/Wl/2);
	rep(i,1,Wl<<1) W[i]=1ll*W[i-1]*pw%mod,(i>1)&&(inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod,lg[i]=lg[i>>1]+1);
}
void NTT(int *A,int n,int tp){
	rep(i,0,n-1) i<(r[i]=(r[i>>1]>>1)|((i&1)<<(lg[n]-1)))&&(swap(A[i],A[r[i]]),0);
	for(int L=1,B=Wl;L<n;L<<=1,B>>=1) for(int s=0;s<n;s+=L<<1) for(int k=s,x=0,t;k<s+L;k++,x+=B)
		t=1ll*(tp==1?W[x]:W[(Wl<<1)-x])*A[k+L]%mod,A[k+L]=(A[k]-t)%mod,A[k]=(A[k]+t)%mod;
	if(tp^1) rep(i,0,n-1) A[i]=1ll*A[i]*inv[n]%mod;
}
void MUL(int *A,int *B,int *C,int n,int m){
	static int t[2][maxn];int L=1<<lg[n+m]+1;
	rep(i,0,L-1) t[0][i]=i<=n?A[i]:0,t[1][i]=i<=m?B[i]:0;NTT(t[0],L,1),NTT(t[1],L,1);
	rep(i,0,L-1) C[i]=1ll*t[0][i]*t[1][i]%mod;NTT(C,L,-1);
}
void INV(int *A,int *B,int n){
	B[B[1]=0]=Pow(A[0],mod-2);static int t[maxn];
	for(int k=2,L=4;k<(n<<1);k<<=1,L<<=1){
		rep(i,0,L-1) t[i]=i<k?A[i]:B[i]=0;NTT(B,L,1),NTT(t,L,1);
		rep(i,0,L-1) B[i]=B[i]*(2-1ll*B[i]*t[i]%mod)%mod;NTT(B,L,-1);
		rep(i,min(n,k),L-1) B[i]=0;
	}
}
void DIV(int *A,int *B,int *C,int *R,int n,int m){
	if(n<m){ rep(i,0,m-1) R[i]=A[i];return; }
	reverse(A,A+n+1),reverse(B,B+m+1),INV(B,C,n-m+1);
	MUL(A,C,C,n-m,n-m),fill(C+n-m+1,C+2*n-2*m+1,0);
	reverse(A,A+n+1);reverse(B,B+m+1);reverse(C,C+n-m+1);
	MUL(B,C,R,m,n-m);rep(i,0,n) R[i]=(A[i]-R[i])%mod;
}
#define lc u<<1
#define rc u<<1|1
int *M[maxn],*F[maxn],dM[maxn];
#define NAF(a) new int[1<<lg[a]+2]
void BDM(int u,int l,int r,int *X){
	M[u]=NAF(r-l+1),dM[u]=r-l+1;
	if(l==r) return (void)(M[u][0]=-X[l],M[u][1]=1);
	int m=(l+r)>>1;BDM(lc,l,m,X),BDM(rc,m+1,r,X);
	MUL(M[lc],M[rc],M[u],dM[lc],dM[rc]);
}
void EVL(int u,int l,int r,int *C){
	if(u>1) DIV(F[u>>1],M[u],F[0],F[u]=NAF(dM[u>>1]),dM[u>>1]-1,dM[u]);
	if(l==r) return (void)(C[l]=F[u][0]);
	int m=(l+r)>>1;EVL(lc,l,m,C),EVL(rc,m+1,r,C);
}
void MEV(int *A,int *X,int *C,int n,int m){//0...m-1
	BDM(1,0,m-1,X);
	DIV(A,M[1],F[0]=NAF(max(n,m)),F[1]=NAF(max(n,m)),n,dM[1]);
	EVL(1,0,m-1,C);
}

int n,m,A[maxn],X[maxn],C[maxn];

int main(){
	scanf("%d%d",&n,&m);init(2*max(n,m));
	rep(i,0,n) scanf("%d",&A[i]);
	rep(i,0,m-1) scanf("%d",&X[i]);
	MEV(A,X,C,n,m);
	rep(i,0,m-1) printf("%d\n",(C[i]+mod)%mod);
}

相关标签: FFT