专栏文章

浅谈矩阵乘法在线段树标记下传的运用

算法·理论参与者 2已保存评论 1

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
1 条
当前快照
1 份
快照标识符
@mlmkiskj
此快照首次捕获于
2026/02/15 01:06
3 周前
此快照最后确认于
2026/03/09 01:27
前天
查看原文

前言

对于复杂的区间操作,我们自然会想到使用线段树。但是,这也意味着会出现复杂的标记下传与维护,许多初学者也因此开始打退堂鼓。本文将介绍另一种用矩阵乘法来解决复杂标记下传的问题。
前置知识:简单线段树、线性代数基础(矩阵乘法)。

一、多标记下传

(一)P3373 线段树 2

由于有区间乘法和区间加法两种操作,所以用普通的标记则需要分别维护乘法标记和加法标记,并且需要进行复杂的分类讨论。而使用矩阵乘法就不一样了,只需要一个标记,自然也没有了分类讨论。
我们让线段树的每个节点存储一个向量 [xlen]\begin{bmatrix}x\\len\end{bmatrix} ,其中 xx 表示当前的区间和,lenlen 表示区间长度。现对其进行加 kk ,也就是使这个向量变为 [x+len×klen]\begin{bmatrix}x+len\times k\\len\end{bmatrix}lenlen 不能变)。不难得到,这就是对原向量[1k01]\begin{bmatrix}1&k\\0&1\end{bmatrix} ,即 [1k01][xlen]=[x+len×klen]\begin{bmatrix}1&k\\0&1\end{bmatrix}\begin{bmatrix}x\\len\end{bmatrix}=\begin{bmatrix}x+len\times k\\len\end{bmatrix} 。同理可得,对区间乘 kk 就是对原向量[k001]\begin{bmatrix}k&0\\0&1\end{bmatrix}
注意
若想用右乘,只需将列向量变为行向量,再将原来用于左乘的矩阵进行转置后拿去右乘,但本文中的矩阵乘法均使用左乘。
此时,我们就将区间加与区间乘两个完全不同的操作,统一为了矩阵乘法。于是我们就只用维护一个矩阵乘法的标记。因为矩阵乘法满足结合律,所以可以放心打标记。
对于 pushup ,只需要对两个子节点的向量进行矩阵加法后放到父节点上就行了,这一点很容易理解。
参考代码CPP
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5,mod=571373;
struct Matrix{
	int n,m,a[3][3];
	void clear(){for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)a[i][j]=0;}
	void reset(){clear();for(int i=1;i<=min(n,m);i++)a[i][i]=1;}
	void init(int _n,int _m,int op){n=_n,m=_m;if(op==0)clear();else reset();}
	Matrix friend operator+(const Matrix&A,const Matrix&B){
		if(A.n!=B.n||A.m!=B.m)cout<<"Error:add",exit(0);
		Matrix C;C.n=A.n,C.m=A.m;
		for(int i=1;i<=A.n;i++)for(int j=1;j<=A.m;j++)
			C.a[i][j]=(A.a[i][j]+B.a[i][j])%mod;
		return C;
	}
	Matrix friend operator*(const Matrix&A,const Matrix&B){
		if(A.m!=B.n)cout<<"Error:mul",exit(0);
		Matrix C;C.init(A.n,B.m,0);
		for(int i=1;i<=A.n;i++)for(int j=1;j<=B.m;j++)
			for(int k=1;k<=A.m;k++)
				C.a[i][j]=(1ll*A.a[i][k]*B.a[k][j]%mod+C.a[i][j])%mod;
		return C;
	}
};
struct segment{
	#define mid (l+r>>1)
	int n;Matrix tr[(N<<2)+1],tag[(N<<2)+1];
	void push(int u,Matrix V){tr[u]=V*tr[u],tag[u]=V*tag[u];}
	void pushdown(int u){push(u<<1,tag[u]),push(u<<1|1,tag[u]),tag[u].init(2,2,1);}
	void pushup(int u){tr[u]=tr[u<<1]+tr[u<<1|1];}
	void update(int u,int l,int r,int L,int R,int v,int op){
		if(L<=l&&r<=R){
			Matrix T;T.init(2,2,1);
			if(op==1)T.a[1][1]=v;
			else T.a[1][2]=v;
			return push(u,T);
		}
		pushdown(u);
		if(L<=mid)update(u<<1,l,mid,L,R,v,op);
		if(R>mid)update(u<<1|1,mid+1,r,L,R,v,op);
		pushup(u);
	}void update(int L,int R,int v,int op){
		update(1,1,n,L,R,v,op);}
	Matrix query(int u,int l,int r,int L,int R){
		if(L<=l&&r<=R)return tr[u];
		pushdown(u);Matrix ans;ans.init(2,1,0);
		if(L<=mid)ans=ans+query(u<<1,l,mid,L,R);
		if(R>mid)ans=ans+query(u<<1|1,mid+1,r,L,R);
		return ans;
	}int query(int L,int R){
		return query(1,1,n,L,R).a[1][1];}
	void build(int u,int l,int r,int*a){
		tr[u].init(2,1,0),tag[u].init(2,2,1);
		if(l==r)return tr[u].a[1][1]=a[l],tr[u].a[2][1]=1,void();
		build(u<<1,l,mid,a),build(u<<1|1,mid+1,r,a),pushup(u);
	}void build(int n_,int*a){n=n_,build(1,1,n,a);}
	#undef mid
}seg;
int n,q,m,a[N];
int main(){
	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	cin>>n>>q>>m;
	for(int i=1;i<=n;i++)cin>>a[i];
	seg.build(n,a);
	int op,l,r,v;
	for(;q--;){
		cin>>op>>l>>r;
		if(op<3)cin>>v,seg.update(l,r,v,op);
		else cout<<seg.query(l,r)<<"\n";
	}
	return 0;
}

