社区讨论

这代码哪里有UB?

学术版参与者 2已保存回复 1

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
1 条
当前快照
1 份
快照标识符
@lo99nh7g
此快照首次捕获于
2023/10/28 07:50
2 年前
此快照最后确认于
2023/10/28 07:50
2 年前
查看原帖
CF788E
我同一份代码交了两次C++11WA的点不同,并且在本地测是对的,分别是#4 和 #9
而且同一份代码交C++11和C++14 WA的点不同
不知道哪里写出UB了,求帮忙看看
CPP
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+5,mod=1e9+7;
int n,m,a[N],b[N],L[N],R[N];
int kd[N],ans;
namespace Bit{
	int c[N];
	void add(int x,int k){
		for(;x<=m;x+=x&-x)c[x]+=k;
	}
	int ask(int x){
		int ret=0;
		for(;x;x-=x&-x)ret+=c[x];
		return ret;
	}
	void cl(){memset(c,0,sizeof(c));}
}
using Bit::add;
using Bit::ask;
int rt[N],d[N],ls[N],rs[N],pos[N];
int suml[N],sumr[N],sl[N],sr[N];
int siz[N],cnt;
int New(int x){
	int p=++cnt;
	d[p]=1LL*rand()*rand()%mod;
	pos[p]=x;
	return p;
}
void update(int p){
	if(!p)return;
	siz[p]=siz[ls[p]]+siz[rs[p]]+1; 
	suml[p]=(1LL*suml[ls[p]]+suml[rs[p]]+L[pos[p]])%mod;
	sumr[p]=(1LL*sumr[ls[p]]+sumr[rs[p]]+R[pos[p]])%mod;
	sr[p]=(1LL*(R[pos[p]]+sumr[rs[p]])*(siz[ls[p]]+1)%mod+sr[ls[p]]+sr[rs[p]])%mod;
	sl[p]=(1LL*(L[pos[p]]+suml[ls[p]])*(siz[rs[p]]+1)%mod+sl[ls[p]]+sl[rs[p]])%mod;
}
void split(int p,int k,int &x,int &y){
	if(!p){x=y=0;return;}
	if(pos[p]<=k){x=p;split(rs[p],k,rs[p],y);}
	else{y=p;split(ls[p],k,x,ls[p]);}
	update(p);
}
int merge(int x,int y){
	if(!x||!y)return x|y;
	if(d[x]>d[y]){rs[x]=merge(rs[x],y);update(x);return x;}
	else{ls[y]=merge(x,ls[y]);update(y);return y;} 
}
void ins(int x){
	int val=a[x],p,q;
	split(rt[val],x,p,q);
	(ans+=1LL*suml[p]*sumr[q]%mod)%=mod;
	(ans+=1LL*R[x]*(sl[p]-suml[p])%mod)%=mod;
	(ans+=1LL*L[x]*(sr[q]-sumr[q])%mod)%=mod;
	rt[val]=merge(merge(p,New(x)),q);
}
void del(int x){
	int val=a[x],p,q,h;
	split(rt[val],x,p,q);
	split(p,x-1,p,h);
	(ans-=1LL*suml[p]*sumr[q]%mod)%=mod;
	(ans-=1LL*R[x]*(sl[p]-suml[p])%mod)%=mod;
	(ans-=1LL*L[x]*(sr[q]-sumr[q])%mod)%=mod;
	rt[val]=merge(p,q);
}
int main(){
	srand(time(0));srand(rand());
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
		b[i]=a[i];
		kd[i]=1;
	}
	sort(b+1,b+n+1);
	m=unique(b+1,b+n+1)-b-1;
	for(int i=1;i<=n;i++)
		a[i]=lower_bound(b+1,b+m+1,a[i])-b;
	for(int i=1;i<=n;i++)L[i]=ask(a[i]),add(a[i],1);
	Bit::cl();
	for(int i=n;i>=1;i--)R[i]=ask(a[i]),add(a[i],1);
	for(int i=1;i<=n;i++)ins(i);
	int T;scanf("%d",&T);
	while(T--){
		int op,x;
		scanf("%d%d",&op,&x);
		if(op==1&&kd[x]==1)del(x),kd[x]=0;
		if(op==2&&kd[x]==0)ins(x),kd[x]=1;
		printf("%d\n",(ans%mod+mod)%mod);
	}
	return 0;
}//

回复

1 条回复,欢迎继续交流。

正在加载回复...