社区讨论

为什么能对呀?

P9233[蓝桥杯 2023 省 A] 颜色平衡树参与者 1已保存回复 0

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
0 条
当前快照
1 份
快照标识符
@mix7yqkg
此快照首次捕获于
2025/12/08 22:00
2 个月前
此快照最后确认于
2025/12/11 20:40
2 个月前
查看原帖
CPP
#include <bits/stdc++.h>

using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 2e5 + 7;
const int P = 998244353;
int read() {
	int x = 0, f = 1;
	char ch = getchar();
	while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
	return x * f;
}
int n, c[N], siz[N], a[N], son[N];
ll sum, ans[N], mx = 0;
unordered_map<int, ll> mp;
// a[c[i]] : 颜色 c[i] 出现的次数 
// son[i] : 是否为重节点
vector<int> edges[N];
inline void solve(int u, int fa) {
	// 求出子树大小 & 是否为重节点
	int mx = 0, root = 0;
	siz[u] = 1;
	for (auto v : edges[u]) {
		if (v == fa) continue;
		solve(v, u);
		siz[u] += siz[v];
		if (siz[v] > mx) {
			mx = siz[v];
			root = v;
		}
	}
	if (root)
		son[root] = 1;
}
void del(int u){
	if (mp[u] != 1)
		mp[u]--;
	else
		mp.erase(u);
}
inline void clear(int u, int fa) {
	// 清空
	if (a[c[u]] > 0)
		del(a[c[u]]); // 删除
	--a[c[u]];
	// 疑似错点:删除后应加上!
	if (a[c[u]] > 0)
		mp[a[c[u]]]++;
	for (auto v : edges[u]) {
		if (v == fa) continue;
		clear(v, u);
	}
}
inline void DFS(int u, int fa, int root) {
	if (a[c[u]] > 0)
		del(a[c[u]]);
	++a[c[u]]; // 增加个数
	if (a[c[u]] > 0)
		mp[a[c[u]]]++;
	for (auto v : edges[u]) {
		if (v == fa || v == root) continue;
		DFS(v, u, root);
	}
}
inline void dfs(int u, int fa) {
	// 轻节点
	int root = 0;
	for (auto v : edges[u]) {
		if (v == fa) continue;
		if (!son[v]) {
			dfs(v, u);
			clear(v, u);
			sum = 0, mx = 0;
		}else {
			root = v;
		}
	}
	if (root) dfs(root, u);
	DFS(u, fa, root); // 不能遍历重儿子
	ans[u] = mp.size() == 1;
}

int main() {
	n = read();
	for (int i = 1; i <= n; i++) {
		c[i] = read();
		int x = read();
		if (i == 1) continue;
		edges[i].pb(x);
		edges[x].pb(i);
	}
	solve(1, -1);
	dfs(1, -1);
	// for (int i = 1; i <= n; i++)
		// printf("%lld ", ans[i]);
	// putchar('\n');
	ll end_ans = 0;
	for (int i = 1; i <= n; i++)
		end_ans += ans[i];
	printf("%lld\n", end_ans);
	return 0;
}

回复

0 条回复,欢迎继续交流。

正在加载回复...