专栏文章

Splay树(伸展树)实现详解

算法·理论参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@mina6z21
此快照首次捕获于
2025/12/01 23:05
3 个月前
此快照最后确认于
2025/12/01 23:05
3 个月前
查看原文

Splay树(伸展树)实现详解

1. 数据结构定义

CPP
const int N=1e5;
int rt;//根节点
int val[N],cnt[N],siz[N];//数据域:值、计数、子树大小
int fa[N],son[N][2];//指针域:父节点、左右儿子

实现原理
  • val[]: 存储节点的值
  • cnt[]: 存储相同值的出现次数(支持重复值)
  • siz[]: 存储以该节点为根的子树大小(用于排名查询)
  • fa[]: 存储父节点指针
  • son[][]: 存储左右子节点指针
  • rt: 整棵树的根节点

2. 辅助函数

2.1 方向判断函数

CPP
//判断节点x的方位,返回0为左,1为右
int dir(int x){
	return son[fa[x]][1]==x;
}

实现原理
  • 通过比较节点x与其父节点的子节点关系来确定方位
  • 如果是父节点的右子节点返回1,左子节点返回0

2.2 更新函数

CPP
//更新节点x的大小
void up(int x){
	siz[x]=siz[son[x][0]]+siz[son[x][1]]+cnt[x];
}

实现原理
  • 节点的大小 = 左子树大小 + 右子树大小 + 当前节点计数
  • 在旋转和插入删除操作后必须调用此函数维护正确的大小信息

3. 核心操作

3.1 旋转操作

CPP
//旋转操作
void rotate(int x){
	int y=fa[x],z=fa[y],s=dir(x);//y是x的父节点,z是y的父节点,s是x的方向

	//处理x的与旋转方向相反的子节点
	if(son[x][!s]){
		fa[son[x][!s]]=y;//将x的反向子节点连接到y上
	}
	son[y][s]=son[x][!s];//y继承x的反向子节点

	//处理x与y的关系
	son[x][!s]=y;//x成为y的父节点
	fa[y]=x;//y的父节点指向x

	//处理x与z的关系
	fa[x]=z;//x的父节点指向z
	if(z){
		son[z][son[z][1]==y]=x;//z的对应子节点指向x
	}
	up(y);//先更新y,因为y在x的下层
	up(x);//再更新x
}

实现原理
  • 单旋操作:将节点x上移一层,保持BST性质
  • 三种情况处理
    1. x的反向子节点连接到y
    2. 调整x与y的父子关系
    3. 调整x与z(祖父节点)的关系
  • 更新顺序:先更新下层节点y,再更新上层节点x

4. 伸展操作

4.1 伸展到根节点

CPP
//伸展操作:将节点x伸展到根节点
void splay_small(int x){
	while(fa[x]){//一直旋转直到x成为根节点
		int y=fa[x];//x的父节点
		if(fa[y]){//如果y不是根节点,考虑双旋
			if(dir(x)==dir(y)){
				rotate(y);//一字型情况,先旋转y
			}else{
				rotate(x);//之字形情况,先旋转x
			}
		}
		rotate(x);//最后旋转x
	}
	rt=x;//更新根节点
}

4.2 伸展到指定位置

CPP
//伸展操作:将节点x伸展到p的位置
void splay(int x,int &p){
	int z=fa[p];//目标位置的父节点
	while(fa[x]!=z){//一直旋转直到x的父节点是z
		int y=fa[x];//x的父节点
		if(fa[y]!=z){//如果y的父节点不是z,考虑双旋
			if(dir(x)==dir(y)){
				rotate(y);//一字型情况,先旋转y
			}else{
				rotate(x);//之字形情况,先旋转x
			}
		}
		rotate(x);//最后旋转x
	}
	p=x;//更新p指向x
}

实现原理
  • 伸展策略:通过旋转将被访问节点移动到根节点附近
  • 双旋优化
    • 一字型(x和y同方向):先旋转父节点y
    • 之字形(x和y不同方向):先旋转自己x
  • 摊还分析:多次访问后树会趋于平衡,均摊时间复杂度O(log n)

5. 基本BST操作

5.1 插入操作

CPP
//插入操作
void insert(int v){
    static int idx=0;//静态变量,节点计数器
	int x=rt,y=0;//x从根开始搜索,y记录父节点
	while(x && val[x]!=v){//查找插入位置
		x=son[y=x][v>val[x]];//根据大小关系选择左右子树
	}
	if(x){//如果节点已存在
		cnt[x]++;//增加计数
		siz[x]++;//更新大小
	}else{//如果节点不存在,创建新节点
		x=++idx;//分配新节点编号
		val[x]=v;//设置节点值
		cnt[x]=siz[x]=1;//初始化计数和大小
		fa[x]=y;//设置父节点
		if(y){//如果父节点存在
			son[y][v>val[y]]=x;//将新节点连接到父节点的对应位置
		}
	}
	splay(x,rt);//将新插入的节点伸展到根位置
}

