专栏文章

树状数组小记

算法·理论参与者 13已保存评论 12

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
11 条
当前快照
1 份
快照标识符
@mips7ttk
此快照首次捕获于
2025/12/03 17:05
3 个月前
此快照最后确认于
2025/12/03 17:05
3 个月前
查看原文

树状数组(Binary Index Tree)\large\color{00aacd}\textbf{树状数组(Binary Index Tree)}

发现自己之前零零碎碎地写过一些树状数组的应用,所以写这一篇文章来整合一下。
树状数组是一种支持单点修改,区间查询的精巧的数据结构,通常用于维护满足结合律可差分的运算和信息。又称二叉索引树(Binary Index Tree)、Fenwick Tree。

单点修改,区间查询\color{00cd00}\text{单点修改,区间查询}

下面这张图展示了树状数组的原理(来源:OI-Wiki)。
其中 cxc_x 表示以 xx 为右端点,长度为 lowbit(x){\rm lowbit}(x) 的区间的和。
lowbit(x){\rm lowbit}(x) 表示的是 xx 在二进制表示下,最低位的 11 的权值。
例如,1010 在二进制表示下为 101010\underset{\blacktriangle}{\bf1}0,加粗的就是最低位的 11,它的权值是 22,因此 lowbit(10)=2\rm lowbit(10)=2
再例如,2424 在二进制表示下为 110001\underset{\blacktriangle}{\bf1}000,最低位的 11 的权值为 88,因此 lowbit(24)=8\rm lowbit(24)=8
根据位运算知识,可以得到 lowbit(x) = x & -x,其中 &按位与运算。
如果一个数减去自己的 lowbit\rm lowbit,得到的数再减去自己的 lowbit\rm lowbit,不断重复,最终这个数一定会变成 00
例如 7(111) ⁣16(110) ⁣24(100) ⁣407(111)\overset{\!-1}{\longrightarrow}6(110)\overset{\!-2}{\longrightarrow}4(100)\overset{\!-4}{\longrightarrow}0
那么我们要计算 a17a_{1\dots7} 的和,就只需要求 c7+c6+c4c_7+c_6+c_4 即可。观察上图,看看是不是这样。
由此我们可以得到查询 a1xa_{1\dots x} 的代码:
CPP
int query(int x)
{
	int ans = 0;
	while(x > 0)
	{
		ans += c[x];
		x -= lowbit(x);
	}
	return ans;
}
可以发现,树状数组通过将一段数划分成 O(logn)O(\log n) 段数的和,从而能够实现高效的查询操作。
如果要求任意一段区间 alra_{l\dots r} 的和,可以借助前缀和的思想,用 a1ra_{1\dots r} 的和减去 a1l1a_{1\dots l-1} 的和,即 query(r) - query(l-1)。这也说明树状数组可以当成一个支持修改的前缀和来用。
如果要将 a5a_5 加上一个数 kk 该如何处理?观察包含 a5a_5 的区间,只有 c5c_5c6c_6c8c_8。那么就只需要将 c5c_5c6c_6c8c_8 都加上 kk 即可。而 6=5+lowbit(5)6=5+\rm lowbit(5)8=6+lowbit(6)8=6+\rm lowbit(6)。也就是说,在树状数组中,一个结点 xx 的父亲是 x+lowbit(x)x+{\rm lowbit}(x)。由此我们可以得到将 axa_x 加上 kk 的代码:
CPP
void update(int x, int k)
{
	while(x <= n)
	{
		c[x] += k;
		x += lowbit(x);
	}
}
显然,修改操作的时间复杂度也为 O(logn)O(\log n)
以下是一份经过封装的极简树状数组代码,可以通过本题。
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct BIT{ //树状数组
	int c[N], lowbit(int x){return x & -x;}
	void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
	int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
} t;
long long n, m;
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> n >> m;
	for(int i=1, x; i<=n; i++) cin >> x, t.update(i, x);
	while(m --> 0){
		int op, x, y; cin >> op >> x >> y;
		if(op == 1) t.update(x, y);
		if(op == 2) cout << t.query(y) - t.query(x - 1) << "\n";
	}
	return 0;
}

