jzoj6407 【NOIP2019模拟11.05】小 D 与随机 (容斥计数)
程序员文章站
2024-03-17 16:52:58
...
题意
一棵树,现在随机分配一个排列,设一种方案中,到根没有比他大的点(好点)的个数是,那么这种方案有的贡献。求贡献和。
思路
题解看不懂- 将k-1,就可以转化成“至少”的计数问题。
- 问题现在变成了:分配排列,方案权值和。
- 这个问题可以dp解决。考虑到当前点的取值范围依赖于子树内最小好点,设表示i的子树内,最小的好点的是第j小的所有方案权值和。
- 转移方法是枚举两颗子树合并后,最小好点的位置,然后再讨论当前点是不是好点,与之前点的相对大小。
- 观察发现可以前缀和优化,注意循环顺序与范围,不要退化了。
#pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5050, mo = 998244353;
ll n, k;
int final[N], nex[N * 2], to[N * 2], tot;
ll C[N][N];
ll f[N][N];
int sz[N];
void link(int x, int y) {
to[++tot] = y, nex[tot] = final[x], final[x] = tot;
}
ll sum[N];
void merge(ll *tmp, int x, int y) {
memset(sum, 0, sizeof sum);
for(int v = sz[y] + 1; v; v--) {
sum[v] = (sum[v + 1] + f[y][v]) % mo;
}
for(int u = 1; u <= sz[x]; u++) {
for(int e = u; e <= u + sz[y]; e++) {
tmp[e] = (tmp[e] + f[x][u] * sum[e - u + 1] % mo
* C[e - 1][u - 1] % mo
* C[sz[x] + sz[y] - e][sz[x] - u]) % mo;
}
}
}
ll tmp[N];
void dfs(int x, int from) {
int has = 0;
for(int i = final[x]; i; i = nex[i]) {
int y = to[i]; if (y != from) {
dfs(y, x);
if (has == 0) {
memcpy(f[x], f[y], sizeof f[y]);
} else {
memset(tmp, 0, (5 + sz[x] + sz[y]) * 8);
merge(tmp, x, y);
merge(tmp, y, x);
tmp[sz[x] + sz[y] + 1] = f[x][sz[x] + 1] * f[y][sz[y] + 1] % mo * C[sz[x] + sz[y]][sz[x]] % mo;
memcpy(f[x], tmp, (5 + sz[x] + sz[y]) * 8);
}
sz[x] += sz[y];
has = 1;
}
}
sz[x] ++;
if (!has) {
f[x][1] = k - 1;
f[x][2] = 1;
return;
}
memset(tmp, 0, sizeof tmp);
for(int i = 1; i <= sz[x]; i++) {
tmp[1] = (tmp[1] + f[x][i] * (k - 1)) % mo;
tmp[i + 1] = (tmp[i + 1] - f[x][i] * (k - 1)) % mo;
tmp[i + 1] = (tmp[i + 1] + i * f[x][i]) % mo;//x不是好点
tmp[i + 2] = (tmp[i + 2] - i * f[x][i]) % mo;
}
for(int i = 1; i <= sz[x] + 1; i ++) {
tmp[i] = (tmp[i - 1] + tmp[i]) % mo;
f[x][i] = tmp[i];
}
}
int main() {
freopen("random.in", "r", stdin);
// freopen("random.out", "w", stdout);
cin >> n >> k;
C[0][0] = 1;
for(int i = 1; i <= n; i++) {
C[i][0] = 1;
for(int j = 1; j <= i; j++)
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mo;
}
for(int i = 1; i < n; i++) {
int u, v; scanf("%d %d", &u, &v);
link(u, v), link(v, u);
}
dfs(1, 0);
ll ans = 0;
for(int i = 1; i <= n + 1; i++) {
ans = (ans + f[1][i]) % mo;
}
cout << (ans + mo) % mo << endl;
}