欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

cf1040E. Network Safety(并查集)

程序员文章站 2022-09-14 17:58:38
题意 题目链接 一张图,n个点,m条边,每个点有个权值x,x<=1e18。如果一条边的两个端点不一样,那么这条边是安全的,开始时所有边都是安全的。 现在有一个病毒y,病毒可以入侵任意的点,入侵一个点后的权值为(y^x)。 (S,y)表示病毒的权值为y,它只入侵了点集S中的点,整张图的边都是安全的。求 ......

题意

题目链接

一张图,n个点,m条边,每个点有个权值x,x<=1e18。如果一条边的两个端点不一样,那么这条边是安全的,开始时所有边都是安全的。

现在有一个病毒y,病毒可以入侵任意的点,入侵一个点后的权值为(y^x)。

(s,y)表示病毒的权值为y,它只入侵了点集s中的点,整张图的边都是安全的。求出所有的(s,y)。

sol

不算是特别难想的div2压轴题。。

题目中保证了任意两个节点互不相同,因此想让图不安全,对于$(a, x)$一定是存在 a内一点 ^ x = 与该点相邻点的权值

显然,对于每一条边我们可以求出对应的x。

刚开始我傻乎乎的以为任意两个节点之间对应的值都是不同的,但是这样肯定是错的

比如 101 ^ 1 = 100, 1011 ^ 1 = 1010

但就算有相同的也没关系。

考虑$x$的贡献,如果存在两个节点,假设其权值分别为a,b,满足a ^ x = b

那么该节点要么同时不出现,要么同时出现,也就相当于把这两个点看成了一个点

然后我在这里又掉了一次坑,题目中只是说了“刚开始是安全的”,我天真的以为是互不相同。。

这样的话,就有可能存在好多个点看成一个点的情况,直接用并查集维护即可

最终的答案 = $\sum_{k} 2^siz(k) + (2^k - cnt) * (2^n)$

cnt表示互不相同的x的数量

$siz(k)$表示把$x ^ k = y$的点全都缩起来后的总点数

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<map>
#include<vector>
#define pair pair<int, int>
#define mp(x, y) make_pair(x, y)
#define fi first
#define se second
#define ll long long 
using namespace std;
const int maxn = 1e6 + 10, mod = 1e9 + 7;
inline ll read() {
    char c = getchar(); ll x = 0, f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
int n, m, fa[maxn];
ll po[maxn], c[maxn], k;
map<ll, vector<pair> >mp;
int find(int x) {
    return fa[x] == x ? fa[x] : fa[x] = find(fa[x]);
}
main() {
    n = read(); m = read(); k = read();
    po[0] = 1;
    for(int i = 1; i <= n; i++) c[i] = read();
    for(int i = 1; i <= max((ll)n, k); i++) po[i] = (1ll * 2 * po[i - 1]) % mod;
    for(int i = 1; i <= m; i++) {
        int x = read(), y = read();
        mp[c[x] ^ c[y]].push_back(mp(x, y));
    }
    ll ans = 0;
    map<ll, vector<pair> >::iterator it;
    for(int i = 1; i <= n; i++) fa[i] = i;
    for(it = mp.begin(); it != mp.end(); it++) {    
        ll val = it -> first;
        vector<pair> now = it -> second;
        vector<int> res;
        for(int i = 0; i < now.size(); i++) {
            int fx = find(now[i].fi), fy = find(now[i].se);
            res.push_back(fx); res.push_back(fy);
            if(fx == fy) continue;
            fa[fx] = fy; 
        }
        for(int i = 0; i < res.size(); i++) fa[res[i]] = res[i];
        (ans += po[tot]) %= mod;
    }
    printf("%i64d", (ans + 1ll * (po[k] - mp.size()) * po[n] % mod) % mod);
    return 0;
}