区间修改,单点查询\color{00cd00}\text{区间修改,单点查询}

借助差分的思想。定义差分数组 di=aiai1d_i = a_i - a_{i-1}。于是有 ax=i=1xdia_x = \sum\limits_{i=1}^x d_i。如果要在 alra_{l\dots r} 加上 kk,只需要让 dldl+k,dr+1dr+1kd_l\gets d_l+k,d_{r+1}\gets d_{r+1}-k。使用树状数组维护这一过程即可。
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct BIT{ //树状数组维护差分数组
	int c[N], lowbit(int x){return x & -x;}
	void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
	int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
} t;
int n, m, a[N];
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> n >> m;
	for(int i=1; i<=n; i++) cin >> a[i], t.update(i, a[i] - a[i-1]);
	while(m --> 0){
		int op, x, y, k; cin >> op;
		if(op == 1) cin >> x >> y >> k, t.update(x, k), t.update(y + 1, -k);
		if(op == 2) cin >> x, cout << t.query(x) << "\n";
	}
	return 0;
}

区间修改,区间查询\color{00cd00}\text{区间修改,区间查询}

我们已经会了使用差分数组实现区间修改。接下来只要考虑如何区间查询。因为 ai=j=1idja_i=\sum\limits_{j=1}^i d_j,所以 a1xa_{1\dots x} 的和,即 i=1xai\sum\limits_{i=1}^x a_i 就等于 i=1xj=1idj\sum\limits_{i=1}^x \sum\limits_{j=1}^i d_j。可以发现对于每一个 djd_j 一共加了 xj+1x-j+1 次。那么原式等于 j=1xdj×(xj+1)\sum\limits_{j=1}^{x} d_j\times (x-j+1),也就是 j=1xdj×(x+1)dj×j\sum\limits_{j=1}^x d_j\times (x+1)-d_j\times j。于是发现我们需要 22 个树状数组 c0,c1c_0,c_1,分别维护 djd_jdj×jd_j\times j
CPP
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e5 + 5;
struct BIT{
	int c[2][N], lowbit(int x){return x & -x;}
	void update(int x, int k){
		for(int i=x; i<N; i+=lowbit(i)){
			c[0][i] += k, c[1][i] += k * x;
		}
	}
	int query(int x){
		int ans = 0;
		for(int i=x; i>0; i-=lowbit(i)){
			ans += c[0][i] * (x + 1) - c[1][i];
		}
		return ans;
	}
} t;
int n, m, a[N];
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> n >> m;
	for(int i=1; i<=n; i++) cin >> a[i], t.update(i, a[i] - a[i-1]);
	while(m --> 0){
		int op, x, y, k; cin >> op >> x >> y;
		if(op == 1) cin >> k, t.update(x, k), t.update(y + 1, -k);
		if(op == 2) cout << t.query(y) - t.query(x - 1) << "\n";
	}
	return 0;
}

逆序对\color{00cd00}\text{逆序对}

逆序对,就是在一个序列 a1na_{1\dots n} 中,满足 1ijn1\le i \le j \le nai>aja_i>a_j 的有序对。
可以用 cntxcnt_x 表示当前 xx 出现的数量。从后往前遍历每一个 aia_i,当前能与 aia_i 匹配的逆序对数量就是 j<aicntj\sum\limits_{j<a_i} cnt_j,即小于 aia_i 的数的数量。使用树状数组维护 cntcnt 数组即可。
你说 ai109a_i\le 10^9,数组开不了那么大?注意到逆序对只关心数的相对大小,所以可以将数据离散化,这样值域就降到了 O(n)O(n)。本文不详细展开。
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
struct BIT{
	int c[N], lowbit(int x){return x & -x;}
	void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
	int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
} t;
long long n, m, ans;
int a[N], rk[N];
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> n;
	for(int i=1; i<=n; i++) cin >> a[i], rk[i] = a[i];
	sort(rk + 1, rk + 1 + n); int len = unique(rk + 1, rk + 1 + n) - (rk + 1);
	for(int i=1; i<=n; i++) a[i] = lower_bound(rk + 1, rk + 1 + len, a[i]) - rk;
	for(int i=n; i>=1; i--) ans += t.query(a[i] - 1), t.update(a[i], 1);
	cout << ans;
	return 0;
}

