社区讨论

求助RE 0pts

P14254分割(divide)参与者 2已保存回复 2

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
2 条
当前快照
1 份
快照标识符
@mhj0zik5
此快照首次捕获于
2025/11/03 18:57
4 个月前
此快照最后确认于
2025/11/03 18:57
4 个月前
查看原帖
CPP
#include<bits/stdc++.h>

#define int long long

using namespace std;

constexpr int mod = 998244353;

int n = 0 , k = 0 , ans = 0 , cnt = 0;

array<int , 2000100> head , dep , max_dep , fact , fact_inv , tot;

map<int , vector<int> > Map;

int quick_pow(int x , int y , int mod)
{
	int res = 1;
	
	while(y)
	{
		if(y & 1) res = (res * x) % mod;
		x = (x * x) % mod;
		y >>= 1;
	}
	
	return res;
}

int get_inv(int x) {return quick_pow(x , mod - 2 , mod);}

int A(int x , int y)
{
	if(x < y || y < 0 || x < 0) return 0;
	return fact[x] * fact_inv[x - y] % mod;
}

int add(int &x , int y) {((x += y) >= mod) && (x -= mod);}

struct Node{
	int to;
	int nxt;
};

array<Node , 4000100> edge;

void new_line(int a , int b)
{
	edge[++ cnt].to = b;
	edge[cnt].nxt = head[a];
	head[a] = cnt;
	
	return;
}

void dfs(int x , int fa)
{
	dep[x] = dep[fa] + 1;
	
	max_dep[x] = dep[x];
	
	for(int e = head[x];e;e = edge[e].nxt)
	{
		int to = edge[e].to;
		
		if(to == fa) continue;
		
		dfs(to , x);
		
		max_dep[x] = max(max_dep[x] , max_dep[to]);
	}
	
	return;
}

signed main()
{
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cout.tie(nullptr);
	
	fact[0] = fact_inv[0] = 1;
	
	for(int i = 1;i <= 1e6;++ i) fact[i] = fact[i - 1] * i % mod;
	
	fact_inv[1e6] = get_inv(fact[1e6]);
	
	for(int i = 1e6 - 1;i >= 1;-- i) fact_inv[i] = fact_inv[i + 1] * (i + 1) % mod;
	
	cin >> n >> k;
	
	for(int i = 2;i <= n;++ i)
	{
		int up = 0;
		
		cin >> up;
		
		new_line(up , i) , new_line(i , up);
	}
	
	dfs(1 , 0);
	
	for(int i = 2;i <= n;++ i) Map[dep[i]].emplace_back(max_dep[i]);
	
	for(auto &it : Map)
	{
		auto &num = it.second;
		
		sort(num.begin() , num.end());
		
		int it1 = 0;
		
		while(it1 < num.size())
		{
			int it2 = it1;
			
			while(it2 < num.size() && num[it2] == num[it1]) ++ it2;
			
			-- it2;
			
			int len = it2 - it1 + 1 , ex = num.size() - it1;
			
			if(ex >= k + 1)
			{
				add(ans , len * A(ex - 1 , k - 1) % mod);
				if(ex - len >= k) add(ans , mod - len * A(ex - len , k - 1) % mod);
			}
			
			it1 = it2 + 1;
		}
	}
	
	cout << ans << endl;
	
	return 0;
}

回复

2 条回复,欢迎继续交流。

正在加载回复...