社区讨论

求助站外题

题目总版参与者 2已保存回复 2

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
2 条
当前快照
1 份
快照标识符
@lo183m1s
此快照首次捕获于
2023/10/22 16:44
2 年前
此快照最后确认于
2023/11/02 16:34
2 年前
查看原帖
就是给定一棵带权树(节点数50000,边权<220)(节点数\leq 50000,边权<2^{20})
定义
f(u,v)f(u,v)为从uuvv的路经上边的 按位与 值,
g(u,v)g(u,v)为从uuvv的路经上边的 按位或 值
i=1nj=i+1nf(i,j)g(i,j)mod998244353\sum^{n}_{i=1}\sum^{n}_{j=i+1}f(i,j)g(i,j) \mod 998244353
我的想法是,根据乘法分配律,f(i,j)g(i,j)f(i,j)和g(i,j)每一位都可单独处理,即枚举x,yx,y,计算有多少对i,ji,j满足f(i,j)f(i,j)的第xx位和g(i,j)g(i,j)的第yy位为11
枚举了x,yx,y,那么可以通过将f(i,j)f(i,j)的第xx位为11的对数、和f(i,j)f(i,j)的第xx位为11g(i,j)g(i,j)的第y位为00的对数相减得到。对于前者,只留下第xx位为11的边,然后计算连通的点对数;对于后者,只留下第xx位为11且第yy位为00的边,然后计算连通的点对数。
我这样做TLE了3个点,有没有人能优化一下计算点对的速度
下面是这个菜鸡的代码
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
int hd[50007],nxt[100007],to[100007],w[100007],cnt,sum;
void adde(int u,int v,int wg){
	nxt[++cnt]=hd[u];
	hd[u]=cnt;
	to[cnt]=v;
	w[cnt]=wg;
}
int n,ls[50007],ls2[50007],cnt2;
void dfs1(int u,int fa,int he){
	ls[u]=cnt2;
	ls2[cnt2]++;
	for(int i=hd[u];i;i=nxt[i]){
		if(to[i]==fa)
		{
			continue;
		}
		if((w[i]&he)==0)
		{
			continue;
		}
		dfs1(to[i],u,he);
	}
}
void dfs2(int u,int fa,int he,int hou){
	ls[u]=cnt2;
	ls2[cnt2]++;
	for(int i=hd[u];i;i=nxt[i]){
		if(to[i]==fa)
		{
			continue;
		}
		if((w[i]&he)==0)
		{
			continue;
		}
		if((w[i]&hou)!=0)
		{
			continue;
		}
		dfs2(to[i],u,he,hou);
	}
}
signed main(){
	cin>>n;	
	for(int i=1;i<n;i++){
		int p1,p2,p3;
		cin>>p1>>p2>>p3;
		adde(p1,p2,p3);
		adde(p2,p1,p3);
	}
	for(int i=0;i<20;i++){
		for(int j=0;j<20;j++){
			int p1=(1<<i),p2=(1<<j);
			cnt2=0;
			for(int k=1;k<=n;k++){
				if(!ls[k])
				{
					++cnt2;
					dfs1(k,0,p1);
				}
			}
			int s1=0;
			for(int i=1;i<=cnt2;i++){
				s1=(s1+((ls2[i]*(ls2[i]-1)/2)%mod))%mod;
			}
			for(int i=1;i<=n;i++){
				ls[i]=0;
			}
			for(int i=1;i<=cnt2;i++){
				ls2[i]=0;
			}
			cnt2=0;
			for(int k=1;k<=n;k++){
				if(!ls[k])
				{
					++cnt2;
					dfs2(k,0,p1,p2);
				}
			}
			int s2=0;
			for(int i=1;i<=cnt2;i++){
				s2=(s2+((ls2[i]*(ls2[i]-1)/2)%mod))%mod;
			}
			for(int i=1;i<=n;i++){
				ls[i]=0;
			}
			for(int i=1;i<=cnt2;i++){
				ls2[i]=0;
			}
			sum=(sum+((s1-s2)*((p1*p2)%mod)%mod))%mod;
		}
	}
	cout<<sum;
	return 0;
}

回复

2 条回复,欢迎继续交流。

正在加载回复...