社区讨论
为什么能对呀?
P9233[蓝桥杯 2023 省 A] 颜色平衡树参与者 1已保存回复 0
讨论操作
快速查看讨论及其快照的属性,并进行相关操作。
- 当前回复
- 0 条
- 当前快照
- 1 份
- 快照标识符
- @mix7yqkg
- 此快照首次捕获于
- 2025/12/08 22:00 2 个月前
- 此快照最后确认于
- 2025/12/11 20:40 2 个月前
CPP
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 2e5 + 7;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int n, c[N], siz[N], a[N], son[N];
ll sum, ans[N], mx = 0;
unordered_map<int, ll> mp;
// a[c[i]] : 颜色 c[i] 出现的次数
// son[i] : 是否为重节点
vector<int> edges[N];
inline void solve(int u, int fa) {
// 求出子树大小 & 是否为重节点
int mx = 0, root = 0;
siz[u] = 1;
for (auto v : edges[u]) {
if (v == fa) continue;
solve(v, u);
siz[u] += siz[v];
if (siz[v] > mx) {
mx = siz[v];
root = v;
}
}
if (root)
son[root] = 1;
}
void del(int u){
if (mp[u] != 1)
mp[u]--;
else
mp.erase(u);
}
inline void clear(int u, int fa) {
// 清空
if (a[c[u]] > 0)
del(a[c[u]]); // 删除
--a[c[u]];
// 疑似错点:删除后应加上!
if (a[c[u]] > 0)
mp[a[c[u]]]++;
for (auto v : edges[u]) {
if (v == fa) continue;
clear(v, u);
}
}
inline void DFS(int u, int fa, int root) {
if (a[c[u]] > 0)
del(a[c[u]]);
++a[c[u]]; // 增加个数
if (a[c[u]] > 0)
mp[a[c[u]]]++;
for (auto v : edges[u]) {
if (v == fa || v == root) continue;
DFS(v, u, root);
}
}
inline void dfs(int u, int fa) {
// 轻节点
int root = 0;
for (auto v : edges[u]) {
if (v == fa) continue;
if (!son[v]) {
dfs(v, u);
clear(v, u);
sum = 0, mx = 0;
}else {
root = v;
}
}
if (root) dfs(root, u);
DFS(u, fa, root); // 不能遍历重儿子
ans[u] = mp.size() == 1;
}
int main() {
n = read();
for (int i = 1; i <= n; i++) {
c[i] = read();
int x = read();
if (i == 1) continue;
edges[i].pb(x);
edges[x].pb(i);
}
solve(1, -1);
dfs(1, -1);
// for (int i = 1; i <= n; i++)
// printf("%lld ", ans[i]);
// putchar('\n');
ll end_ans = 0;
for (int i = 1; i <= n; i++)
end_ans += ans[i];
printf("%lld\n", end_ans);
return 0;
}
回复
共 0 条回复,欢迎继续交流。
正在加载回复...