专栏文章

线段树的一些应用

个人记录参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@min6zkbp
此快照首次捕获于
2025/12/01 21:35
3 个月前
此快照最后确认于
2025/12/01 21:35
3 个月前
查看原文
最近几十天总算把线段树的一些模板题吃透了。
【例 1】:序列,单点加 kk,区间求和。(洛谷 P3368)
代码:
CPP
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=500009;
struct Segment{
	ll l,r,sum;
}tr[4*N];
ll x[N];
void build(ll u,ll l,ll r){
	ll mid=(l+r)/2;
	if(l==r){
		tr[u]=(Segment){l,r,x[l]};
		return;
	}
	build(u*2,l,mid);
	build(u*2+1,mid+1,r);
	tr[u]=(Segment){l,r,tr[u*2].sum+tr[u*2+1].sum};
}
void add(ll u,ll x,ll delta){
	tr[u].sum+=delta;
	if(tr[u].l==tr[u].r) return; 
	if(x<=tr[u*2].r) add(u*2,x,delta);
	else add(u*2+1,x,delta);
}
ll rsq(ll u,ll l,ll r){
	if(tr[u].l>=l&&tr[u].r<=r) return tr[u].sum;
	else if(tr[u].r<l||tr[u].l>r) return 0;
	else return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
int main(){
	ll n,m;
	cin>>n>>m;
	for(ll i=1;i<=n;i++) cin>>x[i];
	build(1,1,n);
	for(ll i=1;i<=m;i++){
		ll op,x,y;
		cin>>op>>x>>y;
		if(op==1) add(1,x,y);
		else cout<<rsq(1,x,y)<<endl;
	}
	return 0;
}
【例 2】序列,区间加 kk,区间求和。(洛谷 P3372)
只需简单 pushdown。代码:
CPP
#include<bits/stdc++.h>
using namespace std;
const int N=100010,M=4*N;
int n,m,a[N],b[M];
typedef long long ll;
ll sum[M];
void build(int k,int l,int r){
	if(l==r){
		sum[k]=a[l];
		return;
	}
	int mid=(l+r)>>1;
	build(k*2,l,mid);
	build(k<<1|1,mid+1,r);
	sum[k]=sum[k*2]+sum[k<<1|1];
}
void add(int k,int l,int r,int x){
	b[k]+=x;
	sum[k]+=(ll)x*(r-l+1);
}
void pushdown(int k,int l,int r,int mid){
	if(b[k]==0) return;
	add(k*2,l,mid,b[k]);
	add(k<<1|1,mid+1,r,b[k]);
	b[k]=0;
}
ll query(int k,int l,int r,int x,int y){
	if(l>=x && r<=y) return sum[k];
	int mid=(l+r)>>1;
	ll res=0;
	pushdown(k,l,r,mid);
	if(x<=mid) res+=query(k*2,l,mid,x,y);
	if(mid<y) res+=query(k<<1|1,mid+1,r,x,y);
	return res;
}
void modify(int k,int l,int r,int x,int y,int t){
	if(l>=x && r<=y) return add(k,l,r,t);
	int mid=(l+r)>>1;
	pushdown(k,l,r,mid);
	if(x<=mid) modify(k*2,l,mid,x,y,t);
	if(mid<y) modify(k<<1|1,mid+1,r,x,y,t);
	sum[k]=sum[k*2]+sum[k<<1|1];
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	build(1,1,n);
	while(m--){
		int op,x,y,k;
		scanf("%d%d%d",&op,&x,&y);
		if(op==1){
			scanf("%d",&k);
			modify(1,1,n,x,y,k);
		}
		else printf("%lld\n",query(1,1,n,x,y));
	}
	return 0;
}
【例 3】序列,区间加乘 kk,区间求和。
注意 pushdown 时先乘后加。
CPP
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=800009;
ll n,m,MOD;
struct Segment{
	ll l;
	ll r;
	ll len;
	ll sum;
	ll toMul;
	ll toAdd;
};
Segment tr[N*4];
int x[N];
void build(ll u,ll l,ll r){
	if(l==r){
		tr[u]=(Segment){l,r,r-l+1,x[l],1,0};
		return;
	}
	ll m=(l+r)/2; 
	build(u*2,l,m);
	build(u*2+1,m+1,r);
	tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,1,0};
}
void pushdown(ll u){
	if(tr[u].l==tr[u].r) return;
	ll M=tr[u].toMul;
	ll A=tr[u].toAdd;
	tr[u].toMul=1;
	tr[u].toAdd=0;
	(tr[u*2].toMul*=M)%=MOD;
	((tr[u*2].toAdd*=M)+=A)%=MOD;
	((tr[u*2].sum*=M)+=tr[u*2].len*A)%=MOD;
	(tr[u*2+1].toMul*=M)%=MOD;
	((tr[u*2+1].toAdd*=M)+=A)%=MOD;
	((tr[u*2+1].sum*=M)+=tr[u*2+1].len*A)%=MOD;
}
void add(ll u,ll &l,ll &r,ll &delta){
	pushdown(u);
	if(r<tr[u].l || tr[u].r<l) return;
	if(l<=tr[u].l && tr[u].r<=r){
		(tr[u].toAdd+=delta)%=MOD;
		(tr[u].sum+=tr[u].len*delta)%=MOD;
		return;
	}
	add(u*2,l,r,delta);
	add(u*2+1,l,r,delta);
	tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum)%MOD;
}
void mul(ll u,ll &l,ll &r,ll &delta){
	pushdown(u);
	if(r<tr[u].l || tr[u].r<l) return;
	if(l<=tr[u].l && tr[u].r<=r){
		(tr[u].toMul*=delta)%=MOD;
		(tr[u].toAdd*=delta)%=MOD;
		(tr[u].sum*=delta)%=MOD;
		return;
	}
	mul(u*2,l,r,delta);
	mul(u*2+1,l,r,delta);
	tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum)%MOD;
}
ll rsq(ll u,ll &l,ll &r){
	pushdown(u);
	if(r<tr[u].l || tr[u].r<l) return 0;
	else if(l<=tr[u].l && tr[u].r<=r) return tr[u].sum;
	return (rsq(u*2,l,r)+rsq(u*2+1,l,r))%MOD;
}
int main(){
	scanf("%lld %lld %lld",&n,&m,&MOD);
	for(int i=1;i<=n;i++) scanf("%d",&x[i]);
	build(1,1,n);
	for(int i=1;i<=m;i++){
		ll t,x,y,z;
		scanf("%lld %lld %lld",&t,&x,&y);
		if(t==3) printf("%lld\n",rsq(1,x,y));
		else if(t==2){
			scanf("%lld",&z);
			add(1,x,y,z);
		}
		else{
			scanf("%lld",&z);
			mul(1,x,y,z);
		}
	}
	return 0;
}
【例 4】区间加,求区间平均值。(洛谷 P14233)
注意这题需要将时间哈希。区间平均值 = 区间和 / 区间长度。
代码:
CPP
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=300009,INF=2e9;
ll n,m,MOD;
struct Segment{
	ll l;
	ll r;
	ll len;
	ll sum;
	ll toAdd;
};
Segment tr[N*4];
void build(ll u,ll l,ll r){
	if(l==r){
		tr[u]=(Segment){l,r,r-l+1,0,0};
		return;
	}
	ll m=(l+r)/2; 
	build(u*2,l,m);
	build(u*2+1,m+1,r);
	tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,0};
}
void pushdown(ll u){
	if(tr[u].l==tr[u].r) return;
	ll A=tr[u].toAdd;
	tr[u].toAdd=0;
	tr[u*2].sum+=tr[u*2].len*A;
	tr[u*2+1].sum+=tr[u*2+1].len*A;
	tr[u*2].toAdd+=A;
	tr[u*2+1].toAdd+=A;
}
void add(ll u,ll l,ll r,ll delta){
	pushdown(u);
	if(r<tr[u].l||tr[u].r<l) return;
	if(l<=tr[u].l&&tr[u].r<=r){
		tr[u].toAdd+=delta;
		tr[u].sum+=tr[u].len*delta;
		return;
	}
	add(u*2,l,r,delta);
	add(u*2+1,l,r,delta);
	tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
}
ll rsq(ll u,ll l,ll r){
	pushdown(u);
	if(r<tr[u].l||tr[u].r<l) return 0;
	else if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
	return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
int main(){
	ll n;
	cin>>n;
	build(1,1,86400);
	for(ll i=1;i<=n;i++){
		ll x,y,z,a,b,c;
		char op;
		cin>>x>>op>>y>>op>>z>>op>>a>>op>>b>>op>>c;
		ll l=x*3600+y*60+z;
		ll r=a*3600+b*60+c;
		l++;r++;
		if(l>r){
			add(1,l,86400,1);
			add(1,1,r,1);
		}
		else add(1,l,r,1); 
	}
	ll q;
	cin>>q;
	for(ll i=1;i<=q;i++){
		ll x,y,z,a,b,c;
		char op;
		cin>>x>>op>>y>>op>>z>>op>>a>>op>>b>>op>>c;
		ll l=x*3600+y*60+z;
		ll r=a*3600+b*60+c;
		l++;r++;
		if(l>r)
			cout<<fixed<<setprecision(10)<<((double)rsq(1,l,86400)+rsq(1,1,r))/(r+86400-l+1)<<endl;
		else
			cout<<fixed<<setprecision(10)<<((double)rsq(1,l,r))/(r-l+1)<<endl;
	}
	return 0;
}
【例 5】区间加,区间 sin 和。(洛谷 P6327)
注意到 sin(α+β)=sinαcosβ+sinβcosα\sin(\alpha+\beta)=\sin \alpha \cos\beta+\sin\beta\cos\alphacos(α+β)=cosαcosβsinαsinβ\cos(\alpha+\beta)=\cos \alpha \cos \beta-\sin \alpha \sin \beta,pushdown 显然。
CPP
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=300009;
int x[N],n;
struct Segment{
	int l;
	int r;
	double sumsin;
	double sumcos;
	int toAdd;
}tr[N*4];
void build(int u,int l,int r){
	if(l==r){
		tr[u]=(Segment){l,r,sin(x[l]),cos(x[l]),0};
		return;
	}
	int mid=(l+r)/2;
	build(u*2,l,mid);
	build(u*2+1,mid+1,r);
	tr[u]=(Segment){l,r,tr[u*2].sumsin+tr[u*2+1].sumsin,tr[u*2].sumcos+tr[u*2+1].sumcos,0};
}
void sincosadd(int u,double sina,double cosa){
	double x=tr[u].sumsin,y=tr[u].sumcos;
	tr[u].sumsin=x*cosa+y*sina;
	tr[u].sumcos=y*cosa-x*sina;
}
void pushdown(int u){
	int A=tr[u].toAdd;
	tr[u].toAdd=0;
	if(!A) return;
	double sinx=sin(A),cosx=cos(A);
	sincosadd(u*2,sinx,cosx);
	sincosadd(u*2+1,sinx,cosx);
	tr[u*2].toAdd+=A;
	tr[u*2+1].toAdd+=A;
}
void add(int u,int l,int r,int delta){
	pushdown(u);
	if(tr[u].r<l||r<tr[u].l) return;
	if(l<=tr[u].l&&tr[u].r<=r){
		sincosadd(u,sin(delta),cos(delta));
		tr[u].toAdd+=delta;
		return;
	}
	add(u*2,l,r,delta);
	add(u*2+1,l,r,delta);
	tr[u].sumsin=tr[u*2].sumsin+tr[u*2+1].sumsin;
	tr[u].sumcos=tr[u*2].sumcos+tr[u*2+1].sumcos;
}
double rsq(int u,int l,int r){
	pushdown(u);
	if(tr[u].r<l||r<tr[u].l) return 0;
	if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sumsin;
	return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
signed main(){
	ios::sync_with_stdio(0);
	cin>>n;
	for(int i=1;i<=n;i++) cin>>x[i];
	build(1,1,n);
	int m;
	cin>>m;
	for(int i=1;i<=m;i++){
		int op,l,r;
		cin>>op>>l>>r;
		if(op==2) cout<<fixed<<setprecision(1)<<rsq(1,l,r)<<'\n';
		else{
			int v;
			cin>>v;
			add(1,l,r,v);
		}
	}
	return 0;
}
【例 6】区间加,区间方差。(洛谷 P1471)
简单的题目,平均值极其简单,方差 = 平方和/区间长度-平均值*平均值。
CPP
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll MOD=1e9+7;
const ll N=300009;
struct Segment{
	ll l;
	ll r;
	ll len;
	double sum;
	double sqsum;
	double toAdd;
}tr[N*4];
double a[N];
ll n,m;
void build(ll u,ll l,ll r){
	if(l==r){
		tr[u]=(Segment){l,r,r-l+1,a[l],a[l]*a[l],0};
		return;
	}
	ll mid=(l+r)/2;
	build(u*2,l,mid);
	build(u*2+1,mid+1,r);
	tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,tr[u*2].sqsum+tr[u*2+1].sqsum,0};
}
void pushdown(ll u){
	double A=tr[u].toAdd;
	if(!A) return;
	tr[u].toAdd=0;
	tr[u*2].sqsum+=2*A*tr[u*2].sum+tr[u*2].len*A*A;
	tr[u*2+1].sqsum+=2*A*tr[u*2+1].sum+tr[u*2+1].len*A*A;
	tr[u*2].sum+=tr[u*2].len*A;
	tr[u*2+1].sum+=tr[u*2+1].len*A;
	tr[u*2].toAdd+=A;
	tr[u*2+1].toAdd+=A;
}
void add(ll u,ll l,ll r,double delta){
	pushdown(u);
	if(r<tr[u].l||tr[u].r<l) return;
	if(l<=tr[u].l&&tr[u].r<=r){
		tr[u].sqsum+=2*delta*tr[u].sum+delta*delta*tr[u].len;
		tr[u].sum+=tr[u].len*delta;
		tr[u].toAdd+=delta;
		return;
	}
	add(u*2,l,r,delta);
	add(u*2+1,l,r,delta);
	tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
	tr[u].sqsum=tr[u*2].sqsum+tr[u*2+1].sqsum;
}
double rsq1(ll u,ll l,ll r){
	pushdown(u);
	if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
	else if(tr[u].r<l||r<tr[u].l) return 0;
	else return rsq1(u*2,l,r)+rsq1(u*2+1,l,r);
}
double rsq2(ll u,ll l,ll r){
	pushdown(u);
	if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sqsum;
	else if(tr[u].r<l||r<tr[u].l) return 0;
	else return rsq2(u*2,l,r)+rsq2(u*2+1,l,r);
}
int main(){
	cin>>n>>m;
	for(ll i=1;i<=n;i++) cin>>a[i];
	build(1,1,n);
	for(ll i=1;i<=m;i++){
		ll op,l,r;
		cin>>op>>l>>r;
		if(op==1){
			double delta;
			cin>>delta;
			add(1,l,r,delta);
		}
		else if(op==2) cout<<fixed<<setprecision(4)<<rsq1(1,l,r)/(r-l+1)<<endl;
		else{
			double ave=rsq1(1,l,r)/(r-l+1);
			cout<<fixed<<setprecision(4)<<((rsq2(1,l,r)/(r-l+1))-ave*ave)<<endl;
		}
	}
	return 0;
}
【例 7】区间加,区间求和,区间求 min/max。(洛谷 P3130)
pushdown 在例 2 基础上,min 值就加上懒标记即可。
CPP
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=800009,INF=2e9;
ll n,m,MOD;
struct Segment{
	ll l;
	ll r;
	ll len;
	ll sum;
	ll mn;
	ll toAdd;
};
Segment tr[N*4];
ll x[N];
void build(ll u,ll l,ll r){
	if(l==r){
		tr[u]=(Segment){l,r,r-l+1,x[l],x[l],0};
		return;
	}
	ll m=(l+r)/2; 
	build(u*2,l,m);
	build(u*2+1,m+1,r);
	tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,min(tr[u*2].mn,tr[u*2+1].mn),0};
}
void pushdown(ll u){
	if(tr[u].l==tr[u].r) return;
	ll A=tr[u].toAdd;
	tr[u].toAdd=0;
	tr[u*2].sum+=tr[u*2].len*A;
	tr[u*2+1].sum+=tr[u*2+1].len*A;
	tr[u*2].mn+=A;
	tr[u*2+1].mn+=A;
	tr[u*2].toAdd+=A;
	tr[u*2+1].toAdd+=A;
}
void add(ll u,ll&l,ll&r,ll&delta){
	pushdown(u);
	if(r<tr[u].l||tr[u].r<l) return;
	if(l<=tr[u].l&&tr[u].r<=r){
		tr[u].toAdd+=delta;
		tr[u].sum+=tr[u].len*delta;
		tr[u].mn+=delta;
		return;
	}
	add(u*2,l,r,delta);
	add(u*2+1,l,r,delta);
	tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
	tr[u].mn=min(tr[u*2].mn,tr[u*2+1].mn);
}
ll rsq(ll u,ll&l,ll&r){
	pushdown(u);
	if(r<tr[u].l||tr[u].r<l) return 0;
	else if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
	return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
ll rmq(ll u,ll&l,ll&r){
	pushdown(u);
	if(r<tr[u].l||tr[u].r<l) return INF;
	else if(l<=tr[u].l&&tr[u].r<=r) return tr[u].mn;
	return min(rmq(u*2,l,r),rmq(u*2+1,l,r));
}
int main(){
	scanf("%lld %lld",&n,&m);
	for(ll i=1;i<=n;i++) scanf("%lld",&x[i]);
	build(1,1,n);
	for(ll i=1;i<=m;i++){
		char t;
		ll x,y,z;
		cin>>t>>x>>y; 
		if(t=='M') cout<<rmq(1,x,y)<<endl;
		else if(t=='S') cout<<rsq(1,x,y)<<endl;
		else{
			cin>>z;
			add(1,x,y,z);
		}
	}
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...