(二)P1253 扶苏的问题

对于这道题,我们让线段树的每个节点存储一个向量 [x1]\begin{bmatrix}x\\1\end{bmatrix} ,其中 xx 表示当前的区间最大值。根据上道题的思路,我们可以将区间覆盖操作变为左乘 [0k01]\begin{bmatrix}0&k\\0&1\end{bmatrix} ,区间加法变为左乘 [1k01]\begin{bmatrix}1&k\\0&1\end{bmatrix}(想想区间加为什么和上一题一样)。pushup 则取左右两个儿子中向量 xx 值更大的。所以只需要将上一题的代码改一下就是这题的代码了。
参考代码CPP
//因为常数问题,该代码仅能得到90tps
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e6+5;
struct Matrix{
	int n,m;ll a[3][3];
	void clear(){for(int i=1;i<=n;i++)for(int j=1;j<=m;j++)a[i][j]=0;}
	void reset(){clear();for(int i=1;i<=min(n,m);i++)a[i][i]=1;}
	void init(int _n,int _m,int op){n=_n,m=_m;if(op==0)clear();else reset();}
	Matrix friend operator*(const Matrix&A,const Matrix&B){
		if(A.m!=B.n)cout<<"Error:mul",exit(0);
		Matrix C;C.init(A.n,B.m,0);
		for(int i=1;i<=A.n;i++)for(int j=1;j<=B.m;j++)
			for(int k=1;k<=A.m;k++)
				C.a[i][j]+=A.a[i][k]*B.a[k][j];
		return C;
	}
};
struct segment{
	#define mid (l+r>>1)
	int n;Matrix tr[(N<<2)+1],tag[(N<<2)+1];
	void push(int u,Matrix V){tr[u]=V*tr[u],tag[u]=V*tag[u];}
	void pushdown(int u){push(u<<1,tag[u]),push(u<<1|1,tag[u]),tag[u].init(2,2,1);}
	void pushup(int u){tr[u].a[1][1]=max(tr[u<<1].a[1][1],tr[u<<1|1].a[1][1]);}
	void update(int u,int l,int r,int L,int R,int v,int op){
		if(L<=l&&r<=R){
			Matrix T;T.init(2,2,1);
			T.a[1][1]=op-1,T.a[1][2]=v;
			return push(u,T);
		}
		pushdown(u);
		if(L<=mid)update(u<<1,l,mid,L,R,v,op);
		if(R>mid)update(u<<1|1,mid+1,r,L,R,v,op);
		pushup(u);
	}void update(int L,int R,int v,int op){update(1,1,n,L,R,v,op);}
	ll query(int u,int l,int r,int L,int R){
		if(L<=l&&r<=R)return tr[u].a[1][1];
		pushdown(u);ll ans=-1e18;
		if(L<=mid)ans=max(ans,query(u<<1,l,mid,L,R));
		if(R>mid)ans=max(ans,query(u<<1|1,mid+1,r,L,R));
		return ans;
	}ll query(int L,int R){return query(1,1,n,L,R);}
	void build(int u,int l,int r,int*a){
		tr[u].init(2,1,0),tag[u].init(2,2,1),tr[u].a[2][1]=1;
		if(l==r)return tr[u].a[1][1]=a[l],tr[u].a[2][1]=1,void();
		build(u<<1,l,mid,a),build(u<<1|1,mid+1,r,a),pushup(u);
	}void build(int n_,int*a){n=n_,build(1,1,n,a);}
	#undef mid
}seg;
int n,q,a[N];
int main(){
	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	cin>>n>>q;
	for(int i=1;i<=n;i++)cin>>a[i];
	seg.build(n,a);
	int op,l,r,v;
	for(;q--;){
		cin>>op>>l>>r;
		if(op<3)cin>>v,seg.update(l,r,v,op);
		else cout<<seg.query(l,r)<<"\n";
	}
	return 0;
}

