专栏文章

【树状数组】学习笔记

算法·理论参与者 2已保存评论 2

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
2 条
当前快照
1 份
快照标识符
@mir0srpf
此快照首次捕获于
2025/12/04 13:53
3 个月前
此快照最后确认于
2025/12/04 13:53
3 个月前
查看原文
本文中,为了避免歧义,定义:
  • nn:数据数量。
  • a1ana_1 \sim a_n 原始数组。
  • b1bnb_1 \sim b_n 树状数组。

引入

给你一个数列 a1ana_1 \sim a_n,你需要实现两个函数:
  • 单点修改:将数列中的一个数值加 xx
  • 区间查询:求出序列中前几个数的和。
如果用暴力或者前缀和显然是不行的,考虑优化。

理论

树状数组的原理

对于这个数列:
11004466552214143344661313221199551212
可以把他相邻两个数求和,并归为新的一层,一直这样直到只剩下一个数字。
8787
35355252
1111242425252727
1110107717171010151510101717
11004466552214143344661313221199551212
他们的关系是这样的: 这样就可以用额外计算出的数来优化时间。
到这个时候,求区间的和操作就可以找一些上面的大数,再拿下面的小数凑整。
比如计算前 1313 个只需要这些标红数字即可:
8787
35\color{red}355252
1111242425\color{red}252727
1110107717171010151510101717
11004466552214143344661313221\color{red}{1}99551212
大大优化了运算速度。
这里注意到比如我想求前三个的和,那么第四行第二个数用不到,求前四个和时用第三行第一个更优,所以第四行第二个数没有任何用处。像这样无意义的数据还有很多,每行的第偶数个数据都没用,可以删掉。
8787
3535
11112525
117710101010
11445514144413131155
这时候,每一列恰好都只有一个数,我们把每个数取出来组成一个数组:
111144111155771414353544101013132525111010558787
这个数组就是树状数组,里面的每一个元素都对应着一个区间和。
求和时,只需要找到对应的区间相加求和即可。
修改时,只需要找到向上包含它的区间再修改。

lowbit\operatorname{lowbit} 函数

lowbit(x)\operatorname{lowbit}(x) 可以求出 xx 在二进制下最低位代表哪个数字。
比如二进制数字 1001010010010010100100(十进制 11881188),它的最低位是 1001010010010010100\color{red}{1}\color{black}{00},所代表的数就是 1001010010010010100\color{red}100。二进制数 100100 对应的十进制数就是 44,所以lowbit(1188)=4\operatorname{lowbit}(1188)=4
代码使用位运算来完成。
CPP
int lowbit(int x)
{
    return x&(-x);
}
证明很简单,自己按位与自己的反码,除了最低有效位其他都会直接抵消。

使用 lowbit\operatorname{lowbit} 实现树状数组

观察树状数组,最后一行的序列长度都为 11,而这些区间对应的树状数组序号的 lowbit\operatorname{lowbit} 也为 11。倒数第二行的序列长度为 22,他们对应序号的 lowbit\operatorname{lowbit} 也为 22
其他的几行也是如此,依次是 2,4,8,16,2,4,8,16,\cdots 依次是二的整数次幂。
比如 b14b_{14},它对应的序列长度就是 lowbit(14)=4\operatorname{lowbit}(14)=4。其他也是同理。
也就是说,bib_i 对应的序列就是长度为 lowbit(i)\operatorname{lowbit}(i) 且以 ii 结尾的序列。
这个时候,如果我们要求前 1414 个数的和,14lowbit(14)=1214-\operatorname{lowbit}(14)=12,那么,只需要计算 b14b_{14} 加上前十二个数的和就好了。计算前十二个数的和可以仿照同样的办法。
求解过程可记作:
sum(pos)={0pos0bpos+sum(poslowbit(pos))pos>0\operatorname{sum}(pos) = \begin{cases} 0 & pos \le 0\\ b_{pos}+ \operatorname{sum}(pos-\operatorname{lowbit}(pos))& pos>0\\ \end{cases}
也可以不用递归,不用递归的版本也很好写。
  • 递归版本
CPP
int sum(int pos)
{
    if(pos<=0)
	{
		return 0; 
	}
    return b[pos]+sum(pos-lowbit(pos));
}
  • 非递归版本
CPP
int sum(int pos)
{
    int cnt=0;
    while(pos>0)
	{
        cnt+=t[pos];
        pos-=lowbit(pos);
    }
    return cnt;
}

还有一个性质,就是 bib_i 正上方的序列刚好就是 bi+lowbit(i)b_{i+\operatorname{lowbit}(i)}
所以只要在修改的时候不断加上 lowbit(i)\operatorname{lowbit}(i) 就可以找到包含自己的所有序列进行修改。
CPP
void add(int pos,int x)//将第 pos 加上 x 并更新树状数组相关的元素
{
    while(pos<=n)
	{
        t[pos]+=x;
        pos+=lowbit(pos);
    }
}

例题

【单点修改】&【求区间和】

很板的树状数组,不妨在建树的时候输入一个 aa 把他当作单点修改操作。
CPP
#include<bits/stdc++.h>
using namespace std;
int n,m;
int t[1000005];
int lowbit(int x)
{
    return x&(-x);
}
void add(int pos,int x)
{
    while(pos<=n)
	{
        t[pos]+=x;
        pos+=lowbit(pos);
    }
}
int sum(int pos)
{
    int cnt=0;
    while(pos>0)
	{
        cnt+=t[pos];
        pos-=lowbit(pos);
    }
    return cnt;
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n;i++)
	{
		int x;
		cin>>x;
		add(i,x);
	}
	while(m--)
	{
		int q;
		cin>>q;
		if(q==1)
		{
			int x,k;
			cin>>x>>k;
			add(x,k);
		}
		else
		{
			int l,r;
			cin>>l>>r;
			cout<<sum(r)-sum(l-1)<<'\n';
		}
	}
	return 0;
}

