欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

P4593 [TJOI2018]教科书般的*(拉格朗日插值 + k幂次之和)

程序员文章站 2022-06-07 09:52:11
...

P4593 [TJOI2018]教科书般的*(拉格朗日插值 + k幂次之和)
洛谷题目链接


题目大意:有点绕,有 nmn - m 个怪兽,它们的血量在 [1,n][1,n] 值域上且每个怪兽的血量都不同 (其中 m 个点被挖掉),你每使用一次*会给所有怪兽血量 减一,若有怪兽死亡(血量为0),则自动再使用一次*(这并不算你使用,求答案时不考虑自动使用的贡献),设最多要使用 k 张*才能杀死所有怪兽,你每使用一次*,每个活着的怪兽都会产生 xkx^k 的分数贡献,其中 xx 是使用*之前的血量。求最后获得的总分数。

首先 k=m+1k = m + 1,第一次当成在0位置使用一次*,每有一个空位,都要使用一次*,连续的段使用一次就可以全部杀死。由于空位较少,考虑在每个空位使用*计算全部的贡献最后再扣掉空位的贡献。

在 0 空位使用一次*的贡献显然是 :i=1nik0\displaystyle\sum_{i = 1}^ni^k-(0空位后面所有空位的贡献)

下次在 p 空位使用一次*的贡献是 :i=1npikp\displaystyle\sum_{i = 1}^{n-p}i^k-(p空位后面所有空位的贡献)

枚举所有空位,在每个空位计算贡献 :i=1npik\displaystyle\sum_{i=1}^{n-p}i^k,这个式子可以用拉格朗日插值快速计算。

对每个空位需要扣掉的贡献:枚举这个空位 p 之前的所有空位 x,扣掉 (px)k(p-x)^k

复杂度 O((m2+mn)log(mod))O((m^2 + mn)*\log(mod))


代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e3 + 10;
const int mod = 1e9 + 7;
typedef long long ll;
int mx,t,a[maxn];
inline ll add(ll x, ll y) {
  	x += y;
  	if (x >= mod) x -= mod;
  	return x;
}

inline ll sub(ll x, ll y) {
	x -= y;
	if (x < 0) x += mod;
	return x;
}

inline ll mul(ll x, ll y) {
  	return x * y % mod;
}
ll fpow(ll a,ll b) {
	ll r = 1;
	while(b) {
		if (b & 1) r = mul(r,a);
		b >>= 1;
		a = mul(a,a);
	}
	return r;
}
ll g[maxn],f[maxn],u[maxn],r[maxn],k,n,m;
ll fac[maxn],ifac[maxn];
ll cal(ll g[maxn],ll x) {			//拉格朗日插值计算多项式
	if (x <= mx) return g[x];
	ll tmp = 1,inv,ans = 0;
	for (int i = 1; i <= mx; i++)
		tmp = mul(tmp,x - i);
	for (int i = 1; i <= mx; i++) {
		ll res = 1, inv = fpow(x - i,mod - 2);
		res = mul(res,g[i]);
		res = mul(res,ifac[i - 1]);
		res = mul(res,ifac[mx - i]);
		res = mul(res,inv);
		res = mul(res,tmp);
		if ((mx - i) & 1) res = mul(res,-1);
		if (res < 0) res += mod;
		ans = add(ans,res);
	}
	return ans;
}
ll C(int x,int y) {
	if (y > x || x < 0) return 0;
	return mul(mul(ifac[y],ifac[x - y]),fac[x]);
}
int main() {
	fac[0] = 1;
	for (int i = 1; i <= 1000; i++)
		fac[i] = mul(fac[i - 1],i);
	ifac[1000] = fpow(fac[1000],mod - 2);
	for (int i = 1000 - 1; i >= 0; i--)
		ifac[i] = mul(ifac[i + 1],i + 1);
	scanf("%d",&t);
	while (t--) {
		scanf("%d%d",&n,&m);
		for (int i = 1; i <= m; i++) {
			scanf("%d",&a[i]);
		}
		sort(a + 1,a + m + 1);
		mx = m + 3;
		a[0] = g[0] = 0;
		for (int i = 1; i <= mx; i++)
			g[i] = add(g[i - 1],fpow(i,m + 1));
		ll ans = 0;
		for (int i = 0; i <= m; i++) {
			ans = add(ans,cal(g,n - a[i]));
			for (int j = 0; j <= i; j++)
				ans = sub(ans,fpow(a[i] - a[j],m + 1));
		}
		printf("%lld\n",ans);
	}
	return 0;
}