欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

[树套树][学习笔记]

程序员文章站 2022-03-25 19:12:18
思想 树套树像他的名字一样,就是一棵树套另一棵树。用一棵外层树来维护一些区间之类的东西。然后外层树的每个节点都是一棵内层树。就这样 一道模板题 bzoj3196 思路 ......

思想

树套树像他的名字一样,就是一棵树套另一棵树。用一棵外层树来维护一些区间之类的东西。然后外层树的每个节点都是一棵内层树。就这样

一道模板题

bzoj3196

思路

这是一道线段树套平衡树的模板题。外层用一棵线段树来维护区间操作。然后线段树的每个节点都是一棵平衡树

操作1:查询从l到r中比k小的数的个数,然后+1输出即可

操作2:二分一下答案,找排名小于等于k的最大值就行了

操作3:将原来的值先删去,然后加入新的值.

操作4:查询每个子区间中的前驱,然后最大的那个就是当前区间中的前驱

操作5:与操作4类似,查询每个子区间中的后继,然后最小的那个就是当前区间中的后继。

ps:在进行操作3的时候不要忘记将原来数组中的值也进行更改,不然以后再删除的时候会出错。在这个地方调了2h 2333

代码

/*
* @author: wxyww
* @date:   2018-12-11 08:29:48
* @last modified time: 2018-12-11 10:44:01
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<bitset>
using namespace std;
typedef long long ll;
#define ls tr[cur].ch[0]
#define rs tr[cur].ch[1]
const int n = 100000 + 100,inf = 2147483647;
ll read() {
    ll x=0,f=1;char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        x=x*10+c-'0';
        c=getchar();
    }
    return x*f;
}
namespace treap {
    struct node {
        int val,siz,ch[2],id,cnt;
    }tr[n * 20];
    void up(int cur) {
        tr[cur].siz = tr[ls].siz + tr[rs].siz + tr[cur].cnt;
    }
    int tot = 0;
    void rotate(int &cur,int f) {
        int son = tr[cur].ch[f];
        tr[cur].ch[f] = tr[son].ch[f ^ 1];
        tr[son].ch[f ^ 1] = cur;
        up(cur);
        cur = son;
        up(cur);
    }
    void insert(int &cur,int val) {
        if(!cur) {
            cur = ++tot;
            tr[cur].val = val;
            tr[cur].siz = tr[cur].cnt = 1;
            tr[cur].id = rand();
            return;
        }
        tr[cur].siz++;
        if(val == tr[cur].val) {tr[cur].cnt++;return;}
        int d = val > tr[cur].val;
        insert(tr[cur].ch[d],val);
        if(tr[tr[cur].ch[d]].id < tr[cur].id) rotate(cur,d);
    }
    void del(int &cur,int val) {
        if(!cur) return;
        if(tr[cur].val == val) {
            if(tr[cur].cnt > 1) {tr[cur].cnt--;tr[cur].siz--;return;}
            if(!ls || !rs) {cur = ls + rs;return;}
            rotate(cur,tr[rs].id < tr[ls].id);
            del(cur,val);
            return;
        }
        tr[cur].siz--;
        del(tr[cur].ch[val > tr[cur].val],val);
    }
    int rank(int cur,int val) {
        int ans = 0;
        while(cur) {
            if(val < tr[cur].val) cur = ls;
            else if(val == tr[cur].val) return ans + tr[ls].siz;
            else ans += tr[ls].siz + tr[cur].cnt,cur = rs;
        }
        return ans;
    }
    int pred(int cur,int val) {
        if(!cur) return -inf;
        if(val > tr[cur].val) return max(pred(rs,val),tr[cur].val);
        else return pred(ls,val);
    }
    int nex(int cur,int val) {
        if(!cur) return inf;
        if(val < tr[cur].val) return min(nex(ls,val),tr[cur].val);
        else return nex(rs,val);
    }
}
using namespace treap;
int tree[n << 2];
int a[n];
int n;
void build(int rt,int l,int r) {
    if(l == r) {
        insert(tree[rt],a[l]);
        return;
    }
    int mid = (l + r) >> 1;
    for(int i = l;i <= r;++i) insert(tree[rt],a[i]);
    build(rt << 1,l,mid);
    build(rt << 1 | 1,mid + 1,r);
}
void delet(int rt,int l,int r,int pos,int c) {
    if(l == r) {
        insert(tree[rt],c);
        del(tree[rt],a[pos]);
        return;
    }
    insert(tree[rt],c);
    del(tree[rt],a[pos]);
    int mid = (l + r) >> 1;
    if(pos <= mid) delet(rt << 1,l,mid,pos,c);
    else delet(rt << 1 | 1,mid + 1,r,pos,c);
}
int getrank(int rt,int l,int r,int l,int r,int val) {
    if(l <= l && r >= r) return rank(tree[rt],val);
    int mid = (l + r) >> 1;
    int ans = 0;
    if(l <= mid) ans += getrank(rt << 1,l,mid,l,r,val);
    if(r > mid) ans += getrank(rt << 1 | 1,mid + 1,r,l, r,val);
    return ans;
}
int getpred(int rt,int l,int r,int l,int r,int val) {
    if(l <= l && r >= r) return pred(tree[rt],val);
    int mid = (l + r) >> 1;
    int ans = -inf;
    if(l <= mid) ans = max(ans,getpred(rt << 1,l,mid,l,r,val));
    if(r > mid) ans = max(ans,getpred(rt << 1 | 1,mid + 1,r,l, r,val));
    return ans;
}
int getnex(int rt,int l,int r,int l,int r,int val) {
    if(l <= l && r >= r) return nex(tree[rt],val);
    int mid = (l + r) >> 1;
    int ans = inf;
    if(l <= mid) ans = min(ans,getnex(rt << 1,l,mid,l,r,val));
    if(r > mid) ans = min(ans,getnex(rt << 1 | 1,mid + 1,r,l,r,val));
    return ans;
}
int max = -inf;
int getkth(int l,int r,int x) {
    int l = 0,r = inf;
    int ans = 0;
    while(l <= r) {
        int mid = (l + r) >> 1;
        if(getrank(1,1,n,l,r,mid) + 1<= x) ans = mid,l = mid + 1;
        else r = mid - 1;
    }
    return ans;
}
int main() {

    n = read();
    int m = read();
    for(int i = 1;i <= n;++i) a[i] = read();
    build(1,1,n);
    while(m--) {
        int opt = read();
        if(opt == 1) {
            int l = read(),r = read(),k = read();
            printf("%d\n",getrank(1,1,n,l,r,k) + 1);
        }
        if(opt == 2) {
            int l = read(),r = read(),k = read();
            printf("%d\n",getkth(l,r,k));
        }
        if(opt == 3) {
            int pos = read(),k = read();
            delet(1,1,n,pos,k);
            a[pos] = k;//!!!
        }
        if(opt == 4) {
            int l = read(),r = read(),k = read();
            printf("%d\n",getpred(1,1,n,l,r,k));
        }
        if(opt == 5) {
            int l = read(),r = read(),k = read();
            printf("%d\n",getnex(1,1,n,l,r,k));
        }
    }
    return 0;
}