专栏文章

二宫凉夏

AT_arc142_d题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@miqkui1r
此快照首次捕获于
2025/12/04 06:27
3 个月前
此快照最后确认于
2025/12/04 06:27
3 个月前
查看原文
没懂为啥题解全是八个状态。
首先肯定要能做 KK 次操作,不过容易发现这个其实等价于能做 11 次操作,因为一条边不能被两个不同的棋子经过,却可以被一个棋子反复蹂躏,然后就可以进进退退搞很多很多步。
然后我们考虑在这个条件下,我们把一个节点进进退退所涉及到的两个节点之间的边直接点亮,那么最后点亮的边一定是若干条链。
这是为啥呢,你考虑假设不是链,那就有地方度数在 33 以上,那就必然有一个回合这个点上必须塞 22 个棋子,那就是非法状态。
然后考虑这是条链还不够,还得每一步点集唯一,那我已经构造出一组合法方案,只要踏出一步使得和这个方案不一样那就是非法方案。
什么情况下会出现一个非法状态呢,我们先把这个图的大概样式写出来,对于每条链他由两个部分组成,链头一回合有点一回合没有点,链中间始终有点。
然后我们发现如果链头即将消失,旁边有一个不在链上的点的话,这个点就可以跑过来使得这个方案不唯一。
这启发我们链中间不能和链头拼在一起。
链头是可以和链头拼在一起的,但是显然他们的状态受到限制。
而链中间也是可以和链中间拼在一起的。
还有把所有点亮的边都点亮之后,点也都被点亮了,一旦出现空点我就可以直接踩过去然后方案就不唯一了。
好的,那么有这个我们就可以来设计一个状态,设 fif_i 表示以第 ii 个节点为根的子树中点 ii 是链头的方案数,gig_i 表示以第 ii 个节点为根的子树中点 ii 是链中间的方案数,转移就是考虑如果你要链中间就把两个链头拉起来,然后别的子树全部取链中间,如果你要链头就挑个链头给你连起来,然后别的子树也全都链头,显然单纯计数链还需要考虑他的节点堆在哪边,但是如果链头连起来了联通块就减少了,所以记得除掉这个系数。
然后你发现这个东西有个问题,那就是链头如果被上面的点接过去了,那他不就变成链中间了吗?
所以加一维,设 fi,0/1f_{i,0/1} 表示以点 ii 为根节点的子树,点 ii 是链头,是否要接上去的方案数,那么 fi,1f_{i,1} 就有两种转移方式,一种是找一条链上来,然后对于别的子树就直接取出 gvg_v 统计,第二种是直接自己当链头,底下的每个子树取 fv,0f_{v,0} 统计。
然后就做完啦!不过注意到因为有 00 存在所以不能直接除以逆元来算这个乘积,把 00 挑出来就好了。
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int P=998244353;
int n;
vector<int> ljb[200005];
int f[200005][2];
int g[200005];
int power(int x,int y=P-2){
    if(y==0)return 1;
    int tmp=power(x,y>>1);
    if(y&1)return tmp*tmp%P*x%P;
    return tmp*tmp%P;
}
void getf(int cur,int fa){
    for(int i=0;i<ljb[cur].size();i++){
        int v=ljb[cur][i];
        if(v==fa)continue;
        getf(v,cur);
    }
    int mul=1;
    int cnt=0;
    int num=0;
    for(int i=0;i<ljb[cur].size();i++){//
        int v=ljb[cur][i];
        if(v==fa)continue;
        cnt++;
        if(!f[v][0]){
            num++;
            continue;
        }
        mul=mul*f[v][0]%P;
    }
    if(!num){
        for(int i=0;i<ljb[cur].size();i++){
            int v=ljb[cur][i];
            if(v==fa)continue;
            f[cur][0]+=f[v][1]*mul%P*power(f[v][0])%P*power(power(2,cnt-1))%P;
            if(f[cur][0]>=P)f[cur][0]-=P;
        }
    }
    if(num==1){
        for(int i=0;i<ljb[cur].size();i++){
            int v=ljb[cur][i];
            if(v==fa)continue;
            if(f[v][0])continue;
            f[cur][0]+=f[v][1]*mul%P*power(power(2,cnt-1))%P;
            if(f[cur][0]>=P)f[cur][0]-=P;
        }
    }
    int all=1;
    for(int i=0;i<ljb[cur].size();i++){
        int v=ljb[cur][i];
        if(v==fa)continue;
        all=all*f[v][0]%P*power(2)%P;
    }
    all=all*2%P;
    f[cur][1]=all;
    mul=1;
    num=0;
    for(int i=0;i<ljb[cur].size();i++){
        int v=ljb[cur][i];
        if(v==fa)continue;
        if(!g[v]){
            num++;
            continue;
        }
        mul=mul*g[v]%P;
    }
    if(!num){
        for(int i=0;i<ljb[cur].size();i++){
            int v=ljb[cur][i];
            if(v==fa)continue;
            f[cur][1]+=f[v][1]*mul%P*power(g[v])%P;
            if(f[cur][1]>=P)f[cur][1]-=P;
        }
    }
    if(num==1){
        for(int i=0;i<ljb[cur].size();i++){
            int v=ljb[cur][i];
            if(v==fa)continue;
            if(g[v])continue;
            f[cur][1]+=f[v][1]*mul%P;
            if(f[cur][1]>=P)f[cur][1]-=P;
        }
    }
    if(num==0){
        int sum=0;
        for(int i=ljb[cur].size()-1;i>=0;i--){
            int v1=ljb[cur][i];
            if(v1==fa)continue;
            int val1=(f[v1][1]*power(g[v1])%P);
            g[cur]+=sum*val1;
            g[cur]%=P;
            sum+=(f[v1][1]*power(g[v1])%P)%P*power(2)%P*mul%P;
            if(sum>=P)sum-=P;
        }
    }
    if(num==1){
        for(int i=0;i<ljb[cur].size();i++){
            int v1=ljb[cur][i];
            if(v1==fa)continue;
            if(g[v1])continue;
            for(int j=0;j<ljb[cur].size();j++){
                int v2=ljb[cur][j];
                if(v1==fa||v2==fa)continue;
                if(v1==v2)continue;
                g[cur]+=f[v1][1]*f[v2][1]%P*power(2)%P*mul%P*power(g[v2])%P;
                if(g[cur]>=P)g[cur]-=P;
            }
        }
    }
    if(num==2){
        for(int i=0;i<ljb[cur].size();i++){
            int v1=ljb[cur][i];
            if(v1==fa)continue;
            if(g[v1])continue;
            for(int j=i+1;j<ljb[cur].size();j++){
                int v2=ljb[cur][j];
                if(v1==fa||v2==fa)continue;
                if(v1==v2)continue;
                if(g[v2])continue;
                g[cur]+=f[v1][1]*f[v2][1]%P*power(2)%P*mul%P;
                if(g[cur]>=P)g[cur]-=P;
            }
        }
    }
    return;
}
signed main(){
    scanf("%lld",&n);
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%lld%lld",&u,&v);
        ljb[u].push_back(v);
        ljb[v].push_back(u);
    }
    getf(1,0);
    printf("%lld\n",(g[1]+f[1][0])%P);
    return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...