\lfloor区间修改\rceil&\lfloor单点查询\rceil

可以考虑维护差分树状数组,利用差分思想来预处理出差分数组。
CPP
#include<bits/stdc++.h>
using namespace std;
int n,m;
int t[1000005];
int lowbit(int x)
{
    return x&(-x);
}
void add(int pos,int x)
{
    while(pos<=n)
	{
        t[pos]+=x;
        pos+=lowbit(pos);
    }
}
int sum(int pos)
{
    int cnt=0;
    while(pos>0)
	{
        cnt+=t[pos];
        pos-=lowbit(pos);
    }
    return cnt;
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n;i++)
	{
		int x;
		cin>>x;
		add(i,x);
		add(i+1,-x);
		/*
		可以理解为在 i~i 区间内加 x。
		*/
	}
	while(m--)
	{
		int q;
		cin>>q;
		if(q==1)
		{
			int x,y,k;
			cin>>x>>y>>k;
			add(x,k);
			add(y+1,-k);
		}
		else
		{
			int x;
			cin>>x;
			cout<<sum(x)<<'\n';
		}
	}
	return 0;
}

\lfloor区间修改\rceil&\lfloor求区间和\rceil

区间修改利用差分维护即可,重点看求区间和。
那么
i=1xai=i=11bi+i=12bi+i=13bi+i=1xbi=b1×x+b2×(x1)+b3×(x2)++bx×1=(x+1)i=1xdi1×d12×d2++x×dx=(x+1)i=1xdii=1x(i×di)\begin{aligned} \sum_{i=1}^{x} a_i &= \sum_{i=1}^{1} b_i + \sum_{i=1}^{2} b_i + \sum_{i=1}^{3} b_i + \cdots \sum_{i=1}^{x} b_i \\ &= b_1 \times x + b_2 \times (x-1) + b_3 \times (x-2) + \cdots + b_x \times 1 \\ &= (x+1) \sum_{i=1}^{x} d_i - 1 \times d_1 -2 \times d_2 + \cdots + x \times d_x \\ &= (x+1) \sum_{i=1}^{x} d_i - \sum_{i=1}^{x} (i \times d_i) \end{aligned}
我们给两个 \sum 都做一个树状数组就可以了。
CPP
#include<bits/stdc++.h>
using namespace std;
#define int long long					//开ll(偷懒写法
int n,m,a[1000005];						//要用数组输入来保存差分数组
int At[1000005];
int Bt[1000005];
int lowbit(int x)
{
    return x&(-x);
}
void Aadd(int pos,int x)
{
    while(pos<=n)
	{
        At[pos]+=x;
        pos+=lowbit(pos);
    }
}
int Asum(int pos)
{
    int cnt=0;
    while(pos>0)
	{
        cnt+=At[pos];
        pos-=lowbit(pos);
    }
    return cnt;
}

void Badd(int pos,int x)
{
    while(pos<=n)
	{
        Bt[pos]+=x;
        pos+=lowbit(pos);
    }
}
int Bsum(int pos)
{
    int cnt=0;
    while(pos>0)
	{
        cnt+=Bt[pos];
        pos-=lowbit(pos);
    }
    return cnt;
}
 /*
((y+1ll)*Asum(y)-Bsum(y))-((x+1ll)*Asum(x)-Bsum(x))
*/
int getSum(int p)
{
	return (p+1LL)*Asum(p)-Bsum(p);
}
signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n;i++)
	{
		cin>>a[i];
		Aadd(i,a[i]-a[i-1]);
		Badd(i,i*(a[i]-a[i-1]));
	}
	while(m--)
	{
		int q;
		cin>>q;
		if(q==1)
		{
			int x,y,k;
			cin>>x>>y>>k;
			Aadd(x,k);
			Aadd(y+1,-k);
			Badd(x,k*x);
			Badd(y+1,-(k*(y+1)));
		}
		else
		{
			int x,y;
			cin>>x>>y;
			cout<<getSum(y)-getSum(x-1)<<'\n';
		}
	}
	return 0;
}

