欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tokitsukaze and Colorful Tree【树状数组+离线+dfs】

程序员文章站 2022-03-21 17:32:07
...

题目链接 HDU-6793


Tokitsukaze and Colorful Tree【树状数组+离线+dfs】

题意:有N个点的树,每个点有颜色和权值,现在有两种操作,要求的是树上的同种颜色的非祖先与子孙节点的两点的异或和。

  1. 更改某个点权值为v
  2. 将某个点的颜色更改为c

  于是我们可以这样考虑,现在将所有的颜色离线下来,每次我们先对一种颜色求贡献,因为有N个点,每个点的颜色都是固定的在几个中的一个的,并且操作数是只有Q次,所以这样的总的点复杂度在Tokitsukaze and Colorful Tree【树状数组+离线+dfs】级别。

  将操作分成增加一个节点、删除一个节点、以及修改一个节点的操作,于是,我们可以对这棵树结构固定的树来进行操作了,因为树结构固定,所以我们可以使用dfs序来进行简单的处理,如果使用树链剖分这里会被卡TLE,出题人故意卡了树链剖分。以及一些暴力的解决办法。

  所以,我们将操作离线之后,我们直接维护一棵类似于虚树一样的树就可以了,分别求出加入点的增加贡献,到对应的操作;删除点对应的减少操作,也是对应的操作;以及修改的操作。

  然后就是码码码,就是了。

  小心auto搭配vector的pair类型,会RE的。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
