欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Berlekamp-Massey算法学习小记

程序员文章站 2022-07-12 13:48:32
...

简介

Berlekamp-Massey算法,简称BM算法,可以在O(N2)O(N^2)时间内求解一个数列的最短线性递推式。

教程

一篇讲的很详细的博客

Berlekamp-Massey算法

我们采用增量法构造数列{an}\{a_n\}最短线性递推式。

假设现在已经得到了数列a1,a2,...,ai1a_1,a_2,...,a_{i-1}的最短线性递推式,且在这之前递推式被修改过cntcnt次,设第kk次修改后的递推式为RkR_k

如果在加入aia_i后原来的递推式仍然满足,也就是有k=1mrkaik=ai\sum_{k=1}^mr_k*a_{i-k}=a_i
那么数列a1,a2,...,ana_1,a_2,...,a_n的最短线性递推式就是RcntR_{cnt}

不然的话,表明第cntcnt次修改后的递推式在位置ii处出错。定义failkfail_{k}表示第kk次修改后的递推式在failkfail_{k}处出错,显然有failcnt=ifail_{cnt}=i,同时定义deltacnt=aik=1mrkaikdelta_{cnt}=a_i-\sum_{k=1}^mr_k*a_{i-k}

如果cnt=0cnt=0,表明aia_i是数列中第一个非零元素,不难证明当前最短线性递推数列就是{0,0,...,0}\{0,0,...,0\}也就是ii00

不然的话,我们可以考虑构造一个新的递推式RR',若对于任意m+1ji1m'+1\le j\le i-1k=1mrkajk=0\sum_{k=1}^{m'}r'_{k}*a_{j-k}=0
且满足k=1mrkaik=deltacnt\sum_{k=1}^{m'}r'_k*a_{i-k}=delta_{cnt}

那么我们只要把RcntR_{cnt}RR'的对应位分别相加就可以得到一个合法的线性递推式。

考虑通过以下方法构造RR'

mul=deltacntdeltacnt1mul=\frac{delta_{cnt}}{delta_{cnt-1}},使R={0,0,...,0,mul,mulRcnt1}R'=\{0,0,...,0,mul,-mul*R_{cnt-1}\},也就是最前面有ifailcnt11i-fail_{cnt-1}-100,再把数列{1,Rcnt1}\{1,-R_{cnt-1}\}乘上mulmul后接到后面。

为什么这样的RR'是合法的呢?

首先k=1mrkaik=muldeltacnt1=deltacnt\sum_{k=1}^{m'}r'_k*a_{i-k}=mul*delta_{cnt-1}=delta_{cnt}
其次对于所有m+1ji1m'+1\le j\le i-1,设p=ji+failcnt1p=j-i+fail_{cnt-1},有k=1mrkajk=mul(apk=1mcnt1rcnt1,kapk)\sum_{k=1}^{m'}r'_k*a_{j-k}=mul*(a_p-\sum_{k=1}^{m_{cnt-1}}r_{cnt-1,k}*a_{p-k})
又因为Rcnt1R_{cnt-1}对于数列a1,...,afailcnt1a_1,...,a_{fail_{cnt-1}}是一个合法的递推式,故上式等于00

所以新的递推式就是R+RcntR'+R_{cnt}

代码

这是模1e9+7意义下的版本
(话说最后求出来线性递推式貌似并不是最短的?)

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>

typedef long long LL;

const int MOD=1000000007;
const int N=3005;

int n,a[N],r[N][N],delta[N],len[N],fail[N];

int ksm(int x,int y)
{
	int ans=1;
	while (y)
	{
		if (y&1) ans=(LL)ans*x%MOD;
		x=(LL)x*x%MOD;y>>=1;
	}
	return ans;
}

int main()
{
	scanf("%d",&n);
	for (int i=1;i<=n;i++) scanf("%d",&a[i]);
	int cnt=0;
	for (int i=1;i<=n;i++)
	{
		int t=a[i];
		for (int j=1;j<=len[cnt];j++) (t+=MOD-(LL)r[cnt][j]*a[i-j]%MOD)%=MOD;
		if (!t) continue;
		delta[cnt]=t;
		fail[cnt]=i;
		if (!cnt)
		{
			len[++cnt]=i;
			continue;
		}
		int mul=(LL)delta[cnt]*ksm(delta[cnt-1],MOD-2)%MOD;
		cnt++;
		len[cnt]=std::max(len[cnt-1],i-fail[cnt-2]+len[cnt-2]);
		for (int j=1;j<=len[cnt-1];j++) r[cnt][j]=r[cnt-1][j];
		(r[cnt][i-fail[cnt-2]]+=mul)%=MOD;
		for (int j=1;j<=len[cnt-2];j++) (r[cnt][j+i-fail[cnt-2]]+=MOD-(LL)r[cnt-2][j]*mul%MOD)%=MOD;
	}
	printf("%d\n",len[cnt]);
	for (int i=1;i<=len[cnt];i++) printf("%d ",r[cnt][i]);
	return 0;
}