Description
- 给定一棵 n 个节点的树,每个节点有 qipi 的概率坍塌(指切断所有邻边)。
- 求期望的无序叶子节点对数。(其中叶子节点指恰有一条邻边的节点)
- 答案对 998244353 取模。
- 多测。
Solution
记
N(k) 为
k 在树上的相邻节点集,
dis(i,j) 为树上
i 与
j 路径的边数,简记
pi 为
i 号节点坍塌的概率。
首先考虑每个节点成为叶子节点的概率,由题可写出
vi=(1−pi)j∈N(i)∑(1−pj)k∈N(i)&k=j∏pk=(1−pi)(k∈N(i)∏pk)(j∈N(i)∑pj1−pj)。
Warning
期望点对数不可直接由期望点数求出。
Node=∑piai
Pair=∑pi2ai2−ai=2Node2−Node
因此我们只可直接考虑点对。
考虑到
dis(i,j)≥3 时,两点各自为叶子节点的两事件独立,而
dis(i,j)=1 和
dis(i,j)=2 的情况下,由于存在邻居节点重合事件不独立,需要另外考虑。
因此做出如下分讨:
dis(i,j)≥3
由于独立,贡献为
vi×vj。
dis(i,j)=1
显然
i 和
j 都不能坍塌,则若
(i,j) 要成为叶子点对,其连边则为唯一边,也即
i 和
j 的其余所有邻居节点坍塌。
贡献为
(1−pi)(1−pj)k∈N(i)∪N(j)&k=i,j∏pk。
dis(i,j)=2
考虑
i 与
j 路径上的唯一点
k。
若
k 不坍塌,则
(i,j) 要成为叶子点对,则除
k 以外的
i 与
j 的邻居节点全部坍塌。
此处贡献为
(1−pi)(1−pj)(1−pk)s∈N(i)∪N(j)&s=k∏ps。
若
k 坍塌,则
(i,j) 要成为叶子点对,则各有一个邻居节点未坍塌。
此处贡献为
pk(1−pi)(1−pj)(s∈N(i)&s=k∑(1−ps)t∈N(i)&t=k,s∏pt)(s∈N(j)&s=k∑(1−ps)t∈N(j)&t=k,s∏pt)。
两者加和即可。
考虑如何快速计算贡献。
令
Ai=j∈N(i)∏pj,
Bi=j∈N(i)∑pj1−pj。
两者都可以在输入的时候顺带线性处理。
-
dis(i,j)≥3
vi=(1−pi)AiBi。
可以先利用前缀和线性计算
i=1∑nj=i+1∑nvi×vj,然后再在计算
dis(i,j)=1 和
dis(i,j)=2 的时候将多计算的部分剔除。
-
dis(i,j)=1
枚举边即可线性计算
(1−pi)(1−pj)k∈N(i)∪N(j)&k=i,j∏pk=pipj(1−pi)(1−pj)AiAj。
-
dis(i,j)=2
分别计算
(1−pi)(1−pj)(1−pk)s∈N(i)∪N(j)&s=k∏ps 和
pk(1−pi)(1−pj)(s∈N(i)&s=k∑(1−ps)t∈N(i)&t=k,s∏pt)(s∈N(j)&s=k∑(1−ps)t∈N(j)&t=k,s∏pt),
可分别简记为
pk21−pk(1−pi)Ai(1−pj)Aj 和
pk1(1−pi)Ai(Bi−pk1−pk)(1−pj)Aj(Bj−pk1−pk)。
由于
i 与
j 在贡献中对称,易发现可以类似
dis(i,j)≥3 的方式计算前缀和线性得。
Time Complexity
计算逆元时间复杂度为
O(logM),每部分遍历都是线性的,因此总时间复杂度为
O(nlogM)。
Space Complexity
记录树以及每个节点的不同权值即可,
O(n)。
Code
CPPconst int N=1e5+5,mod=998244353;
int t,n,p[N],v[N],A[N],B[N];
vector<int> e[N];
inline int pw(int x,int y){
int sum=1;
while(y){
if(y&1) sum=1ll*sum*x%mod;
x=1ll*x*x%mod;y>>=1;
}
return sum;
}
inline int inv(int x){return pw(x,mod-2);}
inline void solve(){
rd(n);int sum1=0,sum21=0,sum22=0,sum3=0;
for(re i=1;i<=n;++i){
int q;rd(p[i]);rd(q);p[i]=1ll*p[i]*inv(q)%mod;
e[i].clear();A[i]=1;B[i]=0;
}
for(re i=1;i<n;++i){
int u,v;rd(u);rd(v);
e[u].pb(v);e[v].pb(u);
A[u]=1ll*A[u]*p[v]%mod;A[v]=1ll*A[v]*p[u]%mod;
B[u]=(1ll*(1-p[v]+mod)%mod*inv(p[v])%mod+B[u])%mod;B[v]=(1ll*(1-p[u]+mod)%mod*inv(p[u])%mod+B[v])%mod;
}
for(re i=1;i<=n;++i) v[i]=1ll*(1-p[i]+mod)%mod*A[i]%mod*B[i]%mod;
int tmp=0;
for(re i=1;i<=n;++i) sum3=(sum3+1ll*tmp*v[i]%mod)%mod,tmp=(tmp+v[i])%mod;
for(re i=1;i<=n;++i)
for(re j=0;j<e[i].size();++j)
if(i<e[i][j]){
sum1=(sum1+1ll*(1-p[i]+mod)%mod*(1-p[e[i][j]]+mod)%mod*inv(p[i])%mod*inv(p[e[i][j]])%mod*A[i]%mod*A[e[i][j]]%mod)%mod;
sum3=(sum3-1ll*v[i]*v[e[i][j]]%mod+mod)%mod;
}
for(re i=1;i<=n;++i){
int tmp1=0,tmp2=0,tmp3=0,p1=1ll*(1-p[i]+mod)%mod*inv(p[i])%mod*inv(p[i])%mod,p2=inv(p[i]),p3=1ll*(1-p[i]+mod)%mod*inv(p[i])%mod;
for(re j=0;j<e[i].size();++j){
sum21=(sum21+1ll*tmp1*(1-p[e[i][j]]+mod)%mod*A[e[i][j]]%mod*p1%mod)%mod;tmp1=(tmp1+1ll*(1-p[e[i][j]]+mod)%mod*A[e[i][j]]%mod)%mod;
sum22=(sum22+1ll*tmp2*(1-p[e[i][j]]+mod)%mod*A[e[i][j]]%mod*(B[e[i][j]]-p3+mod)%mod*p2%mod)%mod;tmp2=(tmp2+1ll*(1-p[e[i][j]]+mod)%mod*A[e[i][j]]%mod*(B[e[i][j]]-p3+mod)%mod)%mod;
sum3=(sum3-1ll*tmp3*v[e[i][j]]%mod+mod)%mod;tmp3=(tmp3+v[e[i][j]])%mod;
}
}
wr((sum1+(sum21+(sum22+sum3)%mod)%mod)%mod);puts("");
}