实现原理
  1. 查找位置:从根开始,根据BST性质找到插入位置
  2. 处理重复:如果值已存在,增加计数
  3. 创建节点:如果值不存在,创建新节点并连接到树中
  4. 伸展优化:将新节点伸展到根,提高后续访问效率

5.2 查找操作

CPP
//查找值为v的节点并伸展到根
void find(int v){
	int x=rt;
	while(x && val[x]!=v){
		x=son[x][v>val[x]];
	}
	if(x) splay(x,rt);
}

实现原理
  • 标准BST查找,找到后通过伸展操作将节点移动到根
  • 即使没找到,也会将最后一个访问的节点伸展到根

6. 查询操作

6.1 排名查询

CPP
//查询值v的排名(比v小的数的个数+1)
int rnk(int v){
	find(v);//先查找v并伸展到根
	if(val[rt]>=v){//如果根节点值>=v,排名在左子树中
		return siz[son[rt][0]]+1;
	}else{//如果根节点值<v,排名在右子树中
		return siz[son[rt][0]]+cnt[rt]+1;
	}
}

实现原理
  • 排名定义:比v小的元素个数 + 1
  • 通过伸展将相关节点移动到根,利用子树大小信息计算排名
  • 时间复杂度:O(log n)

6.2 第k小查询

CPP
//查询第k小的值
int kth(int k){
	int x=rt;
	while(x){
		if(k<=siz[son[x][0]]){//k在左子树中
			x=son[x][0];
		}else if(k<=siz[son[x][0]]+cnt[x]){//k在当前节点中
			splay(x,rt);//将找到的节点伸展到根
			return val[x];
		}else{//k在右子树中
			k-=siz[son[x][0]]+cnt[x];
			x=son[x][1];
		}
	}
	return -1;//未找到
}

实现原理
  • 利用子树大小信息在BST中二分查找
  • 三种情况:
    1. k在左子树:继续在左子树中查找
    2. k在当前节点:返回当前节点值
    3. k在右子树:调整k值后在右子树中查找

7. 前驱后继操作

7.1 前驱查询

CPP
//查找前驱(小于v的最大值)
int pre(int v){
	find(v);//先查找v并伸展到根
	if(val[rt]<v) return val[rt];//如果根节点值小于v,直接返回
	int x=son[rt][0];//否则在左子树中找最大值
	if(!x) return -1;//不存在前驱
	while(son[x][1]) x=son[x][1];//一直往右走
	splay(x,rt);//将前驱伸展到根
	return val[x];
}

7.2 后继查询

CPP
//查找后继(大于v的最小值)
int nxt(int v){
	find(v);//先查找v并伸展到根
	if(val[rt]>v) return val[rt];//如果根节点值大于v,直接返回
	int x=son[rt][1];//否则在右子树中找最小值
	if(!x) return -1;//不存在后继
	while(son[x][0]) x=son[x][0];//一直往左走
	splay(x,rt);//将后继伸展到根
	return val[x];
}

实现原理
  • 前驱:左子树中的最大值
  • 后继:右子树中的最小值
  • 通过伸展操作优化后续访问效率

8. 删除操作

CPP
//删除值为v的节点
void del(int v){
	find(v);//先查找v并伸展到根
	if(val[rt]!=v) return;//不存在该值

	if(cnt[rt]>1){//如果有多个相同值
		cnt[rt]--;
		siz[rt]--;
		return;
	}

	//只有一个节点的情况
	if(!son[rt][0] && !son[rt][1]){//没有子节点
		rt=0;//树为空
	}else if(!son[rt][0]){//只有右子树
		fa[son[rt][1]]=0;
		rt=son[rt][1];
	}else if(!son[rt][1]){//只有左子树
		fa[son[rt][0]]=0;
		rt=son[rt][0];
	}else{//有两个子节点
		int x=son[rt][0];
		while(son[x][1]) x=son[x][1];//在左子树中找最大值
		splay(x,son[rt][0]);//将x伸展到左子树的根

		//连接右子树
		son[x][1]=son[rt][1];
		fa[son[rt][1]]=x;

		//更新根
		fa[x]=0;
		rt=x;
		up(rt);//更新根节点大小
	}
}

实现原理
  • 四种删除情况
    1. 重复值:减少计数
    2. 叶节点:直接删除
    3. 单子树:用子节点替代
    4. 双子树:用前驱或后继替代(这里用前驱)
  • 替代策略:用左子树的最大值替代被删除节点,保持BST性质

9. 测试示例

