专栏文章
P14251 [集训队互测 2025] Everlasting Friends?
P14251题解参与者 1已保存评论 0
文章操作
快速查看文章及其快照的属性,并进行相关操作。
- 当前评论
- 0 条
- 当前快照
- 1 份
- 快照标识符
- @mingfk43
- 此快照首次捕获于
- 2025/12/02 02:00 3 个月前
- 此快照最后确认于
- 2025/12/02 02:00 3 个月前
考虑 。枚举连通块在 的根 ,只保留 两端都在 子树内的边。发现 中每条边覆盖 的一条祖孙链,且每条边至少被覆盖一次。
然后有一个比较深刻的观察是,把选连通块看成断一些边,一条边能断当且仅当它只被覆盖一次,感性理解就是如果被覆盖两次,那么断了这条边会导致那两个上端点不连通。
那么考虑 DP, 表示以 为根的连通块数量,那么 ,其中 为 这条边被覆盖的次数。答案即为 。时间复杂度 。
考虑优化。固然可以 DDP 优化到 ,但是有更简单的方法。
只做一次 DFS,递归到 时,找到它在 中的所有边 。设 往下的边为 ,相当于是把 上 路径上的边全部设成不可断开($$)。
从上往下设,那么考虑到一条原本可以断开的边 时, 对 的贡献原本是 现在变成 ,那么只需令 即可。除 的问题可以将每个数表示成 解决(集合还在追我()。时间复杂度 ,其中 。
考虑 。固定连通块在 上的根 和 上的根 。可以归纳证明连通块 在 中也是连通块。
并且可能的 最多有一个,因为设初始连通块为 在 中路径的点集,连通块不断拓展的过程中,若 且 不在连通块且 在 中与连通块中的点 相邻,那么 一定在 或 中是 的祖先,从而一定要被加入连通块。
通过上述过程可以观察出结论: 只可能是 中 子树与 中 子树的交。
于是问题转化成有多少对 ,满足:
- 在 上是 祖先;
- 在 上是 祖先;
- 中 子树与 中 子树的交在 和 上都是连通块。
数连通块自然考虑点减边容斥。枚举 ,给每个 设一个权值 (初始若 在 中 子树内则 ,否则 )。若一个点同时在两棵子树内,对 有 的贡献,对于 中 子树内的一条边 ,若 都在 子树内,对 有 的贡献,对于 中 子树内的一条边 同理。 一定 ,若 说明 合法。
那么考虑在 上做线段树合并,对 的修改相当于若干个在 上的链加,统计答案相当于查 在 上到根的路径的最小值和最小值个数。
时空复杂度均为 ,感受一下空间很难卡满,加上垃圾回收后可以通过。
代码
CPP#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 200100;
const int logn = 20;
const int maxm = 16000100;
const int inf = 0x3f3f3f3f;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, type;
int fa[maxn], p[maxn], pa[maxn];
vector<int> G[maxn], G1[maxn], G2[maxn];
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
namespace Sub1 {
ll ans;
struct node {
ll x, y;
node(ll _x = 0, ll _y = 0) : x(_x), y(_y) {}
} f[maxn];
inline node operator + (const node &a, const node &b) {
if (a.y < b.y) {
return a;
} else if (a.y > b.y) {
return b;
} else {
ll x = (a.x + b.x) % mod;
if (x == 0) {
return node(1, a.y + 1);
} else {
return node(x, a.y);
}
}
}
inline node operator * (const node &a, const node &b) {
return node(a.x * b.x % mod, a.y + b.y);
}
inline node operator / (const node &a, const node &b) {
return node(a.x * qpow(b.x, mod - 2) % mod, a.y - b.y);
}
void dfs(int u) {
f[u] = node(1, 0);
for (int v : G1[u]) {
dfs(v);
int w = p[v];
for (int x = find(w); x != v; x = find(x)) {
f[v] = f[v] / (f[x] + node(1, 0)) * f[x];
fa[x] = pa[x];
}
f[u] = f[u] * (f[v] + node(1, 0));
}
ans = (ans + (f[u].y ? 0 : f[u].x)) % mod;
}
void solve() {
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
dfs(n);
printf("%lld\n", ans);
}
}
int st1[logn][maxn], st2[logn][maxn], dfn1[maxn], dfn2[maxn], tim;
inline int get1(int i, int j) {
return dfn1[i] < dfn1[j] ? i : j;
}
inline int get2(int i, int j) {
return dfn2[i] < dfn2[j] ? i : j;
}
inline int qlca1(int x, int y) {
if (x == y) {
return x;
}
x = dfn1[x];
y = dfn1[y];
if (x > y) {
swap(x, y);
}
++x;
int k = __lg(y - x + 1);
return get1(st1[k][x], st1[k][y - (1 << k) + 1]);
}
inline int qlca2(int x, int y) {
if (x == y) {
return x;
}
x = dfn2[x];
y = dfn2[y];
if (x > y) {
swap(x, y);
}
++x;
int k = __lg(y - x + 1);
return get2(st2[k][x], st2[k][y - (1 << k) + 1]);
}
void dfs(int u, int t) {
dfn1[u] = ++tim;
st1[0][tim] = t;
for (int v : G1[u]) {
dfs(v, u);
}
}
int sz[maxn], son[maxn], dep[maxn], top[maxn];
int dfs2(int u, int f, int d) {
fa[u] = f;
sz[u] = 1;
dep[u] = d;
int mx = -1;
for (int v : G2[u]) {
sz[u] += dfs2(v, u, d + 1);
if (sz[v] > mx) {
son[u] = v;
mx = sz[v];
}
}
return sz[u];
}
void dfs3(int u, int tp) {
top[u] = tp;
dfn2[u] = ++tim;
st2[0][tim] = fa[u];
if (!son[u]) {
return;
}
dfs3(son[u], tp);
for (int v : G2[u]) {
if (!dfn2[v]) {
dfs3(v, v);
}
}
}
inline pii operator + (const pii &a, const pii &b) {
if (a.fst < b.fst) {
return a;
} else if (a.fst > b.fst) {
return b;
} else {
return mkp(a.fst, a.scd + b.scd);
}
}
namespace SGT {
int ls[maxm], rs[maxm], tag[maxm], nt, stk[maxm], top;
pii a[maxm];
inline void init() {
for (int i = 0; i < maxm; ++i) {
a[i] = pii(inf, 0);
}
}
inline void pushup(int x) {
a[x] = a[ls[x]] + a[rs[x]];
a[x].fst += tag[x];
}
inline void pushtag(int x, int y) {
if (!x) {
return;
}
a[x].fst += y;
tag[x] += y;
}
inline void delnode(int x) {
a[x] = pii(inf, 0);
ls[x] = rs[x] = tag[x] = 0;
if (top + 1 < maxm) {
stk[++top] = x;
}
}
inline int newnode() {
assert(nt + 1 < maxm);
return top ? stk[top--] : (++nt);
}
void update(int &rt, int l, int r, int ql, int qr, int x) {
if (!rt) {
rt = newnode();
}
if (ql <= l && r <= qr) {
pushtag(rt, x);
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) {
update(ls[rt], l, mid, ql, qr, x);
}
if (qr > mid) {
update(rs[rt], mid + 1, r, ql, qr, x);
}
pushup(rt);
}
void modify(int &rt, int l, int r, int x) {
if (!rt) {
rt = newnode();
}
if (l == r) {
a[rt].fst -= inf;
a[rt].scd = 1;
return;
}
int mid = (l + r) >> 1;
(x <= mid) ? modify(ls[rt], l, mid, x) : modify(rs[rt], mid + 1, r, x);
pushup(rt);
}
int merge(int u, int v, int l, int r) {
if (!u || !v) {
return u | v;
}
tag[u] += tag[v];
if (l == r) {
bool fl = (a[u].fst > 1e9) && (a[v].fst > 1e9);
if (a[u].fst >= inf) {
a[u].fst -= inf;
}
if (a[v].fst >= inf) {
a[v].fst -= inf;
}
a[u].fst = a[u].fst + a[v].fst + (fl ? inf : 0);
a[u].scd |= a[v].scd;
delnode(v);
return u;
}
int mid = (l + r) >> 1;
ls[u] = merge(ls[u], ls[v], l, mid);
rs[u] = merge(rs[u], rs[v], mid + 1, r);
pushup(u);
delnode(v);
return u;
}
pii query(int rt, int l, int r, int ql, int qr) {
if (!rt) {
return pii(inf, 0);
}
if (ql <= l && r <= qr) {
return a[rt];
}
int mid = (l + r) >> 1;
pii res(inf, 0);
if (ql <= mid) {
res = res + query(ls[rt], l, mid, ql, qr);
}
if (qr > mid) {
res = res + query(rs[rt], mid + 1, r, ql, qr);
}
res.fst += tag[rt];
return res;
}
}
vector<int> vc[maxn];
ll ans;
int rt[maxn];
inline void update(int &rt, int x, int y) {
while (x) {
SGT::update(rt, 1, n, dfn2[top[x]], dfn2[x], y);
x = fa[top[x]];
}
}
inline pii query(int rt, int x) {
pii res(inf, 0);
while (x) {
res = res + SGT::query(rt, 1, n, dfn2[top[x]], dfn2[x]);
x = fa[top[x]];
}
return res;
}
void dfs4(int u) {
for (int v : G1[u]) {
dfs4(v);
rt[u] = SGT::merge(rt[u], rt[v], 1, n);
}
SGT::modify(rt[u], 1, n, dfn2[u]);
update(rt[u], u, 2);
for (int v : G1[u]) {
int w = qlca2(u, v);
update(rt[u], w, -1);
}
for (int v : vc[u]) {
update(rt[u], v, -1);
}
pii p = query(rt[u], u);
if (p.fst == 2) {
ans += p.scd;
}
}
void solve() {
scanf("%lld%lld", &type, &n);
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
for (int i = 1; i <= n; ++i) {
for (int j : G[i]) {
if (j < i && find(i) != find(j)) {
int k = find(j);
p[k] = j;
fa[k] = i;
pa[k] = i;
G1[i].pb(k);
}
}
}
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
for (int i = n; i; --i) {
for (int j : G[i]) {
if (j > i && find(j) != find(i)) {
int k = find(j);
fa[k] = i;
G2[i].pb(k);
}
}
}
if (type == 1) {
Sub1::solve();
return;
}
dfs(n, 0);
tim = 0;
dfs2(1, 0, 1);
dfs3(1, 1);
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
st1[j][i] = get1(st1[j - 1][i], st1[j - 1][i + (1 << (j - 1))]);
st2[j][i] = get2(st2[j - 1][i], st2[j - 1][i + (1 << (j - 1))]);
}
}
for (int i = 1; i <= n; ++i) {
for (int j : G2[i]) {
vc[qlca1(i, j)].pb(i);
}
}
SGT::init();
dfs4(n);
printf("%lld\n", ans % mod);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}
相关推荐
评论
共 0 条评论,欢迎与作者交流。
正在加载评论...