社区讨论
Treap样例都不过
P3369【模板】普通平衡树参与者 3已保存回复 3
讨论操作
快速查看讨论及其快照的属性,并进行相关操作。
- 当前回复
- 3 条
- 当前快照
- 1 份
- 快照标识符
- @lo12yz1w
- 此快照首次捕获于
- 2023/10/22 14:21 2 年前
- 此快照最后确认于
- 2023/11/02 13:50 2 年前
rt
CPP#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 1e6+10, INF = 1e8;
struct Node {
int ls, rs;
int key, val;
int siz, cnt;
} tr[N];
int root, idx;
int get_node (int key) {
tr[++ idx].key = key;
tr[idx].val = rand();
tr[idx].cnt = tr[idx].siz = 1;
return idx;
}
void push_up (int &p) {
tr[p].siz = tr[tr[p].ls].siz + tr[tr[p].rs].siz + tr[p].cnt;
}
void build () {
get_node(-INF), get_node(INF);
tr[1].rs = 2, root = 1;
push_up(root);
}
void lt (int &p) {
int q = tr[p].rs;
tr[p].rs = tr[q].ls, tr[q].ls = p, p = q;
push_up(tr[p].ls), push_up(p);
}
void rt (int &p) {
int q = tr[p].ls;
tr[p].ls = tr[q].rs, tr[q].rs = p, p = q;
push_up(tr[p].rs), push_up(p);
}
void insert (int &p, int key) {
if (!p) get_node (key);
else if (tr[p].key == key) tr[p].cnt ++;
else if (key < tr[p].key){
insert (tr[p].ls, key);
if (tr[tr[p].ls].val > tr[p].val) rt (p);
} else {
insert (tr[p].rs, key);
if (tr[tr[p].rs].val > tr[p].val) lt (p);
}
push_up (p);
}
void remove (int &p, int key) {
if(!p) return;
else if(tr[p].key == key) {
if(tr[p].cnt > 1) tr[p].cnt --;
else if(tr[p].ls || tr[p].rs) {
if(!tr[p].rs || tr[tr[p].ls].val > tr[tr[p].rs].val) {
rt (p);
remove (tr[p].rs, key);
} else {
lt (p);
remove (tr[p].ls, key);
}
} else p = 0;
} else if(key < tr[p].key) remove (tr[p].ls, key);
else remove (tr[p].rs, key);
push_up (p);
}
int rk (int p, int key) {
if(!p) return 0;
else if(tr[p].key == key) return tr[tr[p].ls].siz + 1;
else if(key < tr[p].key) return rk (tr[p].ls, key);
return tr[tr[p].ls].siz + tr[p].cnt + rk (tr[p].rs, key);
}
int kr (int p, int rank) {
if(!p) return INF;
else if(tr[tr[p].ls].siz >= rank) return kr(tr[p].ls, rank);
else if(tr[tr[p].ls].siz + tr[p].cnt >= rank) return tr[p].key;
return kr(tr[p].rs, rank - tr[tr[p].ls].siz - tr[p].siz);
}
int get_pre (int p, int key) {
if(!p) return -INF;
if(key < tr[p].key) return get_pre(tr[p].ls, key);
return max(tr[p].key, get_pre(tr[p].rs, key));
}
int get_nxt (int p, int key) {
if(!p) return INF;
if(tr[p].key <= key) return get_nxt(tr[p].rs, key);
return min(tr[p].key, get_nxt(tr[p].ls, key));
}
int main() {
freopen("in", "r", stdin);
freopen("out", "w", stdout);
build();
int n; scanf("%d", &n);
while(n --) {
int opt, x;
scanf("%d%d", &opt, &x);
if(opt == 1) insert(root, x);
else if(opt == 2) remove(root, x);
else if(opt == 3) printf("%d\n", rk(root, x) - 1);
else if(opt == 4) printf("%d\n", kr(root, x + 1));
else if(opt == 5) printf("%d\n", get_pre(root, x));
else printf("%d\n", get_nxt(root, x));
}
fclose(stdin);
fclose(stdout);
return 0;
}
回复
共 3 条回复,欢迎继续交流。
正在加载回复...