欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Colorful Tree

程序员文章站 2022-07-13 17:37:06
...

HDU - 6035
Colorful Tree

给定一棵树,每个点有一个颜色1~n,定义两点的路径是两点间不同的颜色数量,求n*(n-1)/2条路径的权值总和。
一开始想主席树,但是后来发现并不用那么麻烦。
假设考虑颜色1的贡献值,可以把所有颜色为1的点割掉,用n*(n-1)/2 减去剩下的所有子树的路径和,就是所有经过了颜色1的路径数,也就是路径1的贡献。那么怎么快速求每一颗子树的大小,考虑肯定是从下往上进行分割,我用fa[t]代表当前dfs到x的节点,往上走第一个颜色为t的节点是fa[t],然后我们可以把x节点压如到fa[t]节点所建立的vector中,当回溯到fa[t]时,减去所有vector中的子树大小,就是割剩的树的大小m,路径总数就是m*(m-1)/2,当fa[t]处理另一颗子树时vector要清空。注意就是当fa[t]是0,即是从上往下第一个颜色是t的节点时要特殊处理。

Colorful Tree

#include <bits/stdc++.h>
#include <vector>
using namespace std;

#define ll long long

const int maxn = 200010;
vector<int> G[maxn];
vector<int> s[2*maxn], c[maxn];
int a[maxn], n, x, y, r[maxn], fa[maxn];
ll f[maxn], all[maxn];
ll ans, t;
bool flag[maxn], p[maxn];

void dfs1(int x) {
    f[x] = 0;
    flag[x] = true;
    ll temp = 0, son = 0;
    if (fa[a[x]] != 0) s[fa[a[x]]].push_back(x);
    else c[a[x]].push_back(x);
    int k = fa[a[x]];
    fa[a[x]] = x;
    for (int i = 0; i < G[x].size(); i++) if (!flag[G[x][i]]) {
        s[x].clear();
        dfs1(G[x][i]);
        f[x] += f[G[x][i]];
        t = f[G[x][i]];
        for (int j = 0; j < s[x].size(); j++) t -= f[s[x][j]];
        all[a[x]] += 1LL*t*(t-1)/2;
    }
    f[x]++;
    fa[a[x]] = k;
}


int main() {
    //freopen("input.txt","r",stdin);
    int cases = 0;
    while (scanf("%d", &n) != EOF) {
        cases++;
        memset(p, 0, sizeof(p));
        for (int i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
            c[a[i]].clear();
            all[a[i]] = 0;
            p[a[i]] = true;
        }
        for (int i = 1; i <= n; i++) G[i].clear();
        for (int i = 1; i < n; i++) {
            scanf("%d %d", &x, &y);
            G[x].push_back(y);
            G[y].push_back(x);
        }
        memset(flag, 0, sizeof(flag));
        ans = 0;
        dfs1(1);
        for (int i = 1; i <= n; i++) if (c[i].size() != 0) {
            t = n;
            for (int j = 0; j < c[i].size(); j++) t -= f[c[i][j]];
            if (t > 0) all[i] += 1LL*t*(t-1)/2;
        }
        for (int i = 1; i <= n; i++) if (p[i]) {
            ans += 1LL*n*(n-1)/2-all[i];
        }
        printf("Case #%d: %lld\n", cases, ans);
    }
}