二、区间历史最值

前置知识1:区间历史最值的概念

历史最大值

简单地说,一个位置的历史最大值就是当前位置下曾经出现过的数的最大值。形式化地定义,我们定义一个辅助数组 BB ,一开始与 AA 完全相同。在 AA 的每次操作后,我们对整个数组取 max\maxi[1,n], Bi=max(Bi,Ai)\forall i\in[1,n],\ B_i=\max(B_i,A_i) 这时,我们将 BiB_i 称作这个位置的历史最大值。

历史最小值

定义与历史最大值类似,在 AA 的每次操作后,我们对整个数组取 min\min 。这时,我们将 BiB_i 称作这个位置的历史最小值。

历史版本和

辅助数组 BB 一开始全部是 00。在每一次操作后,我们把整个 AA 数组累加到 BB 数组上: i[1,n], Bi=Bi+Ai\forall i\in[1,n], \ B_i=B_i+A_i 我们称 BiB_iii 这个位置上的历史版本和。
※ 以上内容摘自 OI Wiki
前置知识2:广义矩阵乘

一、定义

设两个 n×nn\times n 的矩阵 A,BA,B ,定义广义矩阵乘 C=ABC=A\odot B(符号 \odot 表示广义乘),则其元素 Ci,jC_{i,j} 满足:Ci,j=k=1n(Ai,kBk,j)C_{i,j}=\bigoplus_{k=1}^{n}(A_{i,k}\otimes B_{k,j})
  • \oplus 是“广义加法”(比如 min\minmax\max、普通加法);
  • \otimes 是“广义乘法”(比如普通加法、普通乘法、减法);
  • 运算范围需一致(比如实数集、整数集、有限集合)。
在此基础上还需保证设计出来的广义矩阵乘满足结合律

二、判断结合律:44 条半环公理

