牛客练习赛71 E 神奇的迷宫
程序员文章站
2022-06-07 13:13:32
...
点分治统计路径 NTT优化卷积
设ans[i]代表长度为i的路径的概率,A[i]代表我们已经统计过的路径中,到当前根节点的路径长度i,B[i]代表我己经要统计的子树中,到根节点路径长度为i的概率是多少,这样的话如果要统计当前长度为L的路径的概率我们即可
ans[i] = ∑(0<= k <= i)A[i-k] * B[k],观察A和B的下标和为K,也就是说每一个ans[i]即为一个卷积的形式,然后就可以用NTT来优化这个卷积,点分治去统计长度为L的所有点对。
套的别人的板子,顺手搞一套NTT模板。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 5e5 + 7;
const int MOD = 998244353;
const int inf = 0x3f3f3f3f;
//ntt
const int G = 3;//原根
int len,r[MAXN];
ll x[MAXN],y[MAXN],w[MAXN];
inline ll q_pow(ll x,ll y){
ll res = 1;
while(y){
if(y&1) res = res * x % MOD;
x = x * x % MOD;
y >>= 1;
}
return res;
}
void NTT(ll *a,ll f){
for(int i = 0;i < len;i ++){
if(i < r[i]) swap(a[i],a[r[i]]);
}
w[0] = 1;
for(int i = 2;i <= len;i *= 2){
ll wn;
if(f == 1) wn = q_pow(G,(ll)(MOD-1)/i);
else wn = q_pow(G,(ll)(MOD-1)-(MOD-1)/i);
for(int j = i/2;j >= 0;j -= 2) w[j] = w[j/2];
for(int j = 1;j < i/2;j += 2) w[j] = (w[j-1]*wn)%MOD;
for(int j = 0;j < len;j += i){
for(int k = 0;k < i/2;k ++){
ll u = a[j+k],v = (a[j+k+i/2] * w[k]) % MOD;
a[j+k] = (u + v) % MOD;
a[j+k+i/2] = (u - v + MOD) % MOD;
}
}
}
if(f == -1){
ll inv = q_pow(len,MOD-2);
for(int i = 0;i < len;i ++) a[i] = a[i] * inv % MOD;
}
}
void MUL(ll *a,ll *b,ll *c,ll n,ll m){
len = 1;
while(len <= (n + m)) len *= 2;
int k = trunc(log(len + 0.5) / log(2));
for(int i = 0;i < len;i ++){
r[i] = (r[i>>1]>>1) | ((i&1) << (k-1));
}
for(int i = 0;i < len;i ++){
if(i < n) x[i] = a[i];else x[i] = 0;
if(i < m) y[i] = b[i];else y[i] = 0;
}
NTT(x,1);
NTT(y,1);
for(int i = 0;i < len;i ++) c[i] = x[i] * y[i] % MOD;
NTT(c,-1);
}
int n,SIZE,maxx,root;
int siz[MAXN],maxson[MAXN],dep[MAXN],maxdep[MAXN],vis[MAXN];
ll a[MAXN],cost[MAXN],A[MAXN],B[MAXN],C[MAXN],ans[MAXN];
vector<int>g[MAXN];
void get_root(int u,int fa){
siz[u] = 1,maxson[u] = 0;
int tot = g[u].size();
for(int i = 0;i < tot;i ++){
int v = g[u][i];
if(v == fa || vis[v]) continue;
get_root(v,u);
siz[u] += siz[v];
maxson[u] = max(maxson[u],siz[v]);
}
maxson[u] = max(maxson[u],SIZE-siz[u]);
if(maxx > maxson[u]) root = u,maxx = maxson[u];//找出最大子树最小的节点 作为根节点
}
void get_dep(int u,int fa){//获取以当前点为根节点的最大深度 NTT 要用长度信息
siz[u] = 1;
maxdep[u] = dep[u];
int tot = g[u].size();
for(int i = 0;i < tot;i ++){
int v = g[u][i];
if(v == fa || vis[v]) continue;
dep[v] = dep[u] + 1;
get_dep(v,u);
siz[u] += siz[v];
maxdep[u] = max(maxdep[u],maxdep[v]);
}
}
void get_B(int u,int fa){
B[dep[u]] = (B[dep[u]] + a[u]) % MOD;//B[i] 记录到达深度为i的节点的概率
int tot = g[u].size();
for(int i = 0;i < tot;i ++){
int v = g[u][i];
if(v == fa || vis[v]) continue;
get_B(v,u);
}
}
void divide(int u){
vis[u] = 1;
int L = 0;//A数组的长度
//maxson[u] = 0;
int tot = g[u].size();
for(int i = 0;i < tot;i ++){
int v = g[u][i];
if(vis[v]) continue;
dep[v] = 1;//当前子节点的深度为1
get_dep(v,u);
}
A[0] = a[u];
for(int i = 0;i < tot;i ++){
int v = g[u][i];
if(vis[v]) continue;
get_B(v,u);
MUL(A,B,C,L+1,maxdep[v]+1);//这个慢板的长度要加一
int Len = L + maxdep[v] + 2;
for(int i = 1;i <= Len;i ++) ans[i] = (ans[i] + C[i]) % MOD;
for(int i = 0;i <= maxdep[v];i ++) A[i] = (A[i] + B[i]) % MOD,B[i] = 0;
for(int i = 0;i <= len;i ++) C[i] = 0;
L = max(L,maxdep[v]);//跟新一下A数组的长度
}
for(int i = 0;i <= L;i ++) A[i] = 0;//准备换根 所以A数组也需要相应的清空
for(int i = 0;i < tot;i ++){
int v = g[u][i];
if(vis[v]) continue;
root = 0;
maxx = inf;
SIZE = siz[v];
get_root(v,u);
divide(root);
}
}
int main()
{
scanf("%d",&n);
ll sum = 0;
ll res = 0;
for(int i = 1;i <= n;i ++) scanf("%lld",&a[i]),sum += a[i];
sum = q_pow(sum,MOD-2);//逆元
for(int i = 1;i <= n;i ++){
a[i] = a[i] * sum % MOD;//每个点的概率
res = (res + a[i] * a[i] % MOD) % MOD;//先把一个点的概率算上
}
for(int i = 0;i < n;i ++) scanf("%lld",&cost[i]);
res = res * cost[0] % MOD;//落在同一个点上的期望
int a,b;
for(int i = 1;i < n;i ++){
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
SIZE = n;
maxx = inf;
//maxson[0] = 0;
get_root(1,0);
divide(root);
for(int i = 1;i < n;i ++){
res = (res + 2ll * ans[i] * cost[i] % MOD) % MOD;//对于不在同一点上的a和b 先选a再选b 和先选b再选a 要计算两次的答案
}
printf("%lld\n",res);
return 0;
}
推荐阅读