二维树状数组\color{00cd00}\text{二维树状数组}

和一维的树状数组类似,我们设 cx,yc_{x,y} 为右下角为 (x,y)(x,y),向上高为 lowbit(x){\rm lowbit}(x),向左长为 lowbit(y){\rm lowbit}(y) 的矩阵的和。
单点修改时,也和一维一样,i,ji,j 不断加上自己的 lowbit\rm lowbit,修改所有包含 ax,ya_{x,y}ci,jc_{i,j} 即可。
区间查询时,i,ji,j 不断减去自己的 lowbit\rm lowbit,累加沿路上的 ci,jc_{i,j}
如果要查询任意一个矩阵 ax1,y1ax2,y2a_{x_1,y_1}\sim a_{x_2,y_2} 的和,可以用二维前缀和的方法,即 sum(x2,y2)sum(x2,y1 ⁣ ⁣1)sum(x1 ⁣ ⁣1,y2)+sum(x1 ⁣ ⁣1,y1 ⁣ ⁣1){\rm sum}(x2,y2)-{\rm sum}(x2,y1\!-\!1)-{\rm sum}(x1\!-\!1,y2)+{\rm sum}(x1\!-\!1,y1\!-\!1)
这题的值域只有 100100,因此我们只需要用 100100 个二维树状数组分别统计每种权值的数量即可。
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 3e2 + 5;
struct BIT{ //二维树状数组
	int c[N][N], lowbit(int x){return x & -x;}
	void update(int x, int y, int k){
		for(int i=x; i<N; i+=lowbit(i)){
			for(int j=y; j<N; j+=lowbit(j)){
				c[i][j] += k;
			}
		}
	}
	int query(int x, int y){
		int ans = 0;
		for(int i=x; i; i-=lowbit(i)){
			for(int j=y; j; j-=lowbit(j)){
				ans += c[i][j];
			}
		}
		return ans;
	}
	int query(int x1, int y1, int x2, int y2){
		return query(x2, y2) - query(x2, y1-1) - query(x1-1, y2) + query(x1-1, y1-1);
	}
} t[101];
long long n, m, Q;
int a[N][N];
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> n >> m;
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			cin >> a[i][j];
			t[a[i][j]].update(i, j, 1);
		}
	}
	for(cin >> Q; Q --> 0;){
		int op, c; cin >> op;
		if(op == 1){
			int x, y; cin >> x >> y >> c;
			t[a[x][y]].update(x, y, -1);
			t[c].update(x, y, 1);
			a[x][y] = c;
		}
		if(op == 2){
			int x1, y1, x2, y2; cin >> x1 >> x2 >> y1 >> y2 >> c;
			cout << t[c].query(x1, y1, x2, y2) << "\n";
		}
	}
	return 0;
}
要实现矩阵修改,也和一维时的方法一样,在二维数组上差分,维护差分数组即可。二维的差分数组定义为 di,j=ai,jai1,jai,j1+ai1,j1d_{i,j} = a_{i,j} - a_{i-1,j} - a_{i, j-1} + a_{i-1, j-1},它满足 ax,y=i=1xj=1ydi,ja_{x,y} = \sum\limits_{i=1}^x\sum\limits_{j=1}^y d_{i,j}
如果要同时实现矩阵修改和矩阵查询,也可以先差分,然后推一下式子:
p=1xq=1yap,q=p=1xq=1yi=1pj=1qdi,j=i=1xj=1ydi,j×(xi+1)×(yj+1)=i=1xj=1ydi,j×(xy+x+y+1)di,j×i×(y+1)di,j×j×(x+1)+di,j×i×j\begin{aligned} &\sum_{p=1}^x\sum_{q=1}^y a_{p,q} \\ =&\sum_{p=1}^x\sum_{q=1}^y\sum_{i=1}^p\sum_{j=1}^q d_{i,j} \\ =&\sum_{i=1}^x\sum_{j=1}^y d_{i,j}\times (x-i+1) \times (y-j+1) \\ =&\sum_{i=1}^x\sum_{j=1}^y d_{i,j}\times(xy+x+y+1)-d_{i,j}\times i\times(y+1)-d_{i,j}\times j\times (x+1)+d_{i,j}\times i\times j \end{aligned}
于是我们用四个二维树状数组,分别维护 di,j,di,j×i,di,j×j,di,j×i×jd_{i,j},d_{i,j}\times i,d_{i,j}\times j,d_{i,j}\times i \times j 即可。
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 3e3 + 5;
struct BIT{
	int c[4][N][N], lowbit(int x){return x & -x;}
	void update(int x, int y, int k){
		for(int i=x; i<N; i+=lowbit(i)){
			for(int j=y; j<N; j+=lowbit(j)){
				c[0][i][j] += k;
				c[1][i][j] += k * x;
				c[2][i][j] += k * y;
				c[3][i][j] += k * x * y;
			}
		}
	}
	int query(int x, int y){
		int ans = 0;
		for(int i=x; i; i-=lowbit(i)){
			for(int j=y; j; j-=lowbit(j)){
				ans += c[0][i][j] * (x + 1) * (y + 1)
				     - c[1][i][j] * (y + 1)
				     - c[2][i][j] * (x + 1)
				     + c[3][i][j];
			}
		}
		return ans;
	}
	void update(int x1, int y1, int x2, int y2, int k){
		update(x1, y1, k), update(x2+1, y1, -k), update(x1, y2+1, -k), update(x2+1, y2+1, k);
	}
	int query(int x1, int y1, int x2, int y2){
		return query(x2, y2) - query(x2, y1-1) - query(x1-1, y2) + query(x1-1, y1-1);
	}
} t;
long long n, m;
char op;
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> op >> n >> m;
	while(cin >> op){
		int x1, y1, x2, y2, k;
		cin >> x1 >> y1 >> x2 >> y2;
		if(op == 'L') cin >> k, t.update(x1, y1, x2, y2, k);
		if(op == 'k') cout << t.query(x1, y1, x2, y2) << "\n";
	}
	return 0;
}