广义矩阵乘 \odot 满足结合律的充要条件是:运算对 (,)(\oplus,\otimes) 满足以下 44 条公理(构成“半环”结构):
  1. \oplus 的交换律:对任意元素 a,ba,b ,有 ab=baa\oplus b=b\oplus a
  2. \oplus 的结合律:对任意元素 a,b,ca,b,c ,有 (ab)c=a(bc)(a\oplus b)\oplus c=a\oplus(b\oplus c)
  3. \otimes 的结合律:对任意元素 a,b,ca,b,c ,有 (ab)c=a(bc)(a\otimes b)\otimes c=a\otimes(b\otimes c)
  4. \otimes\oplus 的左右分配律:对任意元素 a,b,ca,b,c ,有:a(bc)=(ab)(ac)a\otimes (b\oplus c)=(a\otimes b)\oplus(a\otimes c)(ab)c=(ac)(bc)(a\oplus b)\otimes c=(a\otimes c)\oplus(b\otimes c)
我们知道这一类问题的标准解法为吉司机线段树,但吉司机线段树也需要复杂的标记下传。而这里将介绍如何用矩阵乘法来解决,或者说用矩阵乘法来理解吉司机线段树。

(一)区间历史最大/最小

先来看区间历史最大。这次我们让线段树的每个节点存储的向量为 [ab]\begin{bmatrix}a\\b\end{bmatrix} ,其中 aa 表示当前区间最大值,bb 表示区间历史最大值。并定义广义矩阵乘中广义加法为取 max\max ,广义乘法为普通加法。此时,对于区间加操作,有 [kk0][ab]=[a+kmax(b,a+k)]\begin{bmatrix}k&-\infty\\k&0\end{bmatrix}\begin{bmatrix}a\\b\end{bmatrix}=\begin{bmatrix}a+k\\\max(b,a+k)\end{bmatrix}(使用广义矩阵乘,下同)。若还有区间覆盖操作,就需要将向量增加一维到 [ab0]\begin{bmatrix}a\\b\\0\end{bmatrix} ,于是就有了 [k0k0][ab0]=[kmax(b,k)0]\begin{bmatrix}-\infty&-\infty&k\\-\infty&0&k\\-\infty&-\infty&0\end{bmatrix}\begin{bmatrix}a\\b\\0\end{bmatrix}=\begin{bmatrix}k\\\max(b,k)\\0\end{bmatrix}(此时的区间加请下来自己思考)。而 pushup ,也就变为了将左右儿子的向量用此时的广义加法(取 max\max )加起来再赋给父节点。
对于区间历史最小,就是将广义加法定义为取 min\min ,剩下的与区间历史最大一样。

(二)区间历史和

这里使用普通矩阵乘。让线段树的每个节点存储的向量为 [ablen]\begin{bmatrix}a\\b\\len\end{bmatrix} ,其中 aa 表示当前区间和,bb 表示区间历史和,lenlen 表示区间长度。对于区间加,有 [10k11k001][ablen]=[a+len×kb+a+len×klen]\begin{bmatrix}1&0&k\\1&1&k\\0&0&1\end{bmatrix}\begin{bmatrix}a\\b\\len\end{bmatrix}=\begin{bmatrix}a+len\times k\\b+a+len\times k\\len\end{bmatrix} ;对于区间覆盖,有 [00k01k001][ablen]=[len×kb+len×klen]\begin{bmatrix}0&0&k\\0&1&k\\0&0&1\end{bmatrix}\begin{bmatrix}a\\b\\len\end{bmatrix}=\begin{bmatrix}len\times k\\b+len\times k\\len\end{bmatrix}

总结

如你所见,矩阵乘法对线段树的标记下传的简化效果是很大的。但是,也会带来一些复杂度上的问题。假设矩阵大小为 w×ww\times w ,就会对时间复杂度增加一个大小为 w3w^3 的复杂度,对空间复杂度增加一个大小为 w/w2w/w^2 的复杂度。所以,虽然可以对线段树所维护的向量上灵活地增加 0/1/len0/1/len ,但最好不要让向量大小超过 44
最后,也希望这篇博客能为你带来帮助,感谢你的浏览!

参考资料

  1. OI-Wiki
  2. https://www.luogu.com.cn/article/ypgkm4vg
  3. https://www.luogu.com.cn/article/d04azg6j
  4. https://www.luogu.com.cn/article/tafs5gxk

评论

1 条评论,欢迎与作者交流。

正在加载评论...