欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

SplayTree实现_简化版

程序员文章站 2022-03-23 13:53:26
...

(1)

template<class DataType>
class SplayTree{
#define MAXN 1000010
private:
    int ch[MAXN][2],pre[MAXN],pool[MAXN];
    DataType key[MAXN];
    int top,root,tot;
    int malloc(DataType dt){
        int x;
        if(top!=0) x = pool[--top];
        else x = ++tot;
        key[x] = dt;
        return x;
    }
    void free(int x){
        ch[x][0] = ch[x][1] = pre[x] = 0;
        pool[top++] = x;
    }
    void rotate(int x,int f){//f == 0 为zag f == 1 为zig
        int y = pre[x]; int z = pre[y];
        ch[y][!f] = ch[x][f];
        if(ch[y][!f]) pre[ch[y][!f]] = y;
        ch[x][f] = y;
        pre[y] = x;
        if(z) ch[z][0]==y?ch[z][0] = x:ch[z][1] = x;
        pre[x] = z;
    }
    void splay(int x,int &rt){
        int y,z;
        while(pre[x]){
            y = pre[x]; z = pre[y];
            if(!z){
                rotate(x,ch[y][0]==x);
            }else{
                int f = ch[z][0]==y;
                if(ch[y][!f]==x){
                    rotate(y,f); rotate(x,f);
                }else{
                    rotate(x,!f); rotate(x,f);
                }
            }
        }
        rt = x;
    }
    int find_help(DataType k,int rt){
        if(!rt) return 0;
        if(k==key[rt])return rt;
        return k<key[rt]?find_help(k,ch[rt][0]):find_help(k,ch[rt][1]);
    }
    int findmax_help(int rt){
        int x = rt;
        while(x&&ch[x][1]) x = ch[x][1];
        return x;
    }
    int findmin_help(int rt){
        int x = rt;
        while(x&&ch[x][0]) x = ch[x][0];
        return x;
    }
    int join(int rt1,int rt2){
        if(!rt1) return rt2;
        if(!rt2) return rt1;
        int x = findmax_help(rt1);
        splay(x,rt1);
        ch[rt1][1] = rt2;
        pre[rt2] = rt1;
        return rt1;
    }
    int split(DataType k,int &rt,int &rt1,int &rt2){
        int x = find_help(k,rt);
        if(!x) return 0;
        splay(x,rt);
        rt1 = ch[rt][0]?ch[rt][0]:0;
        pre[rt1] = 0;
        rt2 = ch[rt][1]?ch[rt][1]:0;
        pre[rt2] = 0;
        return rt;
    }
    int insert_help(DataType k,int &rt,int father){
        if(!rt){
            rt = malloc(k);
            pre[rt] = father;
            return rt;
        }
        return insert_help(k,ch[rt][!(k<key[rt])],rt);
    }
public:
    void insert(DataType k){
        int x = insert_help(k,root,0);
        splay(x,root);
    }
    void remove(DataType k){
        int rt1,rt2;
        int x = split(k,root,rt1,rt2);
        if(!x) return;
        free(x);
        root = join(rt1,rt2);
    }
    int findmax(DataType &k){
        int x = findmax_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
    int findmin(DataType &k){
        int x = findmin_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
};

(2)

template<class DataType>
class SplayTree{
#define MAXN 1000010
private:
    int ch[MAXN][2],pre[MAXN],pool[MAXN];
    DataType key[MAXN];
    int top,root,tot;
    int malloc(DataType dt){
        int x;
        if(top!=0) x = pool[--top];
        else x = ++tot;
        key[x] = dt;
        return x;
    }
    void free(int x){
        ch[x][0] = ch[x][1] = pre[x] = 0;
        pool[top++] = x;
    }
    void rotate(int x,int f){//f == 0 为zag f == 1 为zig
        int y = pre[x]; int z = pre[y];
        ch[y][!f] = ch[x][f];
        if(ch[y][!f]) pre[ch[y][!f]] = y;
        ch[x][f] = y;
        pre[y] = x;
        if(z) ch[z][0]==y?ch[z][0] = x:ch[z][1] = x;
        pre[x] = z;
    }
    void splay(int x,int &rt){
        int y,z;
        while(pre[x]){
            y = pre[x]; z = pre[y];
            if(!z){
                rotate(x,ch[y][0]==x);
            }else{
                int f = ch[z][0]==y;
                if(ch[y][!f]==x){
                    rotate(y,f); rotate(x,f);
                }else{
                    rotate(x,!f); rotate(x,f);
                }
            }
        }
        rt = x;
    }
    int find_help(DataType k,int rt){
        if(!rt) return 0;
        if(k==key[rt])return rt;
        return k<key[rt]?find_help(k,ch[rt][0]):find_help(k,ch[rt][1]);
    }
    int findmax_help(int rt){
        int x = rt;
        while(x&&ch[x][1]) x = ch[x][1];
        return x;
    }
    int findmin_help(int rt){
        int x = rt;
        while(x&&ch[x][0]) x = ch[x][0];
        return x;
    }
    int insert_help(DataType k,int &rt,int father){
        if(!rt){
            rt = malloc(k);
            pre[rt] = father;
            return rt;
        }
        return insert_help(k,ch[rt][!(k<key[rt])],rt);
    }
    void remove_help(DataType k,int &rt,int father){
        if(!rt) return;
        if(k==key[rt]){
            if(ch[rt][0]==0||ch[rt][1]==0){
                int x = rt;
                rt = ch[rt][0]+ch[rt][1];
                if(rt){ pre[rt] = father; splay(rt,root); }
                free(x);
                return;
            }
            int x = findmin_help(ch[rt][1]);
            key[rt] = key[x];
            remove_help(key[rt],ch[rt][1],rt);
            splay(rt,root);
        }
        remove_help(k,ch[rt][!(k<key[rt])],rt);
    }
public:
    void insert(DataType k){
        int x = insert_help(k,root,0);
        splay(x,root);
    }
    void remove(DataType k){
        remove_help(k,root,0);
    }
    int findmax(DataType &k){
        int x = findmax_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
    int findmin(DataType &k){
        int x = findmin_help(root);
        if(x) splay(x,root);
        k = key[x];
        return x;
    }
};