权值树状数组\color{00cd00}\text{权值树状数组}

所谓权值数组,就是将权值作为下标,统计每种权值出现的次数。权值树状数组就是使用树状数组维护权值数组。我们在前面的“逆序对”一节中已经用到了权值树状数组,这里我们利用权值树状数组解决“查询全局第 kk 小值”问题。
我们需要实现以下操作:
  1. 在序列中加入一个数 xx
  2. 在序列中删除一个数 xx
  3. 查询序列中第 kk 小的数是多少。
对于操作 1/21/2,就是在权值数组中将 cntxcntx±1cnt_x \gets cnt_x\pm 1。对于操作 33,可以考虑二分 xx,用树状数组查询小于 xx 的数的数量,不断调整直到找到一个 x0x_0 满足 Sum(1,x0)<k\operatorname{Sum}(1,x_0)<kSum(1,x0+1)k\operatorname{Sum}(1,x_0+1)\ge k,此时 x0+1x_0+1 即为第 kk 小的数。
二分法的时间复杂度为 O(log2n)O(\log^2 n)。实际上,我们有 O(logn)O(\log n) 的方法解决这个问题。
把二分换成倍增。设 x=0x=0sum=0sum=0,枚举 i=log2n0i=\log_2n\to 0
  • s=Sum(x+1,x+2i)s=\operatorname{Sum}(x+1, x+2^i)
  • 如果 sum+s<ksum+s<k,就将 sumsum+ssum\gets sum+sxx+2ix\gets x+2^i