#include <unordered_map>
#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int maxN = 1e5 + 7;
int N, Q, col[maxN], val[maxN];
namespace Graph
{
    int head[maxN], cnt;
    struct Eddge
    {
        int nex, to;
        Eddge(int a=-1, int b=0):nex(a), to(b) {}
    } edge[maxN << 1];
    inline void addEddge(int u, int v)
    {
        edge[cnt] = Eddge(head[u], v);
        head[u] = cnt++;
    }
    inline void _add(int u, int v) { addEddge(u, v); addEddge(v, u); }
    inline void init()
    {
        cnt = 0;
        for(int i=1; i<=N; i++) head[i] = -1;
    }
};
using namespace Graph;
int dfn[maxN], tot, rid[maxN], end_tim[maxN];
void dfs(int u, int fa)
{
    dfn[u] = ++tot; rid[tot] = u;
    for(int i=head[u], v; ~i; i=edge[i].nex)
    {
        v = edge[i].to;
        if(v == fa) continue;
        dfs(v, u);
    }
    end_tim[u] = tot;
}
struct Question
{
    int add_del_set, qith, u, old_val, nex_val;   //add or delt or set, ith optionertion, point num
    Question(int a=0, int b=0, int c=0, int d=0, int f=0):add_del_set(a), qith(b), u(c), old_val(d), nex_val(f) {}
    friend bool operator < (Question e1, Question e2) { return e1.qith == e2.qith ? dfn[e1.u] < dfn[e2.u] : e1.qith < e2.qith; }
};
ll ans[maxN];
vector<Question> each_col[maxN];    //each color point
namespace BIT
{
    struct Save
    {
        int t[maxN];
        vector<pair<int, int>> Stap;
        void update(int x, int v)
        {
            Stap.push_back(make_pair(x, v));
            while(x <= N)
            {
                t[x] += v;
                x += lowbit(x);
            }
        }
        int query(int x)
        {
            int sum = 0;
            while(x)
            {
                sum += t[x];
                x -= lowbit(x);
            }
            return sum;
        }
        void Bit_clear()
        {
            int len = (int)Stap.size();
            for(int i=0; i<len; i++) update(Stap[i].first, -Stap[i].second);
            Stap.clear();
        }
    } deep_value[21], siz_value[21], siz, deep;
};
using namespace BIT;
int all_point, all_dig[21];
void add_Point(int col_id, int qid, int u, int val)
{
    for(int i=0, id; i<20; i++)
    {
        id = (val >> i) & 1;
        int sum = siz.query(end_tim[u]) - siz.query(dfn[u] - 1);
        sum += deep.query(dfn[u]);
        sum = all_point - sum;
        ll tmp = deep_value[i].query(dfn[u]) + siz_value[i].query(end_tim[u]) - siz_value[i].query(dfn[u] - 1);
        tmp = all_dig[i] - tmp;
        if(id) ans[qid] += (1LL << i) * (sum - tmp);
        else ans[qid] += (1LL << i) * tmp;
        if(id)
        {
            deep_value[i].update(dfn[u] + 1, 1);
            deep_value[i].update(end_tim[u] + 1, -1);
            siz_value[i].update(dfn[u], 1);
            all_dig[i]++;
        }
    }
    all_point++;
    deep.update(dfn[u] + 1, 1);
    deep.update(end_tim[u] + 1, -1);
    siz.update(dfn[u], 1);
}
void del_Point(int col_id, int qid, int u, int val)
{
    all_point--;
    deep.update(dfn[u] + 1, -1);
    deep.update(end_tim[u] + 1, 1);
    siz.update(dfn[u], -1);
    for(int i=0, id; i<20; i++)
    {
        id = (val >> i) & 1;
        int sum = siz.query(end_tim[u]) - siz.query(dfn[u] - 1);
        sum += deep.query(dfn[u]);
        sum = all_point - sum;
        ll tmp = deep_value[i].query(dfn[u]) + siz_value[i].query(end_tim[u]) - siz_value[i].query(dfn[u] - 1);
        tmp = all_dig[i] - tmp;
        if(id) ans[qid] -= (1LL << i) * (sum - tmp);
        else ans[qid] -= (1LL << i) * tmp;
        if(id)
        {
            deep_value[i].update(dfn[u] + 1, -1);
            deep_value[i].update(end_tim[u] + 1, 1);
            siz_value[i].update(dfn[u], -1);
            all_dig[i]--;
        }
    }
}
void solve_its_col(int col_id)
{
    all_point = 0; memset(all_dig, 0, sizeof(all_dig));
    vector<Question> it = each_col[col_id];
    int len = (int)it.size();
    if(len <= 1) return;
    sort(it.begin(), it.end());
    for(int i=0, op, u; i<len; i++)
    {
        op = it[i].add_del_set;
        u = it[i].u;
        switch (op)
        {
            case 0:
            {
                add_Point(col_id, it[i].qith, u, it[i].nex_val);
                break;
            }
            case 1:
            {
                del_Point(col_id, it[i].qith, u, it[i].nex_val);
                break;
            }
            default:
            {
                del_Point(col_id, it[i].qith, u, it[i].old_val);
                add_Point(col_id, it[i].qith, u, it[i].nex_val);
                break;
            }
        }
    }
    for(int i=0; i<20; i++) { deep_value[i].Bit_clear(); siz_value[i].Bit_clear(); }
    siz.Bit_clear();
    deep.Bit_clear();
}
int main()
{
    int T; scanf("%d", &T);
    while(T--)
    {
        scanf("%d", &N);
        init();
        for(int i=1; i<=N; i++) scanf("%d", &col[i]);
        for(int i=1; i<=N; i++) scanf("%d", &val[i]);
        for(int i=1, u, v; i<N; i++)
        {
            scanf("%d%d", &u, &v);
            _add(u, v);
        }
        tot = 0;
        dfs(1, 0);
        for(int i=1; i<=N; i++)
        {
            each_col[col[i]].push_back(Question(0, 0, i, val[i], val[i]));
        }
        scanf("%d", &Q);
        for(int i=1, op, x, y; i<=Q; i++)
        {
            scanf("%d%d%d", &op, &x, &y);
            switch (op)
            {
                case 1:
                {
                    each_col[col[x]].push_back(Question(2, i, x, val[x], y));
                    val[x] = y;
                    break;
                }
                default:
                {
                    each_col[col[x]].push_back(Question(1, i, x, val[x], val[x]));
                    col[x] = y;
                    each_col[y].push_back(Question(0, i, x, val[x], val[x]));
                    break;
                }
            }
        }
        for(int i=0; i<=Q; i++) ans[i] = 0;
        for(int col_id=1; col_id<=N; col_id++)
        {
            solve_its_col(col_id);
        }
        for(int i=0; i<=Q; i++)
        {
            if(i) ans[i] += ans[i - 1];
            printf("%lld\n", ans[i]);
        }
        for(int i=1; i<=N; i++) each_col[i].clear();
    }
    return 0;
}