\lfloor权值树状数组求逆序对\rceil

要离散化。
按价值从大到小排序,排完序之后用树状数组维护,每次把这个数的位置加入到树状数组中。之前加入的一定比后加入的大,然后在查询当前这个数前面位置的数(是前面位置的数,要当前这个数减1)。就是逆序对的个数了
求逆序对。设树状数组为 tt
检查多少组 ajai(j<i)a_{j} \sim a_i(j <i ) 逆序对。
检查 a1ai1a_1 \sim a_{i-1} 有几个大于 aia_i 的数。
检查 tai+1tnt_{a_i+1} \sim t_n 和为多少即可。
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
int ans=0;
struct node
{
	int x;//原数 
	int id;//在原数组里的编号 
	int t;//离散化之后的数字 
}a[500005];
bool cmp(node x,node y)
{
	return x.x<y.x;
}
bool cmp2(node x,node y)
{
	return x.id<y.id;
}
int t[500005];
int n;
int lowbit(int x)
{
    return x&(-x);
}
void add(int pos,int x)
{
    while(pos<=n)
	{
        t[pos]+=x;
        pos+=lowbit(pos);
    }
}
int sum(int pos)
{
    int cnt=0;
    while(pos>0)
	{
        cnt+=t[pos];
        pos-=lowbit(pos);
    }
    return cnt;
}


signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin>>n;
	for(int i=1;i<=n;i++)
	{
		cin>>a[i].x;
		a[i].id=i;
	}
	//-----------------------------抽象离散化 
	sort(a+1,a+1+n,cmp);
	int tot=1;
	for(int i=1;i<=n;)
	{
		int X=a[i].x;
		while(a[i].x==X)
		{
			a[i].t=tot;
			i++;
		}
		tot++;
	}
	sort(a+1,a+1+n,cmp2);
	//--------------------------- 
	for(int i=1;i<=n;i++)
	{
		int x=a[i].t;
		add(x,1);
		ans+=i-sum(x);
	}
	cout<<ans;
	return 0;
}

二维树状数组

可以维护二维数组。
一维树状数组套一维树状数组。
根一维很像,多了一个维度。比较麻烦的是区间求和,涉及了二维前缀和与二维差分。
单点修改:
CPP
void add(int x,int y,int k)
{
	for(int i=x;i<=n;i+=lowbit(i))
	{
		for(int j=y;j<=m;j+=lowbit(j))
		{
			t[i][j]+=k;
		}
	}
}
求区间和:
i=1xj=1yai,j\sum_{i=1}^{x} \sum_{j=1}^{y} a_{i,j} CPP
int sum(int x,int y)
{
	int cnt=0;
	for(int i=x;i>=1;i-=lowbit(i))
	{
		for(int j=y;j>=1;j-=lowbit(j))
		{
			cnt+=t[i][j];
		}
	}
	return cnt;
}

\lfloor二维单点修改\rceil&\lfloor二维区间求和\rceil

没有原题,所以先规定一个题面来避免歧义:problem
这里涉及了二维前缀和,求一个区间的和可以进行类似这样的操作:
先设原二维数组为 aa,设前缀和数组 sumsum
sumi,j=x=1iy=1jax,ysum_{i,j}= \sum_{x=1}^{i} \sum_{y=1}^j a_{x,y}
根据容斥原理就可以推出来求区间和的公式。
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,op;
int t[5000][5000];
int lowbit(int x)
{
	return x&-x;
}
void add(int x,int y,int k)
{
	for(int i=x;i<=n;i+=lowbit(i))
	{
		for(int j=y;j<=m;j+=lowbit(j))
		{
			t[i][j]+=k;
		}
	}
}
int sum(int x,int y)
{
	int cnt=0;
	for(int i=x;i>=1;i-=lowbit(i))
	{
		for(int j=y;j>=1;j-=lowbit(j))
		{
			cnt+=t[i][j];
		}
	}
	return cnt;
}
signed main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	cin>>n>>m;
	while(cin>>op)
	{
		if(op==1)
		{
			int x,y,k;
			cin>>x>>y>>k;
			add(x,y,k);
		}
		if(op==2)
		{
			int x,y,z,t;
			cin>>x>>y>>z>>t;
			cout<<sum(z,t)-sum(x-1,t)-sum(z,y-1)+sum(x-1,y-1)<<"\n";
		}
	}
	return 0;
}

评论

2 条评论,欢迎与作者交流。

正在加载评论...