欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

牛客练习赛71 E 神奇的迷宫

程序员文章站 2022-06-07 13:13:32
...

牛客练习赛71 E 神奇的迷宫
点分治统计路径 NTT优化卷积
设ans[i]代表长度为i的路径的概率,A[i]代表我们已经统计过的路径中,到当前根节点的路径长度i,B[i]代表我己经要统计的子树中,到根节点路径长度为i的概率是多少,这样的话如果要统计当前长度为L的路径的概率我们即可
ans[i] = ∑(0<= k <= i)A[i-k] * B[k],观察A和B的下标和为K,也就是说每一个ans[i]即为一个卷积的形式,然后就可以用NTT来优化这个卷积,点分治去统计长度为L的所有点对。
套的别人的板子,顺手搞一套NTT模板。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int MAXN = 5e5 + 7;
const int MOD = 998244353;
const int inf = 0x3f3f3f3f;

//ntt
const int G = 3;//原根
int len,r[MAXN];
ll x[MAXN],y[MAXN],w[MAXN];
inline ll q_pow(ll x,ll y){
    ll res = 1;
    while(y){
        if(y&1) res = res * x % MOD;
        x = x * x % MOD;
        y >>= 1;
    }
    return res;
}

void NTT(ll *a,ll f){
	for(int i = 0;i < len;i ++){
		if(i < r[i]) swap(a[i],a[r[i]]);
	}
	w[0] = 1;
	for(int i = 2;i <= len;i *= 2){
		ll wn;
		if(f == 1) wn = q_pow(G,(ll)(MOD-1)/i);
		else wn = q_pow(G,(ll)(MOD-1)-(MOD-1)/i);
		for(int j = i/2;j >= 0;j -= 2) w[j] = w[j/2];
		for(int j = 1;j < i/2;j += 2) w[j] = (w[j-1]*wn)%MOD;
		for(int j = 0;j < len;j += i){
			for(int k = 0;k < i/2;k ++){
				ll u = a[j+k],v = (a[j+k+i/2] * w[k]) % MOD;
				a[j+k] = (u + v) % MOD;
				a[j+k+i/2] = (u - v + MOD) % MOD;
			}
		}
	}
	if(f == -1){
		ll inv = q_pow(len,MOD-2);
		for(int i = 0;i < len;i ++) a[i] = a[i] * inv % MOD;
	}
}

void MUL(ll *a,ll *b,ll *c,ll n,ll m){
	len = 1;
	while(len <= (n + m)) len *= 2;
	int k = trunc(log(len + 0.5) / log(2));
	for(int i = 0;i < len;i ++){
		r[i] = (r[i>>1]>>1) | ((i&1) << (k-1));
	}
	for(int i = 0;i < len;i ++){
		if(i < n) x[i] = a[i];else x[i] = 0;
		if(i < m) y[i] = b[i];else y[i] = 0;
	}
	NTT(x,1);
	NTT(y,1);
	for(int i = 0;i < len;i ++) c[i] = x[i] * y[i] % MOD;
	NTT(c,-1);
}

int n,SIZE,maxx,root;
int siz[MAXN],maxson[MAXN],dep[MAXN],maxdep[MAXN],vis[MAXN];
ll a[MAXN],cost[MAXN],A[MAXN],B[MAXN],C[MAXN],ans[MAXN];
vector<int>g[MAXN];

void get_root(int u,int fa){
    siz[u] = 1,maxson[u] = 0;
    int tot = g[u].size();
    for(int i = 0;i < tot;i ++){
        int v = g[u][i];
        if(v == fa || vis[v]) continue;
        get_root(v,u);
        siz[u] += siz[v];
        maxson[u] = max(maxson[u],siz[v]);
    }
    maxson[u] = max(maxson[u],SIZE-siz[u]);
    if(maxx > maxson[u]) root = u,maxx = maxson[u];//找出最大子树最小的节点 作为根节点
}

void get_dep(int u,int fa){//获取以当前点为根节点的最大深度 NTT 要用长度信息
    siz[u] = 1;
    maxdep[u] = dep[u];
    int tot = g[u].size();
    for(int i = 0;i < tot;i ++){
        int v = g[u][i];
        if(v == fa || vis[v]) continue;
        dep[v] = dep[u] + 1;
        get_dep(v,u);
        siz[u] += siz[v];
        maxdep[u] = max(maxdep[u],maxdep[v]);
    }
}

void get_B(int u,int fa){
    B[dep[u]] = (B[dep[u]] + a[u]) % MOD;//B[i] 记录到达深度为i的节点的概率
    int tot = g[u].size();
    for(int i = 0;i < tot;i ++){
        int v = g[u][i];
        if(v == fa || vis[v]) continue;
        get_B(v,u);
    }
}

void divide(int u){
    vis[u] = 1;
    int L = 0;//A数组的长度
    //maxson[u] = 0;
    int tot = g[u].size();
    for(int i = 0;i < tot;i ++){
        int v = g[u][i];
        if(vis[v]) continue;
        dep[v] = 1;//当前子节点的深度为1
        get_dep(v,u);
    }
    A[0] = a[u];
    for(int i = 0;i < tot;i ++){
        int v = g[u][i];
        if(vis[v]) continue;
        get_B(v,u);
        MUL(A,B,C,L+1,maxdep[v]+1);//这个慢板的长度要加一
        int Len = L + maxdep[v] + 2;
        for(int i = 1;i <= Len;i ++) ans[i] = (ans[i] + C[i]) % MOD;
        for(int i = 0;i <= maxdep[v];i ++) A[i] = (A[i] + B[i]) % MOD,B[i] = 0;
        for(int i = 0;i <= len;i ++) C[i] = 0;
        L = max(L,maxdep[v]);//跟新一下A数组的长度
    }
    for(int i = 0;i <= L;i ++) A[i] = 0;//准备换根 所以A数组也需要相应的清空
    for(int i = 0;i < tot;i ++){
        int v = g[u][i];
        if(vis[v]) continue;
        root = 0;
        maxx = inf;
        SIZE = siz[v];
        get_root(v,u);
        divide(root);
    }
}

int main()
{
    scanf("%d",&n);
    ll sum = 0;
    ll res = 0;
    for(int i = 1;i <= n;i ++) scanf("%lld",&a[i]),sum += a[i];
    sum = q_pow(sum,MOD-2);//逆元
    for(int i = 1;i <= n;i ++){
        a[i] = a[i] * sum % MOD;//每个点的概率
        res = (res + a[i] * a[i] % MOD) % MOD;//先把一个点的概率算上
    }
    for(int i = 0;i < n;i ++) scanf("%lld",&cost[i]);
    res = res * cost[0] % MOD;//落在同一个点上的期望
    int a,b;
    for(int i = 1;i < n;i ++){
        scanf("%d%d",&a,&b);
        g[a].push_back(b);
        g[b].push_back(a);
    }
    SIZE = n;
    maxx = inf;
    //maxson[0] = 0;
    get_root(1,0);
    divide(root);
    for(int i = 1;i < n;i ++){
        res = (res + 2ll * ans[i] * cost[i] % MOD) % MOD;//对于不在同一点上的a和b 先选a再选b 和先选b再选a 要计算两次的答案
    }
    printf("%lld\n",res);
    return 0;
}
相关标签: 牛客练习赛