社区讨论

30pts re+ac 求条

P3369【模板】普通平衡树参与者 1已保存回复 1

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
1 条
当前快照
1 份
快照标识符
@mjto552o
此快照首次捕获于
2025/12/31 15:02
2 个月前
此快照最后确认于
2026/01/02 22:45
2 个月前
查看原帖
CPP
#include<bits/stdc++.h>
using namespace std;
int ch[10000010][2],sz[10000010],cnt[10000010],rt,val[10000010],fa[10000010],id;
bool dir(int x){
	return x==ch[fa[x]][1];
}
void push_up(int x){
	sz[x]=cnt[x]+sz[ch[x][0]]+sz[ch[x][1]];
}
void rotate(int x){//旋转操作 
	int y=fa[x],z=fa[y];
	bool r=dir(x);
	ch[y][r]=ch[x][!r];
	ch[x][!r]=y;
	if(z){
		ch[z][dir(y)]=x;
	}
	if(ch[y][r]){
		fa[ch[y][r]]=y;
	}
	fa[y]=x;
	fa[x]=z;
	push_up(y);
	push_up(z);
}
void splay(int &z,int x){//伸展 
	int w=fa[z];
	for(int y;(y=fa[x])!=w;rotate(x)){
		if(fa[y]!=w){
			rotate(dir(x)==dir(y)?y:x);
		}
	}
	z=x;
}
/*
伸展操作是 Splay 树的核心操作,
也是它的时间复杂度能够得到保证的关键步骤。
请务必保证每次向下访问节点后,
都进行一次伸展操作。
*/
void find(int &z,int v){
	int x=z,y=fa[x];
	for(;x&&val[x]!=v;x=ch[y=x][v>val[x]]){
		;
	}
	splay(z,x?x:y);
}
void loc(int &z,int k){
	int x=z;
	while(true){
		if(sz[ch[x][0]]>=k){
			x=ch[x][0];
		}else if(sz[ch[x][0]]+cnt[x]>=k){
			break;
		}else{
			k-=sz[ch[x][0]]+cnt[x];
			x=ch[x][1];
		}
	}
}
int find_kth(int k){
	if(k>sz[rt]){
		return -1;
	}
	loc(rt,k);
	return val[rt];
}
int merge(int x,int y){
	if(!x||!y){
		return x|y;
	}
	loc(y,1);
	ch[y][0]=x;
	push_up(y);
	return y;
}
void insert(int v){
	int x=rt,y=0;
	for(;x&&val[x]!=v;x=ch[y=x][v>val[x]]){
		;
	}
	if(x){
		++cnt[x];
		++sz[x];
	}else{
		x=++id;
		val[x]=v;
		cnt[x]=sz[x]=1;
		fa[x]=y;
		if(y){
			ch[y][v>val[y]]=x;
		}
	}
	splay(rt,x);
}
bool remove(int v){
	find(rt,v);
	if(!rt||val[rt]!=v){
		return false;
	}
	--cnt[rt];
	--sz[rt];
	if(!cnt[rt]){
		int x=ch[rt][0];
		int y=ch[rt][1];
		fa[x]=fa[y]=0;
		rt=merge(x,y);
	}
}
int find_rank(int v){
	find(rt,v);
	return sz[ch[rt][0]]+(val[rt]<v?cnt[rt]:0)+1;
}
int find_prev(int v){
	find(rt,v);
	if(rt&&val[rt]<v){
		return val[rt];
	}
	int x=ch[rt][0];
	if(!x){
		return -1;
	}
	for(;ch[x][1];x=ch[x][1]){
		;
	}
	splay(rt,x);
	return val[rt];
}
int find_next(int v){
    find(rt,v);
    if(rt&&val[rt]>v){
        return val[rt];
    }
    int x=ch[rt][1];
    for(;ch[x][0];x=ch[x][0]){
        ;
    }
    splay(rt,x);
    return val[rt];
}
int main(){
	int n;
	scanf("%d",&n);
	while(n--){
		int op,x;
		scanf("%d%d",&op,&x);
		switch(op){
			case 1:
				insert(x);
				break;
			case 2:
				remove(x);
				break;
			case 3:
				printf("%d\n",find_rank(x));
				break;
			case 4:
				printf("%d\n",find_kth(x));
				break;
			case 5:
				printf("%d\n",find_prev(x));
				break;
			case 6:
				printf("%d\n",find_next(x));
				break;
		}
	}
	return 0;
}

回复

1 条回复,欢迎继续交流。

正在加载回复...