CPP
int main(){
	// 示例操作
	cout << "Splay树操作示例:" << endl;

	// 插入操作
	insert(5);
	insert(3);
	insert(7);
	insert(1);
	insert(9);
	insert(4);
	insert(6);

	cout << "插入 5, 3, 7, 1, 9, 4, 6 后:" << endl;

	// 查询排名
	cout << "数字4的排名: " << rnk(4) << endl; // 应该输出4
	cout << "数字1的排名: " << rnk(1) << endl; // 应该输出1

	// 查询第k小
	cout << "第3小的数: " << kth(3) << endl; // 应该输出4
	cout << "第5小的数: " << kth(5) << endl; // 应该输出6

	// 查询前驱和后继
	cout << "数字4的前驱: " << pre(4) << endl; // 应该输出3
	cout << "数字4的后继: " << nxt(4) << endl; // 应该输出5
	cout << "数字7的前驱: " << pre(7) << endl; // 应该输出6
	cout << "数字7的后继: " << nxt(7) << endl; // 应该输出9

	// 删除操作
	cout << "\\\\n删除数字4后:" << endl;
	del(4);

	cout << "数字4的排名: " << rnk(4) << endl; // 应该输出4(因为4被删除了,现在排名4的是5)
	cout << "第3小的数: " << kth(3) << endl;   // 应该输出5

	// 再次插入测试重复值
	cout << "\\\\n再次插入数字5(重复值):" << endl;
	insert(5);

	cout << "数字5的排名: " << rnk(5) << endl; // 应该输出4或5,取决于实现

	// 边界测试
	cout << "\\\\n边界测试:" << endl;
	cout << "最小值: " << kth(1) << endl;      // 应该输出1
	cout << "最大值: " << kth(6) << endl;      // 应该输出9(因为现在有6个节点)
	cout << "数字0的前驱: " << pre(0) << endl; // 应该输出-1(不存在)
	cout << "数字10的后继: " << nxt(10) << endl; // 应该输出-1(不存在)

	return 0;
}

10.无注释版本

CPP
#include<bits/stdc++.h>
using namespace std;

const int N=1e5;
int rt;
int val[N],cnt[N],siz[N];
int fa[N],son[N][2];

int dir(int x){
    return son[fa[x]][1]==x;
}

void up(int x){
    siz[x]=siz[son[x][0]]+siz[son[x][1]]+cnt[x];
}

void rotate(int x){
    int y=fa[x],z=fa[y],s=dir(x);
    if(son[x][!s]){
        fa[son[x][!s]]=y;
    }
    son[y][s]=son[x][!s];
    son[x][!s]=y;
    fa[y]=x;
    fa[x]=z;
    if(z){
        son[z][son[z][1]==y]=x;
    }
    up(y);
    up(x);
}

void splay_small(int x){
    while(fa[x]){
        int y=fa[x];
        if(fa[y]){
            if(dir(x)==dir(y)){
                rotate(y);
            }else{
                rotate(x);
            }
        }
        rotate(x);
    }
    rt=x;
}

void splay(int x,int &p){
    int z=fa[p];
    while(fa[x]!=z){
        int y=fa[x];
        if(fa[y]!=z){
            if(dir(x)==dir(y)){
                rotate(y);
            }else{
                rotate(x);
            }
        }
        rotate(x);
    }
    p=x;
}

void insert(int v){
    static int idx=0;
    int x=rt,y=0;
    while(x && val[x]!=v){
        x=son[y=x][v>val[x]];
    }
    if(x){
        cnt[x]++;
        siz[x]++;
    }else{
        x=++idx;
        val[x]=v;
        cnt[x]=siz[x]=1;
        fa[x]=y;
        if(y){
            son[y][v>val[y]]=x;
        }
    }
    splay(x,rt);
}

void find(int v){
    int x=rt;
    while(x && val[x]!=v){
        x=son[x][v>val[x]];
    }
    if(x) splay(x,rt);
}

int rnk(int v){
    find(v);
    if(val[rt]>=v){
        return siz[son[rt][0]]+1;
    }else{
        return siz[son[rt][0]]+cnt[rt]+1;
    }
}

int kth(int k){
    int x=rt;
    while(x){
        if(k<=siz[son[x][0]]){
            x=son[x][0];
        }else if(k<=siz[son[x][0]]+cnt[x]){
            splay(x,rt);
            return val[x];
        }else{
            k-=siz[son[x][0]]+cnt[x];
            x=son[x][1];
        }
    }
    return -1;
}

int pre(int v){
    find(v);
    if(val[rt]<v) return val[rt];
    int x=son[rt][0];
    if(!x) return -1;
    while(son[x][1]) x=son[x][1];
    splay(x,rt);
    return val[x];
}

int nxt(int v){
    find(v);
    if(val[rt]>v) return val[rt];
    int x=son[rt][1];
    if(!x) return -1;
    while(son[x][0]) x=son[x][0];
    splay(x,rt);
    return val[x];
}

void del(int v){
    find(v);
    if(val[rt]!=v) return;
    
    if(cnt[rt]>1){
        cnt[rt]--;
        siz[rt]--;
        return;
    }
    
    if(!son[rt][0] && !son[rt][1]){
        rt=0;
    }else if(!son[rt][0]){
        fa[son[rt][1]]=0;
        rt=son[rt][1];
    }else if(!son[rt][1]){
        fa[son[rt][0]]=0;
        rt=son[rt][0];
    }else{
        int x=son[rt][0];
        while(son[x][1]) x=son[x][1];
        splay(x,son[rt][0]);
        
        son[x][1]=son[rt][1];
        fa[son[rt][1]]=x;
        
        fa[x]=0;
        rt=x;
        up(rt);
    }
}

int main(){
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...