专栏文章

题解:P8334 [ZJOI2022] 深搜

P8334题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@mipzn1en
此快照首次捕获于
2025/12/03 20:33
3 个月前
此快照最后确认于
2025/12/03 20:33
3 个月前
查看原文
怎么大家都会拆贡献?来写一篇与拆贡献无关的做法。
若对每个点 uu 钦定其儿子访问顺序,我们发现 f(x,y)f(x,y) 等于 xyx\to y 路径上的最小点权或每次 uv(ysubtree(v))u\to v(y\in subtree(v)) 之前 uu 访问的其他儿子的子树最小值。即若干个子树 min\min 与单点点权取 min\min 的形式。
我们不妨对每个 yy 考虑其所有祖先对它的贡献。我们称上述的某个子树访问或单点访问为一个贡献点,钦定是哪个贡献点对 yy 产生了贡献,那么就要求所有比它小的贡献点均没有被取到,也就是对于某个点 uu 钦定了一些儿子的访问顺序在我们 xyx\to y 路径上的儿子的后面。不难发现若钦定了 ww 个儿子,则其概率为 1i=1w1i×(i+1)1-\sum_{i=1}^w \frac 1 {i\times (i+1)}。枚举 yyxx,按权值从小往大处理所有 xyx\to y 的贡献点权值,即可做到 O(n3)O(n^3)
考虑上述过程的优化:对于一个贡献点,每次加入新的贡献点导致其权值改变时,其变化量不与 xx 相关。于是我们可以按照 dfs 序处理 yy,每次加入一些贡献点,处理其与前面的贡献点的影响,简单统计答案即可。值得注意的是对于点 uu 的每个儿子 vv,贡献点存在一定差别,但若将 vv 按照其子树 min\min 排序后则邻项贡献点差别数量为 11。因此可以简单做到 O(n2)O(n^2)
放一份 2525 分的代码:
CPP
#include <bits/stdc++.h>
using namespace std;
const int N=4e5+5,mod=998244353;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*f;
}
inline int ksm(int a,int b){
	int ans=1;
	while(b){
		if(b&1)ans=1ll*ans*a%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}return ans;
}
int T,n,rt,a[N],mn[N],ans,cnt,inv[N],vvv[N],d[N];
vector<int>e[N];
inline void dfs(int x,int fa){
	mn[x]=a[x];
	if(fa)mn[x]=min(mn[x],a[fa]);
	for(int y : e[x]){
		if(y^fa)d[y]=d[x]+1,dfs(y,x),mn[x]=min(mn[x],mn[y]);
	}
}
struct node{
	int val,w,w2,op,d,id;
}tmp[N];
struct DATA{
	int val,d,id;
}p[N<<2];
inline bool cmp(int a,int b){
	return mn[a]<mn[b];
}
inline void insert(int val,int w,int w2,int d,int r,int id){
	for(int j = 1;j<=cnt;j++)if(tmp[j].val>val||tmp[j].val==val&&(tmp[j].d<d||tmp[j].d==d&&tmp[j].id>id)){
		if(w2==0)tmp[j].op++;
		else tmp[j].w=tmp[j].w*1ll*w2%mod;
	}
	tmp[++cnt]={val,0,w2,0,d,id};
	int ggg=d;
	for(int j = 1;j<cnt;j++)if((tmp[j].val<val||tmp[j].d==d&&tmp[j].id<id)&&!tmp[j].op)ggg=(ggg-tmp[j].w+mod)%mod;
	tmp[cnt].w=ggg*1ll*r%mod;
}
inline void del(int id){
	swap(tmp[id],tmp[cnt]);
	int w=ksm(tmp[cnt].w2,mod-2);
	for(int i = 1;i<cnt;i++)if(tmp[i].val>tmp[cnt].val||tmp[i].val==tmp[cnt].val&&(tmp[i].d<tmp[cnt].d||tmp[i].d==tmp[cnt].d&&tmp[i].id>tmp[cnt].id)){
		if(tmp[cnt].w2==0)tmp[i].op--;
		else tmp[i].w=1ll*tmp[i].w*1ll*w%mod;
	}
	cnt--;
}
inline void dfs2(int x,int fa){
	int gg=d[x];
	for(int i = 1;i<=cnt;i++){
		if(!tmp[i].op&&tmp[i].val<a[x])ans=(ans + tmp[i].val*1ll*tmp[i].w)%mod,gg=(gg+mod-tmp[i].w)%mod;
	}
	ans=(ans+1ll*a[x]*gg)%mod;
	vector<int>v;
	for(auto y : e[x])if(y!=fa)v.push_back(y);
	sort(v.begin(),v.end(),cmp);
	if(v.size())insert(a[x],inv[v.size()],0,d[x],1,v.size());
	for(int i = v.size()-1,y;i>=1;i--){
		y=v[i];
		insert(mn[y],inv[i]*1ll*inv[i+1]%mod,i*1ll*inv[i+1]%mod,d[x],inv[i+1],i);
	}
	int now = cnt;
	for(int i = 0,y;i<v.size();i++){
		y=v[i];
		dfs2(y,x);
		del(now--);
		insert(mn[y],inv[i+1]*1ll*inv[i+2]%mod,(i+1ll)*inv[i+2]%mod,d[x],inv[i+2],i);
	}
	for(int i = 0;i<v.size();i++)del(cnt);
}
int main(){
	T=read();
	while(T--){
		n=read(),rt=read();ans=0;
		inv[0]=inv[1]=1;
		cnt=0;
		for(int i = 2;i<=n+1;i++)inv[i]=inv[mod%i]*1ll*(mod-mod/i)%mod;
		for(int i = 1;i<=n;i++)e[i].clear(),a[i]=read();
		for(int i = 1,x,y;i<n;i++)x=read(),y=read(),e[x].push_back(y),e[y].push_back(x);
		d[rt]=1;dfs(rt,0);
		dfs2(rt,0);
		printf("%d\n",ans);
	}
	return 0;
}
不难发现,贡献点之间影响操作为区间乘,单点改,询问操作为区间和。不过需要维护乘 00 的撤销操作。我们可以暴力打标记记录哪些点被归零,并放入栈中等待撤销。由于线段树结点总访问次数为 O(nlogn)O(n\log n) 级别,因此撤销操作复杂度也为 O(nlogn)O(n\log n)
时空复杂度:O(nlogn)O(n\log n)
贴一份代码:
CPP
#include <bits/stdc++.h>
using namespace std;
const int N=4e5+5,mod=998244353;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*f;
}
inline int ksm(int a,int b){
	int ans=1;
	while(b){
		if(b&1)ans=1ll*ans*a%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}return ans;
}
inline int md(int x){
	return x>=mod?x-mod:x;
}
int T,n,rt,a[N],mn[N],ans,cnt,inv[N],vvv[N],d[N],num,q[N],C;
vector<int>e[N];
inline void dfs(int x,int fa){
	mn[x]=a[x];
	if(fa)mn[x]=min(mn[x],a[fa]);
	for(int y : e[x]){
		if(y^fa)d[y]=d[x]+1,dfs(y,x),mn[x]=min(mn[x],mn[y]);
	}
}
inline bool cmp(int a,int b){
	return mn[a]<mn[b];
}
struct node{
	int val,w,w2,op,d,id;
}tmp[N];
struct DATA{
	int val,d,id;
	inline bool operator<(DATA b){
		if(val^b.val)return val<b.val;
		if(d^b.d)return d>b.d;
		return id<b.id;
	}
}p[N<<1];
struct tree{
	int l,r,tag1,tag2,s1,s3;
}t[N<<3];
#define mid (t[p].l+t[p].r>>1)
inline void up(int p){
	t[p].s1=md(t[p<<1].s1+t[p<<1|1].s1),t[p].s3=md(t[p<<1].s3+t[p<<1|1].s3);
}
inline void build(int p,int l,int r){
	t[p]={l,r,0,1,0,0};
	if(l==r)return;
	build(p<<1,l,mid),build(p<<1|1,mid+1,r);
}
vector<pair<int,tree>>g[N];
inline void cg2(int p,int v);
inline void cg1(int p,int v){
	g[v].push_back({p,t[p]});
	t[p].tag1=v;t[p].tag2=1;t[p].s1=t[p].s3=0;
}
inline void cg2(int p,int v){
	t[p].tag2=1ll*v*t[p].tag2%mod,t[p].s1=1ll*v*t[p].s1%mod;t[p].s3=1ll*v*t[p].s3%mod;
}
inline void spread(int p,int op=1){
	if(t[p].tag1&&op){
		cg1(p<<1,t[p].tag1),cg1(p<<1|1,t[p].tag1);
		t[p].tag1=0;
	}
	if(t[p].tag2!=1){
		cg2(p<<1,t[p].tag2),cg2(p<<1|1,t[p].tag2);
		t[p].tag2=1;
	}
}
inline void change(int p,int l,int r,int v,int op){
	if(l<=t[p].l&&t[p].r<=r){
		if(!v)cg1(p,op);
		else cg2(p,v);
		return;
	}
	spread(p);
	if(!v)g[op].push_back({p,t[p]});
	if(l<=mid)change(p<<1,l,r,v,op);
	if(r>mid)change(p<<1|1,l,r,v,op);
	up(p);
}
inline void insert(int p,int pos,int val,int w){
	if(t[p].l==t[p].r){
		t[p].s1=w,t[p].s3=val*1ll*w%mod;
		return;
	}
	spread(p);
	if(pos<=mid)insert(p<<1,pos,val,w);
	else insert(p<<1|1,pos,val,w);
	up(p);
}
inline int query(int p,int l,int r,int op){
	if(l<=t[p].l&&t[p].r<=r){
		if(op==1)return t[p].s1;
		return t[p].s3;
	}
	spread(p);int ans=0;
	if(l<=mid)ans=query(p<<1,l,r,op);
	if(r>mid)ans=md(ans+query(p<<1|1,l,r,op));
	return ans;
}
inline void insert(int val,int w,int w2,int d,int r,int id){
	int pos = lower_bound(p+1,p+num+1,(DATA){val,d,id})-p;
	if(pos^num)change(1,pos+1,num,w2,d);
	tmp[++cnt]={val,0,w2,0,d,id};
	int ggg = d;
	if(pos>1)ggg=(ggg+mod-query(1,1,pos-1,1))%mod;
	insert(1,pos,val,ggg*1ll*r%mod);
}
inline void del(int id){
	swap(tmp[id],tmp[cnt]);
	int w = ksm(tmp[cnt].w2,mod-2);
	int pos = lower_bound(p+1,p+num+1,(DATA){tmp[cnt].val,tmp[cnt].d,tmp[cnt].id})-p;
	insert(1,pos,0,0);
	if(pos^num){
		if(w)change(1,pos+1,num,w,-1);
		else{
			for(auto[i,k]:g[tmp[cnt].d]){
				spread(i,0);
				t[i]=k;
			}
			g[tmp[cnt].d].clear();
		}
	}
	cnt--;
}
inline void dfs2(int x,int fa){
	int pos = lower_bound(p+1,p+num+1,(DATA){a[x],n+1,0})-p-1;
	if(pos>=1)ans=(ans+query(1,1,pos,2)+1ll*a[x]*(d[x]+mod-query(1,1,pos,1)))%mod;
	else ans=(ans+1ll*a[x]*d[x])%mod;
	vector<int>v;
	for(auto y : e[x])if(y!=fa)v.push_back(y);
	sort(v.begin(),v.end(),cmp);
	if(v.size())insert(a[x],inv[v.size()],0,d[x],1,v.size());
	for(int i = v.size()-1,y;i>=1;i--){
		y=v[i];
		insert(mn[y],inv[i]*1ll*inv[i+1]%mod,i*1ll*inv[i+1]%mod,d[x],inv[i+1],i);
	}
	int now = cnt;
	for(int i = 0,y;i<v.size();i++){
		y=v[i];
		dfs2(y,x);
		if(i!=v.size()-1){
			del(now--);
			insert(mn[y],inv[i+1]*1ll*inv[i+2]%mod,(i+1ll)*inv[i+2]%mod,d[x],inv[i+2],i);
		}
	}
	for(int i = 0;i<v.size();i++)del(cnt);
}
inline void predfs(int x,int fa){
	vector<int>v;
	for(auto y : e[x])if(y!=fa)v.push_back(y);
	sort(v.begin(),v.end(),cmp);
	if(v.size()){
		p[++num]=(DATA){a[x],d[x],(int)v.size()};
		for(int i = 0,y;i<v.size();i++){
			y=v[i];
			predfs(y,x);
			p[++num]=(DATA){mn[y],d[x],i};
		}
	}
}
int main(){
	T=read();
	while(T--){
		n=read(),rt=read();ans=0;
		inv[0]=inv[1]=1;
		cnt=0;num=0;
		for(int i = 2;i<=n+1;i++)inv[i]=inv[mod%i]*1ll*(mod-mod/i)%mod;
		for(int i = 1;i<=n;i++)e[i].clear(),a[i]=read();
		for(int i = 1,x,y;i<n;i++)x=read(),y=read(),e[x].push_back(y),e[y].push_back(x);
		d[rt]=1;dfs(rt,0);predfs(rt,0);
		sort(p+1,p+num+1);
		if(num)build(1,1,num);
		dfs2(rt,0);
		printf("%d\n",ans);
	}
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...