社区讨论

蒟蒻马蜂优美初学Splay,最后一个点TLE求跳

P3369【模板】普通平衡树参与者 1已保存回复 0

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
0 条
当前快照
1 份
快照标识符
@lo13vkm5
此快照首次捕获于
2023/10/22 14:46
2 年前
此快照最后确认于
2023/11/02 14:17
2 年前
查看原帖
rt,感谢神犇!!!
CPP
#include <bits/stdc++.h>
#define in inline
#define rint register int
#define r(a) runtimerror(a)
#define w(a) wronganswer(a)
#define wl(a) wronganswer(a),putchar('\n')
#define ws(a) wronganswer(a),putchar(' ')
#define ls s[0]
#define rs s[1]
using namespace std;
typedef long long ll;
int n,opt,x;
template <typename t> in void runtimerror(t &a){
	bool flag=false;char ch=getchar();a=0;
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') flag=true;
	for(;isdigit(ch);ch=getchar()) a=(a<<3)+(a<<1)+(ch^48);
	a=flag?~a+1:a;
}
template <typename t> void compilerror(t a){
	if(a>9) compilerror(a/10);
	putchar(a%10^48);
}
template <typename t> void wronganswer(t a){
	if(a<0) a=-a,putchar('-');
	compilerror(a);
}
struct splay_tree{
	int root,cnt;
	struct node{
		int s[2],val,fa,cnt,size;
		in void init(int v,int f){
			val=v,fa=f,size=cnt=1;
		}
	}tr[1100010];
	in void pushup(int id){
		tr[id].size=tr[tr[id].ls].size+tr[tr[id].rs].size+tr[id].cnt;
	}
	in void rotate(int id){
		int fa=tr[id].fa,gr=tr[fa].fa,tos=tr[fa].rs==id;
		tr[gr].s[tr[gr].rs==fa]=id,tr[id].fa=gr;
		tr[fa].s[tos]=tr[id].s[!tos],tr[tr[id].s[!tos]].fa=fa;
		tr[id].s[!tos]=fa,tr[fa].fa=id;
		pushup(fa),pushup(id);
	}
	in void splay(int id,int aim){
		while(tr[id].fa!=aim){
			int fa=tr[id].fa,gr=tr[fa].fa;
			if(gr!=aim){
				if((tr[fa].rs==id)^(tr[gr].rs==fa)) rotate(id);
				else rotate(fa);
			}
			rotate(id);
		}
		if(!aim) root=id;
	}
	in void insert(int val){
		int id=root,fa=0;
		while(id&&tr[id].val!=val) fa=id,id=tr[id].s[val>tr[id].val];
		if(id){
			tr[id].cnt++;
		}else{
			id=++cnt;
			if(fa) tr[fa].s[val>tr[fa].val]=id;
			tr[id].init(val,fa);
		}
		splay(id,0);
	}
	int query_id(int id,int val){
		if(!id) return 1;
		if(tr[id].val>val) return query_id(tr[id].ls,val);
		if(tr[id].val<val) return query_id(tr[id].rs,val);
		splay(id,0);
		return id;
	}
	int query_rank(int id,int val){
		if(!id) return 1;
		if(tr[id].val>val) return query_rank(tr[id].ls,val);
		if(tr[id].val<val) return query_rank(tr[id].rs,val)+tr[tr[id].ls].size+tr[id].cnt;
		return tr[tr[id].ls].size+1;
	}
	int query_val(int id,int rank){
		if(!id) return INT_MAX;
		if(tr[tr[id].ls].size>=rank) return query_val(tr[id].ls,rank);
		if(tr[tr[id].ls].size+tr[id].cnt>=rank){
		    splay(id,0);
		    return tr[id].val;
		}
		return query_val(tr[id].rs,rank-tr[tr[id].ls].size-tr[id].cnt);
	}
	int query_pre(int id,int val){
		if(!id) return -INT_MAX;
		if(tr[id].val>=val) return query_pre(tr[id].ls,val);
		return max(tr[id].val,query_pre(tr[id].rs,val));
	}
	int query_nex(int id,int val){
		if(!id) return INT_MAX;
		if(tr[id].val<=val) return query_nex(tr[id].rs,val);
		return min(tr[id].val,query_nex(tr[id].ls,val));
	}
	in void remove(int x){
		int l=query_id(root,query_pre(root,x));
		int r=query_id(root,query_nex(root,x));
		splay(r,0),splay(l,r);
		if(tr[tr[l].rs].cnt>1) tr[tr[l].rs].cnt--,tr[tr[l].rs].size--;
		else tr[l].rs=0;
		pushup(l),pushup(r);
	}
}t;
int main(){
	r(n);
	t.insert(-INT_MAX);
	t.insert(INT_MAX);
	while(n--){
		r(opt),r(x);
		switch(opt){
			case 1:
				t.insert(x);
				break;
			case 2:
				t.remove(x);
				break;
			case 3:
				wl(t.query_rank(t.root,x)-1);
				break;
			case 4:
				wl(t.query_val(t.root,x+1));
				break;
			case 5:
				wl(t.query_pre(t.root,x));
				break;
			case 6:
				wl(t.query_nex(t.root,x));
				break;
		}
	}
	return 0;
}

回复

0 条回复,欢迎继续交流。

正在加载回复...