专栏文章
Splay树(伸展树)实现详解
算法·理论参与者 1已保存评论 0
文章操作
快速查看文章及其快照的属性,并进行相关操作。
- 当前评论
- 0 条
- 当前快照
- 1 份
- 快照标识符
- @mina6z21
- 此快照首次捕获于
- 2025/12/01 23:05 3 个月前
- 此快照最后确认于
- 2025/12/01 23:05 3 个月前
Splay树(伸展树)实现详解
1. 数据结构定义
CPPconst int N=1e5;
int rt;//根节点
int val[N],cnt[N],siz[N];//数据域:值、计数、子树大小
int fa[N],son[N][2];//指针域:父节点、左右儿子
实现原理:
-
val[]: 存储节点的值 -
cnt[]: 存储相同值的出现次数(支持重复值) -
siz[]: 存储以该节点为根的子树大小(用于排名查询) -
fa[]: 存储父节点指针 -
son[][]: 存储左右子节点指针 -
rt: 整棵树的根节点
2. 辅助函数
2.1 方向判断函数
CPP//判断节点x的方位,返回0为左,1为右
int dir(int x){
return son[fa[x]][1]==x;
}
实现原理:
-
通过比较节点x与其父节点的子节点关系来确定方位
-
如果是父节点的右子节点返回1,左子节点返回0
2.2 更新函数
CPP//更新节点x的大小
void up(int x){
siz[x]=siz[son[x][0]]+siz[son[x][1]]+cnt[x];
}
实现原理:
-
节点的大小 = 左子树大小 + 右子树大小 + 当前节点计数
-
在旋转和插入删除操作后必须调用此函数维护正确的大小信息
3. 核心操作
3.1 旋转操作
CPP//旋转操作
void rotate(int x){
int y=fa[x],z=fa[y],s=dir(x);//y是x的父节点,z是y的父节点,s是x的方向
//处理x的与旋转方向相反的子节点
if(son[x][!s]){
fa[son[x][!s]]=y;//将x的反向子节点连接到y上
}
son[y][s]=son[x][!s];//y继承x的反向子节点
//处理x与y的关系
son[x][!s]=y;//x成为y的父节点
fa[y]=x;//y的父节点指向x
//处理x与z的关系
fa[x]=z;//x的父节点指向z
if(z){
son[z][son[z][1]==y]=x;//z的对应子节点指向x
}
up(y);//先更新y,因为y在x的下层
up(x);//再更新x
}
实现原理:
-
单旋操作:将节点x上移一层,保持BST性质
-
三种情况处理:
-
x的反向子节点连接到y
-
调整x与y的父子关系
-
调整x与z(祖父节点)的关系
-
-
更新顺序:先更新下层节点y,再更新上层节点x
4. 伸展操作
4.1 伸展到根节点
CPP//伸展操作:将节点x伸展到根节点
void splay_small(int x){
while(fa[x]){//一直旋转直到x成为根节点
int y=fa[x];//x的父节点
if(fa[y]){//如果y不是根节点,考虑双旋
if(dir(x)==dir(y)){
rotate(y);//一字型情况,先旋转y
}else{
rotate(x);//之字形情况,先旋转x
}
}
rotate(x);//最后旋转x
}
rt=x;//更新根节点
}
4.2 伸展到指定位置
CPP//伸展操作:将节点x伸展到p的位置
void splay(int x,int &p){
int z=fa[p];//目标位置的父节点
while(fa[x]!=z){//一直旋转直到x的父节点是z
int y=fa[x];//x的父节点
if(fa[y]!=z){//如果y的父节点不是z,考虑双旋
if(dir(x)==dir(y)){
rotate(y);//一字型情况,先旋转y
}else{
rotate(x);//之字形情况,先旋转x
}
}
rotate(x);//最后旋转x
}
p=x;//更新p指向x
}
实现原理:
-
伸展策略:通过旋转将被访问节点移动到根节点附近
-
双旋优化:
-
一字型(x和y同方向):先旋转父节点y
-
之字形(x和y不同方向):先旋转自己x
-
-
摊还分析:多次访问后树会趋于平衡,均摊时间复杂度O(log n)
5. 基本BST操作
5.1 插入操作
CPP//插入操作
void insert(int v){
static int idx=0;//静态变量,节点计数器
int x=rt,y=0;//x从根开始搜索,y记录父节点
while(x && val[x]!=v){//查找插入位置
x=son[y=x][v>val[x]];//根据大小关系选择左右子树
}
if(x){//如果节点已存在
cnt[x]++;//增加计数
siz[x]++;//更新大小
}else{//如果节点不存在,创建新节点
x=++idx;//分配新节点编号
val[x]=v;//设置节点值
cnt[x]=siz[x]=1;//初始化计数和大小
fa[x]=y;//设置父节点
if(y){//如果父节点存在
son[y][v>val[y]]=x;//将新节点连接到父节点的对应位置
}
}
splay(x,rt);//将新插入的节点伸展到根位置
}
实现原理:
-
查找位置:从根开始,根据BST性质找到插入位置
-
处理重复:如果值已存在,增加计数
-
创建节点:如果值不存在,创建新节点并连接到树中
-
伸展优化:将新节点伸展到根,提高后续访问效率
5.2 查找操作
CPP//查找值为v的节点并伸展到根
void find(int v){
int x=rt;
while(x && val[x]!=v){
x=son[x][v>val[x]];
}
if(x) splay(x,rt);
}
实现原理:
-
标准BST查找,找到后通过伸展操作将节点移动到根
-
即使没找到,也会将最后一个访问的节点伸展到根
6. 查询操作
6.1 排名查询
CPP//查询值v的排名(比v小的数的个数+1)
int rnk(int v){
find(v);//先查找v并伸展到根
if(val[rt]>=v){//如果根节点值>=v,排名在左子树中
return siz[son[rt][0]]+1;
}else{//如果根节点值<v,排名在右子树中
return siz[son[rt][0]]+cnt[rt]+1;
}
}
实现原理:
-
排名定义:比v小的元素个数 + 1
-
通过伸展将相关节点移动到根,利用子树大小信息计算排名
-
时间复杂度:O(log n)
6.2 第k小查询
CPP//查询第k小的值
int kth(int k){
int x=rt;
while(x){
if(k<=siz[son[x][0]]){//k在左子树中
x=son[x][0];
}else if(k<=siz[son[x][0]]+cnt[x]){//k在当前节点中
splay(x,rt);//将找到的节点伸展到根
return val[x];
}else{//k在右子树中
k-=siz[son[x][0]]+cnt[x];
x=son[x][1];
}
}
return -1;//未找到
}
实现原理:
-
利用子树大小信息在BST中二分查找
-
三种情况:
-
k在左子树:继续在左子树中查找
-
k在当前节点:返回当前节点值
-
k在右子树:调整k值后在右子树中查找
-
7. 前驱后继操作
7.1 前驱查询
CPP//查找前驱(小于v的最大值)
int pre(int v){
find(v);//先查找v并伸展到根
if(val[rt]<v) return val[rt];//如果根节点值小于v,直接返回
int x=son[rt][0];//否则在左子树中找最大值
if(!x) return -1;//不存在前驱
while(son[x][1]) x=son[x][1];//一直往右走
splay(x,rt);//将前驱伸展到根
return val[x];
}
7.2 后继查询
CPP//查找后继(大于v的最小值)
int nxt(int v){
find(v);//先查找v并伸展到根
if(val[rt]>v) return val[rt];//如果根节点值大于v,直接返回
int x=son[rt][1];//否则在右子树中找最小值
if(!x) return -1;//不存在后继
while(son[x][0]) x=son[x][0];//一直往左走
splay(x,rt);//将后继伸展到根
return val[x];
}
实现原理:
-
前驱:左子树中的最大值
-
后继:右子树中的最小值
-
通过伸展操作优化后续访问效率
8. 删除操作
CPP//删除值为v的节点
void del(int v){
find(v);//先查找v并伸展到根
if(val[rt]!=v) return;//不存在该值
if(cnt[rt]>1){//如果有多个相同值
cnt[rt]--;
siz[rt]--;
return;
}
//只有一个节点的情况
if(!son[rt][0] && !son[rt][1]){//没有子节点
rt=0;//树为空
}else if(!son[rt][0]){//只有右子树
fa[son[rt][1]]=0;
rt=son[rt][1];
}else if(!son[rt][1]){//只有左子树
fa[son[rt][0]]=0;
rt=son[rt][0];
}else{//有两个子节点
int x=son[rt][0];
while(son[x][1]) x=son[x][1];//在左子树中找最大值
splay(x,son[rt][0]);//将x伸展到左子树的根
//连接右子树
son[x][1]=son[rt][1];
fa[son[rt][1]]=x;
//更新根
fa[x]=0;
rt=x;
up(rt);//更新根节点大小
}
}
实现原理:
-
四种删除情况:
-
重复值:减少计数
-
叶节点:直接删除
-
单子树:用子节点替代
-
双子树:用前驱或后继替代(这里用前驱)
-
-
替代策略:用左子树的最大值替代被删除节点,保持BST性质
9. 测试示例
CPPint main(){
// 示例操作
cout << "Splay树操作示例:" << endl;
// 插入操作
insert(5);
insert(3);
insert(7);
insert(1);
insert(9);
insert(4);
insert(6);
cout << "插入 5, 3, 7, 1, 9, 4, 6 后:" << endl;
// 查询排名
cout << "数字4的排名: " << rnk(4) << endl; // 应该输出4
cout << "数字1的排名: " << rnk(1) << endl; // 应该输出1
// 查询第k小
cout << "第3小的数: " << kth(3) << endl; // 应该输出4
cout << "第5小的数: " << kth(5) << endl; // 应该输出6
// 查询前驱和后继
cout << "数字4的前驱: " << pre(4) << endl; // 应该输出3
cout << "数字4的后继: " << nxt(4) << endl; // 应该输出5
cout << "数字7的前驱: " << pre(7) << endl; // 应该输出6
cout << "数字7的后继: " << nxt(7) << endl; // 应该输出9
// 删除操作
cout << "\\\\n删除数字4后:" << endl;
del(4);
cout << "数字4的排名: " << rnk(4) << endl; // 应该输出4(因为4被删除了,现在排名4的是5)
cout << "第3小的数: " << kth(3) << endl; // 应该输出5
// 再次插入测试重复值
cout << "\\\\n再次插入数字5(重复值):" << endl;
insert(5);
cout << "数字5的排名: " << rnk(5) << endl; // 应该输出4或5,取决于实现
// 边界测试
cout << "\\\\n边界测试:" << endl;
cout << "最小值: " << kth(1) << endl; // 应该输出1
cout << "最大值: " << kth(6) << endl; // 应该输出9(因为现在有6个节点)
cout << "数字0的前驱: " << pre(0) << endl; // 应该输出-1(不存在)
cout << "数字10的后继: " << nxt(10) << endl; // 应该输出-1(不存在)
return 0;
}
10.无注释版本
CPP#include<bits/stdc++.h>
using namespace std;
const int N=1e5;
int rt;
int val[N],cnt[N],siz[N];
int fa[N],son[N][2];
int dir(int x){
return son[fa[x]][1]==x;
}
void up(int x){
siz[x]=siz[son[x][0]]+siz[son[x][1]]+cnt[x];
}
void rotate(int x){
int y=fa[x],z=fa[y],s=dir(x);
if(son[x][!s]){
fa[son[x][!s]]=y;
}
son[y][s]=son[x][!s];
son[x][!s]=y;
fa[y]=x;
fa[x]=z;
if(z){
son[z][son[z][1]==y]=x;
}
up(y);
up(x);
}
void splay_small(int x){
while(fa[x]){
int y=fa[x];
if(fa[y]){
if(dir(x)==dir(y)){
rotate(y);
}else{
rotate(x);
}
}
rotate(x);
}
rt=x;
}
void splay(int x,int &p){
int z=fa[p];
while(fa[x]!=z){
int y=fa[x];
if(fa[y]!=z){
if(dir(x)==dir(y)){
rotate(y);
}else{
rotate(x);
}
}
rotate(x);
}
p=x;
}
void insert(int v){
static int idx=0;
int x=rt,y=0;
while(x && val[x]!=v){
x=son[y=x][v>val[x]];
}
if(x){
cnt[x]++;
siz[x]++;
}else{
x=++idx;
val[x]=v;
cnt[x]=siz[x]=1;
fa[x]=y;
if(y){
son[y][v>val[y]]=x;
}
}
splay(x,rt);
}
void find(int v){
int x=rt;
while(x && val[x]!=v){
x=son[x][v>val[x]];
}
if(x) splay(x,rt);
}
int rnk(int v){
find(v);
if(val[rt]>=v){
return siz[son[rt][0]]+1;
}else{
return siz[son[rt][0]]+cnt[rt]+1;
}
}
int kth(int k){
int x=rt;
while(x){
if(k<=siz[son[x][0]]){
x=son[x][0];
}else if(k<=siz[son[x][0]]+cnt[x]){
splay(x,rt);
return val[x];
}else{
k-=siz[son[x][0]]+cnt[x];
x=son[x][1];
}
}
return -1;
}
int pre(int v){
find(v);
if(val[rt]<v) return val[rt];
int x=son[rt][0];
if(!x) return -1;
while(son[x][1]) x=son[x][1];
splay(x,rt);
return val[x];
}
int nxt(int v){
find(v);
if(val[rt]>v) return val[rt];
int x=son[rt][1];
if(!x) return -1;
while(son[x][0]) x=son[x][0];
splay(x,rt);
return val[x];
}
void del(int v){
find(v);
if(val[rt]!=v) return;
if(cnt[rt]>1){
cnt[rt]--;
siz[rt]--;
return;
}
if(!son[rt][0] && !son[rt][1]){
rt=0;
}else if(!son[rt][0]){
fa[son[rt][1]]=0;
rt=son[rt][1];
}else if(!son[rt][1]){
fa[son[rt][0]]=0;
rt=son[rt][0];
}else{
int x=son[rt][0];
while(son[x][1]) x=son[x][1];
splay(x,son[rt][0]);
son[x][1]=son[rt][1];
fa[son[rt][1]]=x;
fa[x]=0;
rt=x;
up(rt);
}
}
int main(){
return 0;
}
相关推荐
评论
共 0 条评论,欢迎与作者交流。
正在加载评论...