欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Codeforce 622 F. The Sum of the k-th Powers(拉格朗日插值求k次幂之和,拉格朗日插值公式)

程序员文章站 2022-07-05 17:25:30
...

Codeforce 622 F. The Sum of the k-th Powers(拉格朗日插值求k次幂之和,拉格朗日插值公式)

题目大意:求 i=1nik\displaystyle\sum_{i = 1}^ni^k
求k次幂有多种求法,例如:
伯努利数求k次幂之和(待补)
斯特林数求k次幂之和
拉格朗日插值法求k次幂之和

这里采用拉格朗日插值法进行求解。
拉格朗日可以通过 k+1k + 1 个点唯一确定一个 kk 次多项式,它的公式为:f(x)=i=1ny[i]ijxx[j]x[i]x[j]f(x) = \sum_{i = 1}^ny[i] \prod_{i \neq j}\frac{x - x[j]}{x[i]-x[j]}
其中x[i],y[i]x[i],y[i]对应已知的点值,对已知的点很容易通过代入验证正确性,带入 x[i]x[i] 将会得到 y[i]y[i]

这个式子在一般情况下的复杂度为 O(n2)O(n^2),比高斯消元的 n3n^3 更加优秀,在已知点的 xx 取值连续的情况下,复杂度能降低到 O(n)O(n),只要预处理阶乘逆元,以及 xx 的 k + 1 项倒阶乘:xfacxfac
f(x)=i=1ny[i]xfacfac[i]fac[ni](xi)f(x)=\sum_{i = 1}^ny[i]*\frac{xfac}{fac[i]*fac[n - i]*(x-i)}

为什么这题可以用拉格朗日插值
当然是因为 i=1nik\displaystyle\sum_{i = 1}^ni^k 是一个以n为自变量的多项式,并且是 k+1k + 1 次多项式
证明:
S(n,k)=i=1nik\displaystyle S(n,k)=\sum_{i = 1}^ni^k
对这个序列两两差分可以得到:(n+1)k+1nk+1=i=0k+1C(k+1,i)nink+1=i=0kC(k+1,i)ni(n + 1)^{k+1} - n^{k+1}=\sum_{i = 0}^{k+1}C(k+1,i)*n^i - n^{k+1}=\sum_{i = 0}^kC(k+1,i)*n^ink+1(n1)k+1=i=0kC(k+1,i)(n1)in^{k+1} - (n-1)^{k+1}=\sum_{i = 0}^{k}C(k+1,i)*(n-1)^i......1k+10k+1=i=0kC(k+1,i)0i1^{k+1}-0^{k+1}=\sum_{i = 0}^{k}C(k+1,i)0^i

逐项求和可以得到 (n+1)k+1=i=0kC(k+1,i)S(n,k)\displaystyle (n+1)^{k+1} =\sum_{i=0}^kC(k+1,i)*S(n,k),即得到S(n,k)S(n,k)是以 nn 为自变量的 k+1k + 1 次多项式

f(x)=xkf(x) = x^k,可以得到一个更一般的推广结论:kk 次多项式的前 nn 项和 g(n)g(n) 是一个以 nn 为自变量的 k+1k + 1 次多项式

回到这题,前k+2k + 2项可以 klogkk \log k 暴力计算,对nk+2n \leq k + 2 直接输出答案,对 n>k+2n > k + 2 只要插值一下,根据插值公式计算即可。


代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 1e6 + 100;
int n,k;
int x[maxn],y[maxn];				//拉格朗日差值的计算 
int fac[maxn],ifac[maxn];			//阶乘的逆元 
inline int add(int x, int y) {
  	x += y;
  	if (x >= mod)
    	x -= mod;
  	return x;
}

inline int sub(int x, int y) {
  	x -= y;
  	if (x < 0)
    	x += mod;
  	return x;
}

inline int mul(int x, int y) {
  	return (long long) x * y % mod;
}
int fpow(int a,int b) {
	int r = 1;
	while (b) {
		if (b & 1) r = mul(r,a);
		b >>= 1;
		a = mul(a,a);
	}
	return r;
}
int main() {
	scanf("%d%d",&n,&k);
	for (int i = 1; i <= k + 2; i++) {			//暴力计算 k + 2 个点,根据这 k + 2个点就可以通过插值唯一确定 k + 1次多项式 
		x[i] = i;
		y[i] = add(y[i - 1],fpow(x[i],k));
	}
	if (n <= k + 2) {							//n <= k + 2就直接输出,否则下面的处理会出错 
		printf("%d\n",y[n]);
		return 0;
	}
	fac[0] = 1;
	for (int i = 1; i <= k + 2; i++) {			//由于k+2个点x取值连续,预处理阶乘,使复杂度降低到O(k) 
		fac[i] = mul(fac[i - 1],i);
	}
	ifac[k + 2] = fpow(fac[k + 2],mod - 2);
	for (int i = k + 1; i >= 0; i--) {
		ifac[i] = mul(ifac[i + 1],i + 1);
	}
	int tmp = 1;								//n的倒阶乘 ,同样也是为了加速 
	for (int i = 1; i <= k + 2; i++) {
		tmp = mul(tmp,(n - i) % mod);
	}
	int ans = 0;
	for (int i = 1; i <= k + 2; i++) {			//插值迭代,得到 f(n) 
		int t = k + 2 - i;
		int p = (t & 1) ? -1 : 1;
		int inv = fpow((n - i) % mod,mod - 2);
		int res = mul(mul(ifac[i - 1],ifac[t]),mul(tmp,inv));
		res = mul(mul(res,p),y[i]);
		if (res < 0) res += mod;
		ans = add(ans,res);
	}
	printf("%d\n",ans);
	return 0;
}