专栏文章

题解:P14254 分割(divide)

P14254题解参与者 10已保存评论 10

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
10 条
当前快照
1 份
快照标识符
@minkis31
此快照首次捕获于
2025/12/02 03:54
3 个月前
此快照最后确认于
2025/12/02 03:54
3 个月前
查看原文
感觉没到蓝,场上想的时间甚至比第一题短。

solution

引理 1
子树在原树深度上的深度集合是一个连续区间。
证明
uu 的原树深度为 dud_u,子树中原树最深的深度为 MuM_u。那么子树中出现过的原树深度集合:
Su=du,du+1,,MuS_u={d_u,d_u+1,\dots,M_u}
子树中任一节点的原树深度不小于 dud_u(因为从根到该节点经过 uu),不大于 MuM_u。对于区间中任意整数 xx 满足 duxMud_u\le x\le M_u,沿着从 uu 到达某个深度为 MuM_u 的节点的路径上必然存在深度为 xx 的节点,因此区间内每个整数都出现,集合为闭区间。
引理 2
所有被选节点的原树深度必须相等。
证明
序列要求 1<db1db2dbk1<d_{b_1}\le d_{b_2}\le\cdots\le d_{b_k}。由条件:
S1=i=2k+1Si,S_1=\bigcap_{i=2}^{k+1} S_i,
两边的最小元素(即区间左端点)相等。左端点分别是 db1d_{b_1}max(db2,,dbk,1)\max(d_{b_2},\dots,d_{b_k},1)。因此:
db1=max(db2,,dbk,1)d_{b_1}=\max(d_{b_2},\dots,d_{b_k},1)
又因非降序 db1db2dbkd_{b_1}\le d_{b_2}\le\cdots\le d_{b_k},可推出所有 dbid_{b_i} 必然相等。设共同深度为 D>1D>1
因此我们可以按深度 DD 独立地统计:只考虑原树中深度恰为 DD 的节点,把所有合法序列的项都限制在该层。
对深度为 DD 的每个节点 uu,令子树中原树最深的深度为 MuM_u
按上面的引理,每个被选节点对应的 SS 都是区间 [D,Mu][D, M_u]
把深度为 DD 的所有节点按 MuM_u 从小到大排序。设该层共有 mm 个节点。
固定某个节点 uu且令其被放在序列的第 1 位。
t=Mut=M_u。要使得:
S1=[D,t]=i=2k+1Si,S_1=[D,t]=\bigcap_{i=2}^{k+1} S_i,
必须满足:
  1. 对于 i=2,,ki=2,\dots,k,它们对应的 MbiM_{b_i} 都不能小于 tt,否则交集上界会小于 tt。因此其他 k1k-1 个被选节点必须从 MtM\ge t 的节点中选。
  2. 根所在子树在去掉这些 kk 条边后,仍然有一个深度不小于 tt,即剩余部分的最大深度 RtR\ge t。当且仅当并非层内所有 MtM\ge t 的节点都被选掉时,根所在剩余部分才含有深度 t\ge t。换句话说,若把层内所有 MtM\ge t 的节点全部包含在选集里,那么剩下的树里没有深度 t\ge t 的节点,导致 R<tR<t,因此该种选择非法。
把层内节点按 MM 排序并分组。对于某个具体的 tt
  • aa 为该组中 M=tM=t 的节点数。
  • GG 为层内满足 MtM\ge t 的节点总数。
考虑把首位 b1b_1 选为该组中某个节点。其余 k1k-1 个位置必须从剩下的 G1G-1MtM\ge t 节点中 有序不重复地选出。方案数为:
P(G1,k1)=(G1)(G2)(Gk+1)P(G-1,k-1)=(G-1)(G-2)\cdots(G-k+1)
但若 k=Gk=G,则根所在部分不含深度 t\ge t,不满足条件。
k=Gk=GP(G1,k1)=(G1)!P(G-1,k-1)=(G-1)!,此类全部被选掉的序列数是 a(G1)!a(G-1)!,需要剔除。
因此,固定 tt对合法有序序列贡献:
{aP(G1,k1),Gka(P(G1,k1)(G1)!)=0,G=k \begin{cases} aP(G-1,k-1), & G\ne k\\ a\bigl(P(G-1,k-1)-(G-1)!\bigr)=0, & G=k \end{cases}
把该层上所有不同的 tt 值的贡献累加,得到该深度 DD 的总贡献。对所有 D2D\ge2 累加即为全树结果。
预处理阶乘逆元后如果用 sortO(nlogn)O(n\log n) 的。换成基数排序可以做到 O(n)O(n)
codeCPP
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){
    int s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
    return s*w;
}
inline void out(int x){
    if(x==0){putchar('0');return;}
    int len=0,k1=x,c[10005];
    if(k1<0)k1=-k1,putchar('-');
    while(k1)c[len++]=k1%10+'0',k1/=10;
    while(len--)putchar(c[len]);
}
const int N=1e6+5,mod=998244353;
int fa[N],dep[N],cnt[N],maxd[N],pos[N],m[N],inv[N],fac[N];
int addmod(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
int submod(int a,int b){a-=b;if(a<0)a+=mod;return a;}
int mulmod(int a,int b){return (a*b)%mod;}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=ans*a%mod;
        b>>=1,a=a*a%mod;
    }return ans;
}
void init(int n){
    inv[0]=inv[1]=1;fac[0]=fac[1]=1;
    for(int i=2;i<=n;i++)fac[i]=fac[i-1]*i%mod;
    inv[n]=qpow(fac[n],mod-2);
    for(int i=n-1;i>=2;i--)inv[i]=inv[i+1]*(i+1)%mod;
}
int c(int n,int m){
    if(n<m)return 0;
    return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
signed main(){
    // freopen("divide6.in","r",stdin);
    // freopen("divide.out","w",stdout);
    int n=read(),k=read(),maxn=1;dep[1]=1;
    for(int i=2;i<=n;i++)fa[i]=read();init(n+1);
    for(int i=2;i<=n;i++){
        dep[i]=dep[fa[i]]+1;
        maxn=max(maxn,dep[i]);
    }for(int i=1;i<=n;i++)maxd[i]=dep[i];
    for(int i=n;i>=2;i--)maxd[fa[i]]=max(maxd[fa[i]],maxd[i]);
    // for(int i=2;i<=n;i++)cout<<dep[i]<<" ";puts("");
    for(int i=2;i<=n;i++)cnt[dep[i]]++;int tot=0;
    for(int d=1;d<=maxn;d++)pos[d]=tot,tot+=cnt[d];
    vector<int>cur(pos,pos+maxn+1);int ans=0;
    for(int i=2;i<=n;i++)m[cur[dep[i]]++]=maxd[i];
    // cout<<maxn<<"\n";
    for(int d=2;d<=maxn;d++){
        int mm=cnt[d];
        if(mm<k)continue;
        int l=pos[d],r=l+mm;
        sort(m+l,m+r);int pre=0,idx=l;
        while(idx<r){
            int j=idx,val=m[idx];
            while(j<r&&m[j]==val)j++;
            int a=j-idx,lcnt=pre,g=mm-lcnt,b=g-a;
            if(k<g){
                int t1=submod(c(g-1,k-1),c(b,k-1));
                t1=mulmod(a,t1);int t2=0;
                if((k-1)==b&&a>=2)t2=a%mod;
                int add=addmod(t1%mod,t2);
                add=mulmod(add,fac[k-1]);ans=addmod(ans,add);
            }pre+=a;idx=j;
        }
    }cout<<ans;
    return 0;
}

评论

10 条评论,欢迎与作者交流。

正在加载评论...