最终得到的 xx 是满足 Sum(1,x)<k\operatorname{Sum}(1,x)<k 的最大值,x+1x+1 即为第 kk 小的数。
根据树状数组的美好性质,查询 Sum(x+1,x+2i)\operatorname{Sum}(x+1,x+2^i) 只需要访问 cx+2ic_{x+2^i} 就行了,不需要 O(logn)O(\log n) 查询一遍。因此倍增法的时间复杂度仅为 O(logn)O(\log n)
CPP
int get_kth(int k){
	int sum = 0, x = 0;
	for(int i=20; i>=0; i--){
		x += 1 << i;
		if(x > N || sum + c[x] >= k) x -= 1 << i;
		else sum += c[x];
	}
	return x + 1;
}
想不到吧,树状数组还能当平衡树用。
这题比上面还多了查询 xx 的排名、前驱、后继的操作。对于求 xx 的排名,就是 Sum(x1)+1\operatorname{Sum}(x-1)+1。对于求 xx 的前驱/后继,都可以转化为求排名和查询第 kk 小的操作。
因为需要离散化,所以只能离线下来做。
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
struct BIT{
	int c[N], lowbit(int x){return x & -x;}
	void update(int x, int k){while(x < N) c[x] += k, x += lowbit(x);}
	int query(int x){int s = 0; while(x) s += c[x], x -= lowbit(x); return s;}
	int get_kth(int k){
		int sum = 0, x = 0;
		for(int i=20; i>=0; i--){
			x += 1 << i;
			if(x > N || sum + c[x] >= k) x -= 1 << i;
			else sum += c[x];
		}
		return x + 1;
	}
} t;
int n, m, rl[N];
pair<int, int> q[N];
signed main(){
	cin.tie(nullptr) -> sync_with_stdio(false);
	cin >> n;
	for(int i=1; i<=n; i++){
		auto &[opt, x] = q[i]; 
		cin >> opt >> x;
		if(opt != 4) rl[++m] = x;
	}
	sort(rl+1, rl+1+m), m = unique(rl+1, rl+1+m) - (rl+1);
	for(int i=1; i<=n; i++){
		auto [opt, x] = q[i];
		if(opt != 4) x = lower_bound(rl+1, rl+1+m, x) - rl;
		if(opt == 1) t.update(x, 1);
		if(opt == 2) t.update(x, -1);
		if(opt == 3) cout << t.query(x - 1) + 1 << "\n";
		if(opt == 4) cout << rl[t.get_kth(x)] << "\n";
		if(opt == 5) cout << rl[t.get_kth(t.query(x - 1))] << "\n";
		if(opt == 6) cout << rl[t.get_kth(t.query(x) + 1)] << "\n";
	}
	return 0;
}
可以发现,用权值树状数组实现普通平衡树的代码只有约 1KB\tt{1KB},效率也比一众平衡树高不少。

树状数组与 min/max\color{00cd00}\text{树状数组与 min/max}

需要注意的是,因为 min/max\min/\max 不满足可差分性,所以普通的树状数组不能用于解决 RMQ 问题。但是查询前缀 min/max\min/\max 是可以的,这可以用于一些 DP 的优化,如 P9097gcd\gcd 等满足结合律但不可差分的运算也是同理。
代码就是把普通树状数组里的 ++ 换成 min/max\min/\max。有时候可能需要写一个构造函数将 cc 初始化为无穷大/无穷小。
CPP
struct BIT{
	int c[N], lowbit(int x){return x & -x;}; BIT(){memset(c, 0x3f, sizeof(c));}
	void update(int x, int k){while(x < N) c[x] = min(c[x], k), x += lowbit(x);}
	int query(int x){int s = N; while(x) s = min(s, c[x]), x -= lowbit(x); return s;}
};
事实上,有一种 O(log2n)O(\log^2n) 的方法让树状数组维护不可差分信息,但是用处不大,本文不再赘述。

以上所有代码的树状数组都使用了结构体封装,大家可以直接拿来用 QwQ。

评论

12 条评论,欢迎与作者交流。

正在加载评论...