主席树专题(区间第k小,可持久化数组)
程序员文章站
2022-03-11 19:22:54
...
相关题目
程序实现
以下是区间第k小(大)的封装.
struct chairmantree{
private :
int lc[maxn<<5],rc[maxn<<5],sum[maxn<<5];
public :
int build(int ll,int rr){
int rt=++cnt;
if(ll<rr){
lc[rt]=build(ll,mid);
rc[rt]=build(mid+1,rr);
}
return rt;
}
int update(int pre,int ll,int rr,int pos){
int rt=++cnt;
lc[rt]=lc[pre],rc[rt]=rc[pre],sum[rt]=sum[pre]+1;
if(ll<rr){
if(pos<=mid)lc[rt]=update(lc[pre],ll,mid,pos);
else rc[rt]=update(rc[pre],mid+1,rr,pos);
}
return rt;
}
int query(int u,int v,int ll,int rr,int rank){
if(ll==rr)return ll;
int x=sum[lc[v]]-sum[lc[u]];
if(rank<=x)return query(lc[u],lc[v],ll,mid,rank);
else return query(rc[u],rc[v],mid+1,rr,rank-x);
}//以上为区间第k小
int query(int u,int v,int ll,int rr,int rank){
if(ll==rr)return ll;
int x=sum[rc[v]]-sum[rc[u]];
if(rank<=x)return query(rc[u],rc[v],mid+1,rr,rank);
return query(lc[u],lc[v],ll,mid,rank-x);
}//以上为区间第k大
}t;
以下是区间第k小(大)的主函数部分.
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
b[i]=a[i];
}
sort(b+1,b+n+1);
m=unique(b+1,b+n+1)-(b+1);
tree[0]=build(1,m);
for(int i=1;i<=n;i++){
int op=lower_bound(b+1,b+m+1,a[i])-b;
tree[i]=update(tree[i-1],1,m,op);//离散化
}
int x,y,k;
for(int i=1;i<=q;i++){
scanf("%d%d%d",&x,&y,&k);
int op=query(tree[x-1],tree[y],1,m,k);
printf("%d\n",b[op]);
}
return 0;
}
以下是可持久化数组的封装.
struct chairmantree{
private:
int val[maxn<<5],l[maxn<<5],r[maxn<<5],deep[maxn<<5];
public:
int build(int ll,int rr){
int rt=++cnt;
if(ll==rr){
val[rt]=fa[ll];
deep[rt]+=1;
return rt;
}
l[rt]=build(ll,mid);
r[rt]=build(mid+1,rr);
return rt;
}
int update(int pre,int ll,int rr,int pos,int value){
int rt=++cnt;
val[rt]=val[pre];l[rt]=l[pre];r[rt]=r[pre];
if(ll==rr){val[rt]=value;deep[rt]+=1;return rt;}
if(pos<=mid)l[rt]=update(l[pre],ll,mid,pos,value);
else r[rt]=update(r[pre],mid+1,rr,pos,value);
return rt;
}
int query(int u,int ll,int rr,int pos){
if(ll==rr)return val[u];
if(pos<=mid)return query(l[u],ll,mid,pos);
else return query(r[u],mid+1,rr,pos);
}
}t;
以下是可持久化并查集的程序实现.
#include<bits/stdc++.h>
#define mid ((ll+rr)>>1)
#define maxn 100010
using namespace std;
int n,q,cnt,root[maxn],fa[maxn];
struct chairmantree{
private:
int val[maxn<<5],l[maxn<<5],r[maxn<<5],deep[maxn<<5];
public:
int build(int ll,int rr){
int rt=++cnt;
if(ll==rr){
val[rt]=fa[ll];
deep[rt]+=1;
return rt;
}
l[rt]=build(ll,mid);
r[rt]=build(mid+1,rr);
return rt;
}
int update(int pre,int ll,int rr,int pos,int value){
int rt=++cnt;
val[rt]=val[pre];l[rt]=l[pre];r[rt]=r[pre];
if(ll==rr){val[rt]=value;deep[rt]+=1;return rt;}
if(pos<=mid)l[rt]=update(l[pre],ll,mid,pos,value);
else r[rt]=update(r[pre],mid+1,rr,pos,value);
return rt;
}
int query(int u,int ll,int rr,int pos){
if(ll==rr)return val[u];
if(pos<=mid)return query(l[u],ll,mid,pos);
else return query(r[u],mid+1,rr,pos);
}
int find(int edit,int pos){
int x=query(edit,1,n,pos);
if(x==pos)return x;
else return find(edit,x);
}
void update_depth(int u,int ll,int rr,int pos){
if(ll==rr){deep[u]+=1;return ;}
if(pos<=mid)update_depth(l[u],ll,mid,pos);
else update_depth(r[u],mid+1,rr,pos);
}
int query_depth(int u,int ll,int rr,int pos){
if(ll==rr)return deep[u];
if(pos<=mid)return query_depth(l[u],ll,mid,pos);
else return query_depth(r[u],mid+1,rr,pos);
}
}t;
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++)fa[i]=i;
root[0]=t.build(1,n);
int a1,b,c;
for(int i=1;i<=q;i++){
scanf("%d",&a1);
if(a1==1){
scanf("%d%d",&b,&c);
root[i]=root[i-1];
int fx=t.find(root[i-1],b);
int fy=t.find(root[i-1],c);
if(fx==fy)continue;
int dx=t.query_depth(root[i-1],1,n,fx) ;
int dy=t.query_depth(root[i-1],1,n,fy) ;//查询深度,按秩合并
if(dx<dy)swap(fx,fy); //保证fx的深度更大
root[i]=t.update(root[i-1],1,n,fy,fx);
if(dx==dy) t.update_depth (root[i],1,n,fx);//更新深度,保证fx的深度更大
}
else if(a1==2){
scanf("%d",&b);
root[i]=root[b];
}
else {
scanf("%d%d",&b,&c);
root[i]=root[i-1];
int fx=t.find(root[i],b);
int fy=t.find(root[i],c);
if(fx==fy)printf("1\n");
else printf("0\n");
}
}
return 0;
}
附按秩合并的普通并查集.
#include<bits/stdc++.h>
#define maxn 200010
using namespace std;
int n,m;
int fa[maxn],deep[maxn];
int find(int x){return fa[x]==x?x:find(fa[x]);}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
fa[i]=i;deep[i]=1;
}
int x,y,z;
for(int i=1;i<=m;i++){
scanf("%d%d%d",&z,&x,&y);
int fx=find(x);
int fy=find(y);
if(z==1){
if(fx!=fy){
if(deep[fx]>deep[fy])swap(fx,fy);
fa[fx]=fy;
if(deep[fx]==deep[fy])deep[fy]++;
}
else continue;
}
else {
if(fx!=fy)printf("N\n");
else printf("Y\n");
}
}
return 0;
}