欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

2018.9.2zr模拟T3(倍增,计数)

程序员文章站 2023-12-25 12:54:09
...

描述
给出一棵 N 个节点的树,树上的每个节点都有一个权值 ai。
有 Q 次询问,每次在树上选中两个点 u,v,考虑所有在简单路径 u,vu,v 上(包括 u,v)的点构成的集合 S。

wSaw or dis(u,w)

其中 dist(u,w) 为简单路径 u,w 上的边数,or是按位或。

输入格式
第一行两个整数 N,Q。
接下来一行 N 个整数,第 i 个为 ai。
接下来的 N−1 行,每行两个整数 u,v。表示 u,v 之间有一条边。
接下来的 Q 行,每行两个整数 u,v。表示一组询问。

输出格式
对于每个询问,输出一行一个整数表示答案。

样例一
input
5 2
4 3 2 5 3
1 2
1 3
3 4
3 5
2 5
3 4
output
13
7
限制与约定
每个测试点 10 分,共 10 个测试点:
2018.9.2zr模拟T3(倍增,计数)
对于所有的数据,有:1≤N,Q,0≤ai<323232323
时间限制:3s
空间限制:256MB


比较du的计数题
我们考虑ai<2的情况,意味着只有距离为奇数且ai==1才会有重复
倍增统计即可。
于是我们可以类比出一种没有特殊限制的做法
对于ai的每一位维护一个倍增统计答案
复杂度O(nlogn logai)
考虑优化。
我们大可不必对于每一位维护
我们只需要对于一个点,维护他的点权和距离都为1的位即可

#include<bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = j;i <= k;++i)
#define repp(i,j,k) for(int i = j;i >= k;--i)
#define rept(i,x) for(int i = linkk[x];i;i = e[i].n)
#define P pair<int,int>
#define Pil pair<int,ll>
#define Pli pair<ll,int>
#define Pll pair<ll,ll>
#define pb push_back 
#define pc putchar
#define mp make_pair
#define file(k) memset(k,0,sizeof(k))
#define ll long long
namespace fastIO{
    #define BUF_SIZE 100000
    #define OUT_SIZE 100000
    bool IOerror = 0;
    inline char nc(){ 
        static char buf[BUF_SIZE],*p1 = buf+BUF_SIZE, *pend = buf+BUF_SIZE;
        if(p1 == pend){
            p1 = buf; pend = buf+fread(buf, 1, BUF_SIZE, stdin);
            if(pend == p1){ IOerror = 1; return -1;}
        }
        return *p1++;
    }
    inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
    inline int read(){
        bool sign = 0; char ch = nc();int x = 0;
        for(; blank(ch); ch = nc());
        if(IOerror)return 0;
        if(ch == '-') sign = 1, ch = nc();
        for(; ch >= '0' && ch <= '9'; ch = nc()) x = x*10+ch-'0';
        if(sign) x = -x;
        return x;
    }
    #undef OUT_SIZE
    #undef BUF_SIZE
};
using namespace fastIO;
int n , q;
int cnt[300100][21];
int fa[300100][21] , a[300100] , dep[300100];
ll dpu[300100][21] , dpd[300100][21];
ll ans;
int son[300100],top[300100],sz[300100];//树剖
vector<int>G[300100];
void dfs(int x,int f)
{
    fa[x][0] = f;dep[x] = dep[f]+1;
    rep(i,0,20)
        cnt[x][i] = cnt[f][i] + (!(a[x] & (1<<i)));
    int Max = 0;
    rep(i,1,G[x].size())
        if(G[x][i-1] != f)
        {
            dfs(G[x][i-1],x);
            sz[x] += sz[G[x][i-1]];
            if(sz[G[x][i-1]] > Max) son[x] = G[x][i-1],Max = sz[G[x][i-1]];
        }
    sz[x]++;
}
void dfs2(int x)
{
    top[x] = x == son[fa[x][0]]?top[fa[x][0]]:x;
    rep(i,1,G[x].size())
        if(G[x][i-1] != fa[x][0])
            dfs2(G[x][i-1]);
}
int Lca(int x,int y)
{
    for(;top[x] != top[y];)
        dep[top[x]] > dep[top[y]] ? x = fa[top[x]][0] : y = fa[top[y]][0];
    return dep[x] < dep[y] ? x : y;
}
void init()
{
    n = read();q = read();
    rep(i,1,n) a[i] = read();
    rep(i,1,n-1)
    {
        int x = read() , y = read();
        G[x].pb(y);G[y].pb(x);
    }
    dfs(1,0);
    dfs2(1);
}
ll query1(int x,int y)
{
    int len = dep[x] - dep[y];
    ll sum = 0;
    repp(i,20,0)
        if(len & (1<<i))
        {
            sum += dpu[x][i];
            x = fa[x][i];
            sum += 1ll*(1<<i)*(cnt[x][i] - cnt[y][i]);
        }
    return sum;
}
ll query2(int x,int len)
{
    ll sum = 0;
    int now = x;

    rep(i,0,20)
        if(len & (1<<i))
        {
            sum += dpd[now][i];
            sum += 1ll*(1<<i)*(cnt[x][i]-cnt[now][i]);
            now = fa[now][i];
        }
    return sum;
}
int main()
{
    init();
    rep(i,1,n) dpu[i][0] = dpd[i][0] = a[i];
    rep(j,1,20)
        rep(i,1,n)
        {
            fa[i][j] = fa[fa[i][j-1]][j-1];
            int x = fa[i][j-1] , y = fa[i][j];
            dpu[i][j] = dpu[i][j-1] + dpu[x][j-1] + 1ll*(1<<(j-1))*(cnt[x][j-1]-cnt[y][j-1]);
            dpd[i][j] = dpd[i][j-1] + dpd[x][j-1] + 1ll*(1<<(j-1))*(cnt[i][j-1]-cnt[x][j-1]);
        }
    rep(i,1,q)
    {
        int x = read() , y = read();
        int lca = Lca(x,y) , len = dep[x] + dep[y] - 2*dep[lca];
        ans = query1(x,fa[lca][0]) + query2(y,len+1) - query2(lca,len-dep[y]+dep[lca]+1);
        printf("%lld\n",ans);
    }
    return 0;
}

上一篇:

下一篇: