专栏文章
题解:P14254 分割(divide)
P14254题解参与者 10已保存评论 10
文章操作
快速查看文章及其快照的属性,并进行相关操作。
- 当前评论
- 10 条
- 当前快照
- 1 份
- 快照标识符
- @minkis31
- 此快照首次捕获于
- 2025/12/02 03:54 3 个月前
- 此快照最后确认于
- 2025/12/02 03:54 3 个月前
感觉没到蓝,场上想的时间甚至比第一题短。
solution
引理 1
子树在原树深度上的深度集合是一个连续区间。
证明
设 的原树深度为 ,子树中原树最深的深度为 。那么子树中出现过的原树深度集合:
子树中任一节点的原树深度不小于 (因为从根到该节点经过 ),不大于 。对于区间中任意整数 满足 ,沿着从 到达某个深度为 的节点的路径上必然存在深度为 的节点,因此区间内每个整数都出现,集合为闭区间。
引理 2
所有被选节点的原树深度必须相等。
证明
序列要求 。由条件:
两边的最小元素(即区间左端点)相等。左端点分别是 与 。因此:
又因非降序 ,可推出所有 必然相等。设共同深度为 。
因此我们可以按深度 独立地统计:只考虑原树中深度恰为 的节点,把所有合法序列的项都限制在该层。
对深度为 的每个节点 ,令子树中原树最深的深度为 。
按上面的引理,每个被选节点对应的 都是区间 。
把深度为 的所有节点按 从小到大排序。设该层共有 个节点。
固定某个节点 且令其被放在序列的第 1 位。
令 。要使得:
必须满足:
- 对于 ,它们对应的 都不能小于 ,否则交集上界会小于 。因此其他 个被选节点必须从 的节点中选。
- 根所在子树在去掉这些 条边后,仍然有一个深度不小于 ,即剩余部分的最大深度 。当且仅当并非层内所有 的节点都被选掉时,根所在剩余部分才含有深度 。换句话说,若把层内所有 的节点全部包含在选集里,那么剩下的树里没有深度 的节点,导致 ,因此该种选择非法。
把层内节点按 排序并分组。对于某个具体的 :
- 设 为该组中 的节点数。
- 设 为层内满足 的节点总数。
考虑把首位 选为该组中某个节点。其余 个位置必须从剩下的 个 节点中 有序不重复地选出。方案数为:
但若 ,则根所在部分不含深度 ,不满足条件。
当 时 ,此类全部被选掉的序列数是 ,需要剔除。
因此,固定 对合法有序序列贡献:
把该层上所有不同的 值的贡献累加,得到该深度 的总贡献。对所有 累加即为全树结果。
预处理阶乘逆元后如果用
sort 是 的。换成基数排序可以做到 。code
CPP#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline void out(int x){
if(x==0){putchar('0');return;}
int len=0,k1=x,c[10005];
if(k1<0)k1=-k1,putchar('-');
while(k1)c[len++]=k1%10+'0',k1/=10;
while(len--)putchar(c[len]);
}
const int N=1e6+5,mod=998244353;
int fa[N],dep[N],cnt[N],maxd[N],pos[N],m[N],inv[N],fac[N];
int addmod(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
int submod(int a,int b){a-=b;if(a<0)a+=mod;return a;}
int mulmod(int a,int b){return (a*b)%mod;}
int qpow(int a,int b){
int ans=1;
while(b){
if(b&1)ans=ans*a%mod;
b>>=1,a=a*a%mod;
}return ans;
}
void init(int n){
inv[0]=inv[1]=1;fac[0]=fac[1]=1;
for(int i=2;i<=n;i++)fac[i]=fac[i-1]*i%mod;
inv[n]=qpow(fac[n],mod-2);
for(int i=n-1;i>=2;i--)inv[i]=inv[i+1]*(i+1)%mod;
}
int c(int n,int m){
if(n<m)return 0;
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
signed main(){
// freopen("divide6.in","r",stdin);
// freopen("divide.out","w",stdout);
int n=read(),k=read(),maxn=1;dep[1]=1;
for(int i=2;i<=n;i++)fa[i]=read();init(n+1);
for(int i=2;i<=n;i++){
dep[i]=dep[fa[i]]+1;
maxn=max(maxn,dep[i]);
}for(int i=1;i<=n;i++)maxd[i]=dep[i];
for(int i=n;i>=2;i--)maxd[fa[i]]=max(maxd[fa[i]],maxd[i]);
// for(int i=2;i<=n;i++)cout<<dep[i]<<" ";puts("");
for(int i=2;i<=n;i++)cnt[dep[i]]++;int tot=0;
for(int d=1;d<=maxn;d++)pos[d]=tot,tot+=cnt[d];
vector<int>cur(pos,pos+maxn+1);int ans=0;
for(int i=2;i<=n;i++)m[cur[dep[i]]++]=maxd[i];
// cout<<maxn<<"\n";
for(int d=2;d<=maxn;d++){
int mm=cnt[d];
if(mm<k)continue;
int l=pos[d],r=l+mm;
sort(m+l,m+r);int pre=0,idx=l;
while(idx<r){
int j=idx,val=m[idx];
while(j<r&&m[j]==val)j++;
int a=j-idx,lcnt=pre,g=mm-lcnt,b=g-a;
if(k<g){
int t1=submod(c(g-1,k-1),c(b,k-1));
t1=mulmod(a,t1);int t2=0;
if((k-1)==b&&a>=2)t2=a%mod;
int add=addmod(t1%mod,t2);
add=mulmod(add,fac[k-1]);ans=addmod(ans,add);
}pre+=a;idx=j;
}
}cout<<ans;
return 0;
}
相关推荐
评论
共 10 条评论,欢迎与作者交流。
正在加载评论...