欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

jzoj6407 【NOIP2019模拟11.05】小 D 与随机 (容斥计数)

程序员文章站 2024-03-17 16:52:58
...

题意

一棵树,现在随机分配一个排列,设一种方案中,到根没有比他大的点(好点)的个数是cc,那么这种方案有KcK^c的贡献。求贡献和。
n5000,K109n\leq 5000,K\leq10^9

思路

  • 题解看不懂
  • Kc=(K+11)c=(K1)i×(ci)\sum K^c=\sum (K+1-1)^c=\sum(K-1)^i\times \binom{c}{i}
  • 将k-1,就可以转化成“至少”的计数问题。
  • 问题现在变成了:分配排列,方案权值和。
  • 这个问题可以dp解决。考虑到当前点的取值范围依赖于子树内最小好点,设f[i][j]f[i][j]表示i的子树内,最小的好点的是第j小的所有方案权值和。
  • 转移方法是枚举两颗子树合并后,最小好点的位置,然后再讨论当前点是不是好点,与之前点的相对大小。
  • 观察发现可以前缀和优化,注意循环顺序与范围,不要退化了。
  • O(n2)O(n^2)
#pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5050, mo = 998244353;
ll n, k;
int final[N], nex[N * 2], to[N * 2], tot;
ll C[N][N];
ll f[N][N];
int sz[N];

void link(int x, int y) {
	to[++tot] = y, nex[tot] = final[x], final[x] = tot; 
}

ll sum[N];
void merge(ll *tmp, int x, int y) {
	memset(sum, 0, sizeof sum);
	for(int v = sz[y] + 1; v; v--) {
		sum[v] = (sum[v + 1] + f[y][v]) % mo;
	}
	for(int u = 1; u <= sz[x]; u++) {
		for(int e = u; e <= u + sz[y]; e++) {
			tmp[e] = (tmp[e] + f[x][u] * sum[e - u + 1] % mo 
					* C[e - 1][u - 1] % mo 
					* C[sz[x] + sz[y] - e][sz[x] - u]) % mo;
		}
	}
}

ll tmp[N];
void dfs(int x, int from) {
	int has = 0;
	for(int i = final[x]; i; i = nex[i]) {
		int y = to[i]; if (y != from) {
			dfs(y, x);
			if (has == 0) {
				memcpy(f[x], f[y], sizeof f[y]);
			} else {
				memset(tmp, 0, (5 + sz[x] + sz[y]) * 8);
				merge(tmp, x, y);
				merge(tmp, y, x);
				tmp[sz[x] + sz[y] + 1] = f[x][sz[x] + 1] * f[y][sz[y] + 1] % mo * C[sz[x] + sz[y]][sz[x]] % mo;
				memcpy(f[x], tmp, (5 + sz[x] + sz[y]) * 8);
			}
			sz[x] += sz[y];
			has = 1;
		}
	}
	sz[x] ++;
	if (!has) {
		f[x][1] = k - 1;
		f[x][2] = 1;
		return;
	}
	memset(tmp, 0, sizeof tmp);
	for(int i = 1; i <= sz[x]; i++) {
		tmp[1] = (tmp[1] + f[x][i] * (k - 1)) % mo;
		tmp[i + 1] = (tmp[i + 1] - f[x][i] * (k - 1)) % mo;
		tmp[i + 1] = (tmp[i + 1] + i * f[x][i]) % mo;//x不是好点
		tmp[i + 2] = (tmp[i + 2] - i * f[x][i]) % mo;
	}
	for(int i = 1; i <= sz[x] + 1; i ++) {
		tmp[i] = (tmp[i - 1] + tmp[i]) % mo;
		f[x][i] = tmp[i];
	}
}

int main() {
	freopen("random.in", "r", stdin);
	// freopen("random.out", "w", stdout);
	cin >> n >> k;
	C[0][0] = 1;
	for(int i = 1; i <= n; i++) {
		C[i][0] = 1;
		for(int j = 1; j <= i; j++) 
			C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mo;
	}
	for(int i = 1; i < n; i++) {
		int u, v; scanf("%d %d", &u, &v);
		link(u, v), link(v, u);
	}
	dfs(1, 0);
	ll ans = 0;
	for(int i = 1; i <= n + 1; i++) {
		ans = (ans + f[1][i]) % mo;
	}
	cout << (ans + mo) % mo << endl;
}