专栏文章

P2664 线段树合并题解

P2664题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@mioyc00f
此快照首次捕获于
2025/12/03 03:09
3 个月前
此快照最后确认于
2025/12/03 03:09
3 个月前
查看原文
O(nlogn)\mathcal O(n\log n) 做法。
先考虑一个点起始怎么做,再换根。显然可以 DFS,但是不够数据结构化,不好换根。考虑每种颜色统计有多少点使 ss 到该点经过该种颜色。
使用线段树合并。点 uu 线段树上 Ti=jT_i=j 表示 uu 子树上所有点中,有 jj 个点到 uu 的路径中经过颜色 ii。此时点 uu 在其子树内的答案为 Ti\sum T_i。合并时有(szusz_u 表示以 uu 为根的子树的大小):
Tui={Tvi(vu的儿子)icuszui=cu{T_u}_i=\begin{cases} \sum {T_v}_i(v 是 u 的儿子) && i\neq c_u\\ sz_u && i = c_u \end{cases}
即如图:
uu 开头的路径中,其余颜色出现的次数不变,而 uu 到子树内任一点都要经过 cuc_u
现在求出了以某一点为根的答案,考虑换根。考察如图的情况:
已知 ff 为根时的答案,从 ff 转移到 uu 的过程中,有:
Ti={Tiicuicfni=cunszu+Tucfi=cfT^\prime_i=\begin{cases}T_i && i \neq c_u 且i\neq c_f\\ n && i =c_u\\ n - sz_u + {{T_u}_c}_f && i=c_f \end{cases}
第三种情况中,TuT_u 为线段树合并得到的 uu 的线段树。此时意为,ff 以外的部分,路径必经过 cfc_f,而以内的部分已经求过,线段树合并时将这个值记录即可。
线段树合并 O(nlogn)\mathcal O(n \log n),而每次转移是 O(logn)\mathcal O(\log n) 的,总复杂度 O(nlogn)\mathcal O(n \log n)
CPP
#include<bits/stdc++.h>
#define debug(x) cerr << #x << ": " << x << '\n';
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define lop(i, a, b) for (int i = (a); i < (b); i++)
#define dwn(i, a, b) for (int i = (a); i >= (b); i--)
#define elif else if
#define iosfst ios::sync_with_stdio(0);cin.tie(0), cout.tie(0)
#define pb push_back
#define _if (
#define _then ?(
#define _els ):
#define _end )
#define intt long long
using namespace std;
#define N 100005
bool mem1;
vector<int>g[N];
int n, fa[N], a[N], sz[N];intt val[N], ans[N];
namespace sgt{
    int tp, ls[N * 32], rs[N * 32], fa[N * 32];intt tr[N * 32];
    int rt[N * 32];
    inline void pushup(int p) {
        tr[p] = tr[ls[p]] + tr[rs[p]];
    }
    inline void upd(int p, int l, int r, int x, intt v) {
        if(l == r) {
            tr[p] += v;
            return ;
        }
        int mid = (l + r) >> 1;
        if(x <= mid) {
            if(!ls[p]) {
                ls[p] = ++ tp;
                fa[tp] = p;
            }
            upd(ls[p], l, mid, x, v);
        }
        else {
            if(!rs[p]) {
                rs[p] = ++ tp;
                fa[tp] = p;
            }
            upd(rs[p], mid + 1, r, x, v);
        }
        pushup(p);
    }
    inline void set(int p, int l, int r, intt x, intt v) {
        if(l == r) {
            tr[p] = v;
            return ;
        }
        int mid = (l + r) >> 1;
        if(x <= mid) {
            if(!ls[p]) {
                ls[p] = ++ tp;
                fa[tp] = p;
            }
            set(ls[p], l, mid, x, v);
        }
        else {
            if(!rs[p]) {
                rs[p] = ++ tp;
                fa[tp] = p;
            }
            set(rs[p], mid + 1, r, x, v);
        }
        pushup(p);
        //cout << "Set " << x << " " << v << "\n";
        //cout << p << " " << l << " " << r << " " << tr[p] << "\n";
    }
    inline int merge(int p1, int p2) {
        if(!p1 || !p2) return p1 | p2;
        ls[p1] = merge(ls[p1], ls[p2]);
        rs[p1] = merge(rs[p1], rs[p2]);
        tr[p1] += tr[p2];
        return p1;
    }
    inline int qry(int p, int l, int r, intt x) {
        if(l == r) 
            return tr[p];
        int mid = (l + r) >> 1;
        if(x <= mid) {
            if(!ls[p]) 
                return 0;
            return qry(ls[p], l, mid, x);
        }
        else {
            if(!rs[p]) 
                return 0;
            return qry(rs[p], mid + 1, r, x);
        }
    }
}
bool mem2;

void dfs(int u, int f) {
    fa[u] = f;
    sz[u] = 1;
    for(auto v : g[u])
        if(v != f)
            dfs(v, u),
            sgt::rt[u] = sgt::merge(sgt::rt[u], sgt::rt[v]),
            sz[u] += sz[v];
    if(sgt::rt[u] == 0)
        sgt::rt[u] = ++sgt::tp;
    sgt::set(sgt::rt[u], 1, 1000000, a[u], sz[u]);  
    if(f)
        val[u] = sgt::qry(sgt::rt[u], 1, 1000000, a[f]);  
}

void dfs2(int u, int f) {
    //cout << "DFS2 " << u << " " << f << "\n";
    if(u == 1) {
        for(auto v : g[u])
            dfs2(v, u);
        return ;
    }
    intt rbp = sgt::qry(sgt::rt[1], 1, 1000000, a[u]);
    intt rdx = sgt::qry(sgt::rt[1], 1, 1000000, a[f]);
    sgt::set(sgt::rt[1], 1, 1000000, a[u], n);
    sgt::set(sgt::rt[1], 1, 1000000, a[f], n - sz[u] + val[u]);
    ans[u] = sgt::tr[sgt::rt[1]];
    for(auto v: g[u])
        if(f != v)
            dfs2(v, u);
    sgt::set(sgt::rt[1], 1, 1000000, a[u], rbp);
    sgt::set(sgt::rt[1], 1, 1000000, a[f], rdx);
}

signed main() {
    iosfst;
    cin >> n;
    rep(i, 1, n) cin >> a[i];
    lop(i, 1, n) {
        int x, y;
        cin >> x >> y;
        g[x].pb(y);
        g[y].pb(x);
    }
    dfs(1, 0);
    ans[1] = sgt::tr[sgt::rt[1]];
    dfs2(1, 0);
    rep(i, 1, n)
        cout << ans[i] << "\n";
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...