专栏文章

题解:P13241 「2.48sOI R1」格律树

P13241题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@miowvx0a
此快照首次捕获于
2025/12/03 02:28
3 个月前
此快照最后确认于
2025/12/03 02:28
3 个月前
查看原文
喜报:离 AC 差一个 +mod
好我们回到题目。首先这题是给出了一棵树和多组询问,每组询问给出了一些关键点关键点颜色颜色只有两种 0 或 1。其他的点可以任意染色,求使得每个关键点(注意不是所有节点)到根路径上的颜色序列不出现 101 这种情况的方案数。
我们先考虑没有询问怎么做。设计状态 dpi,0/1/2dp_{i,0/1/2} 表示以 ii 号节点为根的子树内颜色序列以 0110 结尾的方案总数。其中 dpi,2dp_{i,2} 这个状态需要注意一下,就是我的节点 ii 可能会有多个子节点,如果它的任意一个子节点到根的路径出现了 10 结尾,就需要把它统计进 dpi,2dp_{i,2} 里去。
考虑怎么转移。先看 dpi,0dp_{i,0}。因为我把 0 填到序列最后,是肯定不会产生非法情况的。所以它可以从前面的任意状态继承过来。于是 dpi,0=jsoni(dpj,0+dpj,1)dp_{i,0}= \prod_{j\in son_i} (dp_{j,0}+dp_{j,1})
接着看 dpi,1dp_{i,1}。如果我把 1 填到序列最后,就有可能会产生非法情况了。所以要把这种情况减去,相当于就是减去儿子节点以 10 结尾的情况,所以 dpi,1=jsoni(dpj,0+dpj,1dpj,2)dp_{i,1}=\prod_{j\in son_i} (dp_{j,0}+dp_{j,1}-dp_{j,2})
最后来看 dpi,2dp_{i,2}。可以发现因为刚才讲的 dpi,2dp_{i,2} 的统计对象,那么相当于 dpi,0dp_{i,0} 减去所有不含 10 的情况,那么 dpi,2=dpi,0jsonidpj,0dp_{i,2}=dp_{i,0}-\prod_{j\in son_i} dp_{j,0}
现在对于每个节点 ii,如果它的颜色是 0 就转移 dpi,0dp_{i,0}dpi,2dp_{i,2},如果它的颜色是 1 就转移 dpi,1dp_{i,1},否则如果它没有颜色,就两个都转移。
最后整颗树的答案就是 dp1,0+dp1,1dp_{1,0}+dp_{1,1}
好,现在我们已经知道了没有询问怎么做,现在考虑有多组询问该怎么做。
首先可以看见题目中询问的总点数是 O(n)O(n) 级别的,这引导我们往虚树上想,不会建虚树的可以上网搜一下,这题不卡常,二次排序可以过。
现在假设我们已经把虚树给建出来了,那么我们的转移可以和之前一样的做法,但是这时候会出现一个问题:我们从虚树上的一个点跳到他的父节点时会经过很多其他的点,要转移多次。但是这些点没有其他的儿子,因此转移是相同的,这引导我们往矩阵乘法上去想。
那么因为我们刚才提及了 3 种状态转移,是不是就要设计 3 种矩阵呢?其实并不需要,因为我们经过的这些节点肯定都是没有颜色的点,因此只要设计一种矩阵就行。
假设现在我们知道 dpx,0/1/2dp_{x,0/1/2},要往 xx 的父亲 fafa 转移,那么可以按照上面的文字设计出以下的矩阵转移:
[dpx,0dpx,1dpx,2]×[110111010]=[dpfa,0dpfa,1dpfa,2] \begin{bmatrix} dp_{x,0} & dp_{x,1} & dp_{x,2} \\ \end{bmatrix} \times \begin{bmatrix} 1 & 1 & 0 \\ 1 & 1 & 1 \\ 0 & -1 & 0 \end{bmatrix} = \begin{bmatrix} dp_{fa,0} & dp_{fa,1} & dp_{fa,2} \\ \end{bmatrix}
然后我们对于 i=1,2,,ni=1,2,\dots,n 预处理出 [110111010]i\begin{bmatrix}1 & 1 & 0 \\1 & 1 & 1 \\0 & -1 & 0\end{bmatrix}^i
那么每次处理出 dpx,0/1/2dp_{x,0/1/2} 后把它乘上一个 [110111010]depthxdepthfa1\begin{bmatrix}1 & 1 & 0 \\1 & 1 & 1 \\0 & -1 & 0\end{bmatrix}^{depth_x-depth_fa-1} 就可以正常转移了。
最后的答案为 (dp1,0+dp1,1)×2不在关键点到根路径上的节点个数(dp_{1,0}+dp_{1,1})\times 2^{不在关键点到根路径上的节点个数}
给出代码:
CPP
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define N 1000010
const int mod=1e9+7;
int n,q,k,dep[N],dfn[N],idx,f[N][21],rt,col[N],m,sum[N][21],ans[N],bas2[N],tpp;
vector<int>to[N],v;
vector<int>too[N];
struct matrix{
	int a[3][3]={};
	friend matrix operator+(matrix a,matrix b){
		matrix c;
		for(int i=0;i<3;i++){
			for(int j=0;j<3;j++){
                c.a[i][j]=(a.a[i][j]+b.a[i][j])%mod;
			}
		}
		return c;
	}
	friend matrix operator*(matrix a,matrix b){
		matrix c;
		for(int i=0;i<3;i++){
			for(int j=0;j<3;j++){
				for(int k=0;k<3;k++){
					c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]%mod+mod)%mod;
				}
			}
		}
		return c;
	}
    void print(){
        for(int i=0;i<3;i++){
            for(int j=0;j<3;j++){
                cout<<a[i][j]<<" ";
            }
            cout<<'\n';
        }
    }
}dp[N],bas[N],base[3];
void ini(){
    bas[1].a[0][0]=1;
    bas[1].a[0][1]=1;
    bas[1].a[1][0]=1;
    bas[1].a[1][1]=1;
    bas[1].a[1][2]=1;
    bas[1].a[2][1]=-1;
    // bas[1].a[2][0]=1;
    
    base[2]=bas[1];

    base[0].a[0][0]=1;
    base[0].a[1][0]=1;
    base[0].a[1][2]=1;
    
    base[1].a[0][1]=1;
    base[1].a[1][1]=1;
    base[1].a[2][1]=-1;
}
bool comp(int x,int y){
    return dfn[x]<dfn[y];
}
void clr(){
    for(auto x:v){
        too[x].clear();
        col[x]=2;
    }
}
void init(int x,int fa){
    dfn[x]=++idx;
    dep[x]=dep[fa]+1;
    f[x][0]=fa;
    for(int i=1;i<=20;i++){
        f[x][i]=f[f[x][i-1]][i-1];
    }
    for(auto y:to[x]){
        if(y==fa){
            continue;
        }
        init(y,x);
    }
}
int lca(int x,int y){
    if(x==y){
        return x;
    }
    if(dep[x]<dep[y]){
        swap(x,y);
    }
    for(int i=19;i>=0;i--){
        if(dep[f[x][i]]>=dep[y]){
            x=f[x][i];
        }
    }
    if(x==y){
        return x;
    }
    for(int i=19;i>=0;i--){
        if(f[x][i]!=f[y][i]){
            x=f[x][i];
            y=f[y][i];
        }
    }
    return f[x][0];
}
int build(){
    sort(v.begin(),v.end(),comp);
    vector<int>vv;
    vv=v;
    for(int i=0;i<v.size()-1;i++){
        vv.push_back(lca(v[i],v[i+1]));
    }
    sort(vv.begin(),vv.end(),comp);
    v.clear();
    v.push_back(vv[0]);
    for(int i=1;i<vv.size();i++){
        if(vv[i]!=vv[i-1]){
            v.push_back(vv[i]);
        }
    }
    for(int i=0;i<v.size()-1;i++){
        int tmp=lca(v[i],v[i+1]);
        if(tmp!=v[i+1]){
            too[tmp].push_back(v[i+1]);
            too[v[i+1]].push_back(tmp);
        }
    }
    return v[0];
}
void dfs(int x,int fa){
    if(x!=rt){
        tpp+=dep[x]-dep[fa];
    }
    for(int i=0;i<3;i++){
        for(int j=0;j<3;j++){
            dp[x].a[i][j]=0;
        }
    }
    if(too[x].size()+(x==rt)<=1){
        dp[x].a[0][col[x]]=1;
        // cout<<x<<'\n';
        // dp[x].print();
        if(dep[fa]!=dep[x]-1){
            dp[x]=dp[x]*bas[dep[x]-dep[fa]-1];
        }
        // dp[x].print();
        return;
    }
    if(col[x]<2){
        dp[x].a[0][col[x]]=1;
        int tmp=1;
        for(auto y:too[x]){
            if(y==fa){
                continue;
            }
            dfs(y,x);
            matrix f=dp[y]/**base[col[x]]*/;
            tmp=tmp*f.a[0][0]%mod;
            // cout<<x<<" "<<y<<'\n';
            // base[col[x]].print();
            // f.print();
            dp[x].a[0][col[x]]=dp[x].a[0][col[x]]*((f.a[0][0]+f.a[0][1]-col[x]*f.a[0][2]+mod)%mod)%mod;
        }
        if(col[x]==0){
            dp[x].a[0][2]=(dp[x].a[0][0]-tmp+mod)%mod;
        }
    }
    else{
        dp[x].a[0][0]=dp[x].a[0][1]=1;
        int tmp=1;
        for(auto y:too[x]){
            if(y==fa){
                continue;
            }
            dfs(y,x);
            matrix f=dp[y]/**base[col[x]]*/;
            tmp=tmp*f.a[0][0]%mod;
            // cout<<x<<" "<<y<<'\n';
            // base[col[x]].print();
            // f.print();
            dp[x].a[0][0]=dp[x].a[0][0]*(f.a[0][0]+f.a[0][1])%mod;
            dp[x].a[0][1]=dp[x].a[0][1]*((f.a[0][0]+f.a[0][1]-f.a[0][2]+mod)%mod)%mod;
            // dp[x].print();
        }
        dp[x].a[0][2]=(dp[x].a[0][0]-tmp+mod)%mod;
    }
    // cout<<x<<'\n';
    // dp[x].print();
    if(dep[fa]!=dep[x]-1){
        dp[x]=dp[x]*bas[dep[x]-dep[fa]-1];
    }
    // dp[x].print();
}
signed main(){
    ini();
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        to[u].push_back(v);
        to[v].push_back(u);
    }
    bas2[0]=1;
    cin>>m>>q;
    init(1,0);
    int maxn=0;
    for(int i=1;i<=n;i++){
        maxn=max(maxn,dep[i]);
        col[i]=2;
        bas2[i]=bas2[i-1]*2%mod;
    }
    for(int i=2;i<=maxn;i++){
        bas[i]=bas[i-1]*bas[1];
    }
    for(int TEST=1;TEST<=m;TEST++){
        tpp=0;
        clr();
        v.clear();
        bool fl=0;
        cin>>k;
        for(int i=1;i<=k;i++){
            int x;
            cin>>x;
            cin>>col[x];
            v.push_back(x);
            if(x==1){
                fl=1;
            }
        }
        rt=build();
        dfs(rt,f[rt][0]);
        if(dep[rt]!=dep[1]){
            dp[rt]=dp[rt]*bas[dep[rt]-dep[1]];
        }
        tpp+=dep[rt];
        ans[TEST]=(dp[rt].a[0][0]+dp[rt].a[0][1])%mod*bas2[n-tpp]%mod;
    }
    int ANS=0;
    for(int i=1;i<=m;i++){
        ANS^=ans[i];
        if(i%q==0){
            cout<<ANS<<'\n';
            ANS=0;
        }
    }
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...