社区讨论

TLE求助

CF161DDistance in Tree参与者 3已保存回复 3

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
3 条
当前快照
1 份
快照标识符
@mi86d8n0
此快照首次捕获于
2025/11/21 09:21
4 个月前
此快照最后确认于
2025/11/21 09:21
4 个月前
查看原帖
RT,深夜蒟蒻,在线求助(虽然发完就去睡了)\text{RT,深夜蒟蒻,在线求助(虽然发完就去睡了)}
CPP
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
using namespace std;
void read(int &x) {
    x = 0;
    char c = getchar();
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)) x = x * 10 + (c ^ '0'), c = getchar();
}
const int N = 50001;
int head[N], nxt[N << 1], to[N << 1], cnt, vis[N], siz[N], mxpt[N], rt, dis[N],
    s[N], t;
int n, k;
void add_edge(int s, int t) {
    to[++cnt] = t;
    nxt[cnt] = head[s];
    head[s] = cnt;
}
int sum;
void getrt(int u, int fa) {
    siz[u] = 1, mxpt[u] = 0;
    for (int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if (v == fa || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        mxpt[u] = max(siz[v], mxpt[u]);
    }
    mxpt[u] = max(sum - siz[u], mxpt[u]);
    if (mxpt[u] < mxpt[rt]) rt = u;
}
void getdis(int u, int fa) {
    s[++t] = dis[u];
    for (int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if (v == fa || vis[v]) continue;
        dis[v] = dis[u] + 1;
        getdis(v, u);
    }
}
int solve(int u, int d) {
    dis[u] = d;
    t = 0;
    getdis(u, 0);
    sort(s + 1, s + 1 + t);
    int l = 1, ans = 0;
    while (l < cnt && s[l] + s[t] < k) l++;
    while (l < cnt && k - s[l] >= s[l]) {
        // k-sl is val needed,if sl > it then all vals left in s > it due to
        // sort(s).
        int L = lower_bound(s + l, s + 1 + t, k - s[l]) - s;
        int R = upper_bound(s + l, s + 1 + t, k - s[l]) - s;
        if (L < R) ans += R - L;
        l++;
    }
    return ans;
}
long long ans;
void divide(int u) {
    vis[u] = 1;
    ans += solve(u, 0);
    for (int i = head[u]; i; i = nxt[i]) {
        int v = to[i];
        if (vis[v]) continue;
        ans -= solve(v, 1);
        rt = 0, sum = siz[v];
        getrt(v, u), divide(rt);
    }
}
int main() {
    read(n), read(k);
    for (int i = 1, a, b; i < n; i++)
        read(a), read(b), add_edge(a, b), add_edge(b, a);
    mxpt[rt] = n, sum = n;
    getrt(1, 0), divide(rt);
    printf("%I64d", ans);
}

回复

3 条回复,欢迎继续交流。

正在加载回复...