专栏文章

P14251 [集训队互测 2025] Everlasting Friends?

P14251题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@mingfk43
此快照首次捕获于
2025/12/02 02:00
3 个月前
此快照最后确认于
2025/12/02 02:00
3 个月前
查看原文
考虑 tp=1tp = 1。枚举连通块在 TmaxT_{\max} 的根 xx,只保留 TT 两端都在 xx 子树内的边。发现 TT 中每条边覆盖 TmaxT_{\max} 的一条祖孙链,且每条边至少被覆盖一次。
然后有一个比较深刻的观察是,把选连通块看成断一些边,一条边能断当且仅当它只被覆盖一次,感性理解就是如果被覆盖两次,那么断了这条边会导致那两个上端点不连通。
那么考虑 DP,fuf_u 表示以 uu 为根的连通块数量,那么 fu=vsonu(fv+[dv=1])f_u = \prod\limits_{v \in son_u} (f_v + [d_v = 1]),其中 dvd_v(u,v)(u, v) 这条边被覆盖的次数。答案即为 fxf_x。时间复杂度 O(n2)O(n^2)

考虑优化。固然可以 DDP 优化到 O(npolylog(n))O(n \operatorname{polylog}(n)),但是有更简单的方法。
只做一次 DFS,递归到 uu 时,找到它在 TT 中的所有边 (u,v),u>v(u, v), u > v。设 uu 往下的边为 uwu \to w,相当于是把 TmaxT_{\max}uvu \to v 路径上的边全部设成不可断开($$)。
从上往下设,那么考虑到一条原本可以断开的边 xyx \to y 时,fyf_yfwf_w 的贡献原本是 fy+1f_y + 1 现在变成 fyf_y,那么只需令 fwfw×fyfy+1f_w \gets f_w \times \frac{f_y}{f_y + 1} 即可。除 00 的问题可以将每个数表示成 a×0ba \times 0^b 解决(集合还在追我()。时间复杂度 O(n(logn+logP))O(n (\log n + \log P)),其中 P=998244353P = 998244353

考虑 tp=2tp = 2。固定连通块在 TmaxT_{\max} 上的根 xxTminT_{\min} 上的根 yy。可以归纳证明连通块 SSTT 中也是连通块。
并且可能的 SS 最多有一个,因为设初始连通块为 xyx \to yTT 中路径的点集,连通块不断拓展的过程中,若 y<z<xy < z < xzz 不在连通块且 zzTT 中与连通块中的点 ww 相邻,那么 zz 一定在 TmaxT_{\max}TminT_{\min} 中是 ww 的祖先,从而一定要被加入连通块。
通过上述过程可以观察出结论:SS 只可能是 TmaxT_{\max}xx 子树与 TminT_{\min}yy 子树的交。

于是问题转化成有多少对 (x,y)(x, y),满足:
  • yyTminT_{\min} 上是 xx 祖先;
  • xxTmaxT_{\max} 上是 yy 祖先;
  • TmaxT_{\max}xx 子树与 TminT_{\min}yy 子树的交在 TmaxT_{\max}TminT_{\min} 上都是连通块。
数连通块自然考虑点减边容斥。枚举 xx,给每个 yy 设一个权值 valyval_y(初始若 yyTmaxT_{\max}xx 子树内则 valy=2val_y = 2,否则 valy=+val_y = +\infty)。若一个点同时在两棵子树内,对 valyval_y22 的贡献,对于 TmaxT_{\max}xx 子树内的一条边 (u,v)(u, v),若 u,vu, v 都在 yy 子树内,对 valyval_y1-1 的贡献,对于 TminT_{\min}yy 子树内的一条边 (u,v)(u, v) 同理。valyval_y 一定 2\ge 2,若 valy=2val_y = 2 说明 yy 合法。
那么考虑在 TmaxT_{\max} 上做线段树合并,对 valyval_y 的修改相当于若干个在 TminT_{\min} 上的链加,统计答案相当于查 xxTminT_{\min} 上到根的路径的最小值和最小值个数。
时空复杂度均为 O(nlog2n)O(n \log^2 n),感受一下空间很难卡满,加上垃圾回收后可以通过。
代码CPP
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 200100;
const int logn = 20;
const int maxm = 16000100;
const int inf = 0x3f3f3f3f;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, type;
int fa[maxn], p[maxn], pa[maxn];
vector<int> G[maxn], G1[maxn], G2[maxn];

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

namespace Sub1 {
	ll ans;
	
	struct node {
		ll x, y;
		node(ll _x = 0, ll _y = 0) : x(_x), y(_y) {}
	} f[maxn];
	
	inline node operator + (const node &a, const node &b) {
		if (a.y < b.y) {
			return a;
		} else if (a.y > b.y) {
			return b;
		} else {
			ll x = (a.x + b.x) % mod;
			if (x == 0) {
				return node(1, a.y + 1);
			} else {
				return node(x, a.y);
			}
		}
	}
	
	inline node operator * (const node &a, const node &b) {
		return node(a.x * b.x % mod, a.y + b.y);
	}
	
	inline node operator / (const node &a, const node &b) {
		return node(a.x * qpow(b.x, mod - 2) % mod, a.y - b.y);
	}
	
	void dfs(int u) {
		f[u] = node(1, 0);
		for (int v : G1[u]) {
			dfs(v);
			int w = p[v];
			for (int x = find(w); x != v; x = find(x)) {
				f[v] = f[v] / (f[x] + node(1, 0)) * f[x];
				fa[x] = pa[x];
			}
			f[u] = f[u] * (f[v] + node(1, 0));
		}
		ans = (ans + (f[u].y ? 0 : f[u].x)) % mod;
	}
	
	void solve() {
		for (int i = 1; i <= n; ++i) {
			fa[i] = i;
		}
		dfs(n);
		printf("%lld\n", ans);
	}
}

int st1[logn][maxn], st2[logn][maxn], dfn1[maxn], dfn2[maxn], tim;

inline int get1(int i, int j) {
	return dfn1[i] < dfn1[j] ? i : j;
}

inline int get2(int i, int j) {
	return dfn2[i] < dfn2[j] ? i : j;
}

inline int qlca1(int x, int y) {
	if (x == y) {
		return x;
	}
	x = dfn1[x];
	y = dfn1[y];
	if (x > y) {
		swap(x, y);
	}
	++x;
	int k = __lg(y - x + 1);
	return get1(st1[k][x], st1[k][y - (1 << k) + 1]);
}

inline int qlca2(int x, int y) {
	if (x == y) {
		return x;
	}
	x = dfn2[x];
	y = dfn2[y];
	if (x > y) {
		swap(x, y);
	}
	++x;
	int k = __lg(y - x + 1);
	return get2(st2[k][x], st2[k][y - (1 << k) + 1]);
}

void dfs(int u, int t) {
	dfn1[u] = ++tim;
	st1[0][tim] = t;
	for (int v : G1[u]) {
		dfs(v, u);
	}
}

int sz[maxn], son[maxn], dep[maxn], top[maxn];

int dfs2(int u, int f, int d) {
	fa[u] = f;
	sz[u] = 1;
	dep[u] = d;
	int mx = -1;
	for (int v : G2[u]) {
		sz[u] += dfs2(v, u, d + 1);
		if (sz[v] > mx) {
			son[u] = v;
			mx = sz[v];
		}
	}
	return sz[u];
}

void dfs3(int u, int tp) {
	top[u] = tp;
	dfn2[u] = ++tim;
	st2[0][tim] = fa[u];
	if (!son[u]) {
		return;
	}
	dfs3(son[u], tp);
	for (int v : G2[u]) {
		if (!dfn2[v]) {
			dfs3(v, v);
		}
	}
}

inline pii operator + (const pii &a, const pii &b) {
	if (a.fst < b.fst) {
		return a;
	} else if (a.fst > b.fst) {
		return b;
	} else {
		return mkp(a.fst, a.scd + b.scd);
	}
}

namespace SGT {
	int ls[maxm], rs[maxm], tag[maxm], nt, stk[maxm], top;
	pii a[maxm];
	
	inline void init() {
		for (int i = 0; i < maxm; ++i) {
			a[i] = pii(inf, 0);
		}
	}
	
	inline void pushup(int x) {
		a[x] = a[ls[x]] + a[rs[x]];
		a[x].fst += tag[x];
	}
	
	inline void pushtag(int x, int y) {
		if (!x) {
			return;
		}
		a[x].fst += y;
		tag[x] += y;
	}
	
	inline void delnode(int x) {
		a[x] = pii(inf, 0);
		ls[x] = rs[x] = tag[x] = 0;
		if (top + 1 < maxm) {
			stk[++top] = x;
		}
	}
	
	inline int newnode() {
		assert(nt + 1 < maxm);
		return top ? stk[top--] : (++nt);
	}
	
	void update(int &rt, int l, int r, int ql, int qr, int x) {
		if (!rt) {
			rt = newnode();
		}
		if (ql <= l && r <= qr) {
			pushtag(rt, x);
			return;
		}
		int mid = (l + r) >> 1;
		if (ql <= mid) {
			update(ls[rt], l, mid, ql, qr, x);
		}
		if (qr > mid) {
			update(rs[rt], mid + 1, r, ql, qr, x);
		}
		pushup(rt);
	}
	
	void modify(int &rt, int l, int r, int x) {
		if (!rt) {
			rt = newnode();
		}
		if (l == r) {
			a[rt].fst -= inf;
			a[rt].scd = 1;
			return;
		}
		int mid = (l + r) >> 1;
		(x <= mid) ? modify(ls[rt], l, mid, x) : modify(rs[rt], mid + 1, r, x);
		pushup(rt);
	}
	
	int merge(int u, int v, int l, int r) {
		if (!u || !v) {
			return u | v;
		}
		tag[u] += tag[v];
		if (l == r) {
			bool fl = (a[u].fst > 1e9) && (a[v].fst > 1e9);
			if (a[u].fst >= inf) {
				a[u].fst -= inf;
			}
			if (a[v].fst >= inf) {
				a[v].fst -= inf;
			}
			a[u].fst = a[u].fst + a[v].fst + (fl ? inf : 0);
			a[u].scd |= a[v].scd;
			delnode(v);
			return u;
		}
		int mid = (l + r) >> 1;
		ls[u] = merge(ls[u], ls[v], l, mid);
		rs[u] = merge(rs[u], rs[v], mid + 1, r);
		pushup(u);
		delnode(v);
		return u;
	}
	
	pii query(int rt, int l, int r, int ql, int qr) {
		if (!rt) {
			return pii(inf, 0);
		}
		if (ql <= l && r <= qr) {
			return a[rt];
		}
		int mid = (l + r) >> 1;
		pii res(inf, 0);
		if (ql <= mid) {
			res = res + query(ls[rt], l, mid, ql, qr);
		}
		if (qr > mid) {
			res = res + query(rs[rt], mid + 1, r, ql, qr);
		}
		res.fst += tag[rt];
		return res;
	}
}

vector<int> vc[maxn];
ll ans;
int rt[maxn];

inline void update(int &rt, int x, int y) {
	while (x) {
		SGT::update(rt, 1, n, dfn2[top[x]], dfn2[x], y);
		x = fa[top[x]];
	}
}

inline pii query(int rt, int x) {
	pii res(inf, 0);
	while (x) {
		res = res + SGT::query(rt, 1, n, dfn2[top[x]], dfn2[x]);
		x = fa[top[x]];
	}
	return res;
}

void dfs4(int u) {
	for (int v : G1[u]) {
		dfs4(v);
		rt[u] = SGT::merge(rt[u], rt[v], 1, n);
	}
	SGT::modify(rt[u], 1, n, dfn2[u]);
	update(rt[u], u, 2);
	for (int v : G1[u]) {
		int w = qlca2(u, v);
		update(rt[u], w, -1);
	}
	for (int v : vc[u]) {
		update(rt[u], v, -1);
	}
	pii p = query(rt[u], u);
	if (p.fst == 2) {
		ans += p.scd;
	}
}

void solve() {
	scanf("%lld%lld", &type, &n);
	for (int i = 1; i <= n; ++i) {
		fa[i] = i;
	}
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		G[u].pb(v);
		G[v].pb(u);
	}
	for (int i = 1; i <= n; ++i) {
		for (int j : G[i]) {
			if (j < i && find(i) != find(j)) {
				int k = find(j);
				p[k] = j;
				fa[k] = i;
				pa[k] = i;
				G1[i].pb(k);
			}
		}
	}
	for (int i = 1; i <= n; ++i) {
		fa[i] = i;
	}
	for (int i = n; i; --i) {
		for (int j : G[i]) {
			if (j > i && find(j) != find(i)) {
				int k = find(j);
				fa[k] = i;
				G2[i].pb(k);
			}
		}
	}
	if (type == 1) {
		Sub1::solve();
		return;
	}
	dfs(n, 0);
	tim = 0;
	dfs2(1, 0, 1);
	dfs3(1, 1);
	for (int j = 1; (1 << j) <= n; ++j) {
		for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
			st1[j][i] = get1(st1[j - 1][i], st1[j - 1][i + (1 << (j - 1))]);
			st2[j][i] = get2(st2[j - 1][i], st2[j - 1][i + (1 << (j - 1))]);
		}
	}
	for (int i = 1; i <= n; ++i) {
		for (int j : G2[i]) {
			vc[qlca1(i, j)].pb(i);
		}
	}
	SGT::init();
	dfs4(n);
	printf("%lld\n", ans % mod);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...