社区讨论

不明白为什么我的线段树得开八倍空间

P3384【模板】重链剖分 / 树链剖分参与者 5已保存回复 8

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
8 条
当前快照
1 份
快照标识符
@mjgivzi1
此快照首次捕获于
2025/12/22 10:14
3 个月前
此快照最后确认于
2025/12/24 20:55
2 个月前
查看原帖
想不明白,正常线段树写法是tag记录子节点修改,我写的tag记录当前节点修改,同学说我写法很石,但是实测能过,就是必须得开八倍空间(因为学完线段树写的都是只用查询或者只用修改的,写法一直不规范
我写的pushup会往下跑两次,但叶子节点不pushup的断言(算是吧)过了,最多也就次底层往下跑2层,理论上4倍空间足矣啊(因为线段树板子直接40倍空间没发现问题,这里只开4倍wa了,才发现有问题
代码:
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;

int n,m,r,q,idx;
int a[100005],fa[100005],de[100005],siz[100005],dfn[100005],top[100005],rnk[100005],sun[100005],bot[100005],tag[800005],ans[800005];
vector<int> e[100005];

void pushdown(int u,int l,int r) {
	ans[u] = ((r-l+1)*tag[u]+ans[u])%q;
	tag[u<<1] = (tag[u<<1]+tag[u])%q;
	tag[u<<1|1] = (tag[u<<1|1]+tag[u])%q;
	tag[u] = 0;
  //pushdown只是*2
}
void pushup(int u,int l,int r) {
	int mid=(l+r)>>1;
	if(tag[u<<1])
		pushdown(u<<1,l,mid);
	if(tag[u<<1|1])
		pushdown(u<<1|1,mid+1,r);
	ans[u] = (ans[u<<1]+ans[u<<1|1])%q;
  //pushup虽然*4但叶子节点不可能触发
}
void build(int u,int l,int r) {
	if(l == r) {
		ans[u] = a[rnk[l]]%q;
		return;
	}
	int mid=(l+r)>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	ans[u] = (ans[u<<1]+ans[u<<1|1])%q;
}
int querry(int u,int l,int r,int x,int y) {
	if(tag[u])
		pushdown(u,l,r);
	if(x<=l&&r<=y)
		return ans[u];
	int mid=(l+r)>>1,sum=0;
	if(x<=mid)
		sum += querry(u<<1,l,mid,x,y);
	if(mid<y)
		sum += querry(u<<1|1,mid+1,r,x,y);
	return sum%q;
}
void update(int u,int l,int r,int x,int y,int num) {
	if(tag[u])
		pushdown(u,l,r);
	if(x<=l&&r<=y) {
		tag[u] = num;
		pushdown(u,l,r);
		return;
	}
	int mid=(l+r)>>1;
	if(x<=mid)
		update(u<<1,l,mid,x,y,num);
	if(mid<y)
		update(u<<1|1,mid+1,r,x,y,num);
    if(l == r)
        exit(0);//这么写交上去过了
	pushup(u,l,r);
}
void dfs1(int u,int f) {
	fa[u]=f,de[u]=de[f]+1,siz[u]=1;
	for(int v:e[u])
		if(v!=f) {
			dfs1(v,u);
			siz[u] += siz[v];
			if(siz[v] > siz[sun[u]])
				sun[u] = v;
		}
}
void dfs2(int u,int ftop) {
	top[u]=ftop,dfn[u]=++idx,rnk[idx]=u;
	if(sun[u])
		dfs2(sun[u],ftop); 
	for(int v:e[u])
		if(v!=fa[u]&&v!=sun[u])
			dfs2(v,v);
	bot[u]=idx;
}
int pathsum(int u,int v) {
	int sum=0;
	while(top[u]!=top[v]) {
		if(de[top[u]]<de[top[v]])
			swap(u,v);
		sum = (sum+querry(1,1,n,dfn[top[u]],dfn[u]))%q;
		u = fa[top[u]];
	}
	if(dfn[u]<dfn[v])
		swap(u,v);
	sum = (sum+querry(1,1,n,dfn[v],dfn[u]))%q;
	return sum;
}
void pathadd(int u,int v,int num) {
	while(top[u]!=top[v]) {
		if(de[top[u]]<de[top[v]])
			swap(u,v);
		update(1,1,n,dfn[top[u]],dfn[u],num);
		u = fa[top[u]];
	}
	if(de[u]<de[v])
		swap(u,v);
	update(1,1,n,dfn[v],dfn[u],num);
	return;
}

signed main() {
	int x,y,c,z;
	cin>>n>>m>>r>>q;
	for(int i=1;i<=n;i++)
		cin>>a[i];
	for(int i=1;i<n;i++) {
		cin>>x>>y;
		e[x].push_back(y);
		e[y].push_back(x);
	}
	dfs1(r,0);
	dfs2(r,r);
	build(1,1,n);
	for(int i=1;i<=m;i++) {
		cin>>c;
		if(c==1) {
			cin>>x>>y>>z;
			z=z%q;
			pathadd(x,y,z);
		}
		if(c==2) {
			cin>>x>>y;
			cout<<pathsum(x,y)<<endl;
		}
		if(c==3) {
			cin>>x>>z;
			z=z%q;
			update(1,1,n,dfn[x],bot[x],z);
		}
		if(c==4) {
			cin>>x;
			cout<<querry(1,1,n,dfn[x],bot[x])<<endl;
		}
	}
	return 0;
}

回复

8 条回复,欢迎继续交流。

正在加载回复...