专栏文章

P3714 [BJOI2017] 树的难题

P3714题解参与者 2已保存评论 1

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
1 条
当前快照
1 份
快照标识符
@miqdfhk0
此快照首次捕获于
2025/12/04 02:59
3 个月前
此快照最后确认于
2025/12/04 02:59
3 个月前
查看原文
本题存在单 log\log 做法。
我们考虑点分治,选择重心 xx 作为根。
假设每条边有一个权值,权值定义为每条边的权值和,我们的做法是依次遍历所有儿子的子树,计算出子树内 g(d)g(d) 表示深度为 dd 的最大权值和是多少,同时维护前面的 f(d)f(d) 表示前面的子树中深度为 dd 的最大是多少。
计算答案可以从大到小扫一遍 gg,然后将查询 [Li,Ri][L-i,R-i]ff 的最大值,可以用单调队列维护。
注意到每次统计答案的时间复杂度和 g,fg,f 的长度相关,我们可以将所有儿子按照子树深度从小到大排序,然后从左往右扫描,设 did_i 表示排序后第 ii 子树的深度,则我们的时间复杂度是 O((di+di+1))O(\sum(d_i+d_{i+1})),而由于 didi+1d_i \le d_{i+1},所以实际上就是 O(di)O(\sum d_i),而 di\sum d_i 的上界是当前分治的节点个数,所以这样扫描使得每一轮的时间复杂度是 O(n)O(n),点分治就是 O(nlogn)O(n \log n)
注意到还需要将子树按照深度排序,这样看似是 O(nlog2n)O(n \log^2 n) 的,实际上这是 O(nlogn)O(n \log n)
我们每次需要排序的长度是子树的个数,而每个子树都会继续分治。
而我们知道分治的总次数是 O(n)O(n) 的,所以实际上这些排序的长度和也是 O(n)O(n) 的,所以总时间复杂度是 O(nlogn)O(n \log n) 的。
好了,现在就把点分治部分讲完了,回到原问题。
我们不用更改上面的框架,我们思考如何处理这个颜色段。
首先,我们还是可以算出从 xx 到每个点的路径的权值。
但现在问题在于两个路径的合并有两种:从 xx 出发的两条路径,如果第一条边颜色相同,还要减去这个颜色的权值一次。
我们考虑现在有一个从 xx 出发的路径 AA,他希望前面找一条路径合并使得权值尽量大,那么只能是以下两种:
  • 权值最大的异色路径
  • 权值最大的同色路径
这是后可以不用线段树等数据结构,这种问题有一个很简单的方法:记录最大值和次大值。
我们考虑记录前面所有路径的最大值和次大值,要求最大值和次大值的颜色不能相同。
那么如果 AA 和最大值不同色,那么它肯定选择和最大值合并,这是显然的。
如果 AA 和最大值同色,那么它一定和次大值不同色,那么最大值就是权值最大的同色路径,次大值就是权值最大的异色路径,我们比较以下而这即可。
所以我们用 ff 数组来记录最大值和次大值即可,剩下部分没有区别,只是记录的信息变了一下。
这样就能做到严格 O(nlogn)O(n \log n) 了。这个记录最大值和次大值的技巧实际上和树形 dpdp 求直径有点像。
代码实现和重建计划,Freezing with Style这两个问题几乎一样。
CPP
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2e5 + 5;
const int inf = -2e9;
typedef pair<int, int> pii;
typedef pair<pii, pii> pp;
#define debug(x) cout << #x << "=" << x << endl
#define mp(x, y) make_pair(x, y)
#define fi first
#define se second
#define all(x) x.begin(), x.end()

int n, m, L, R;
struct Edge {
	int to, val;
	Edge (int _to = 0, int _val = 0) :
		to(_to), val(_val) {}
};
vector<Edge> e[N];
int c[N] = {0};
bool vis[N] = {false};
int sz[N] = {0};
int getsz(int x, int pr) {
	sz[x] = 1;
	for (auto i: e[x])
		if (i.to != pr && !vis[i.to])
			sz[x] += getsz(i.to, x);
	return sz[x];
}
int totsz, mxsz, rt;
void getrt(int x, int pr) {
	int msz = totsz - sz[x];
	for (auto i: e[x])
		if (i.to != pr && !vis[i.to]) {
			getrt(i.to, x);
			msz = max(msz, sz[i.to]);
		}
	if (mxsz > msz)
		mxsz = msz, rt = x;
}
int getd(int x, int pr) {
	int mxd = 0;
	for (auto i: e[x])
		if (i.to != pr && !vis[i.to]) 
			mxd = max(mxd, getd(i.to, x) + 1);
	return mxd;
}

pp f[N];
int g[N] = {0};
void add(pp &x, pii y) {
	if (y.fi > x.fi.fi) {
		if (x.fi.se != y.se)
			x.se = x.fi, x.fi = y;
		else
			x.fi = y;
	}
	else if (x.se.fi < y.fi && x.fi.se != y.se)	
		x.se = y;
}

void upd(int x, int pr, int d, int tv, int cl, int cfr) {
	g[d] = max(g[d], tv);
	for (auto i: e[x])
		if (i.to != pr && !vis[i.to])
			upd(i.to, x, d + 1, tv + (cfr != i.val) * c[i.val], cl, i.val);
}

int ans = 0;

int q[N] = {0};
int cal(pp x, int cl) {
	if (x.fi.se != cl)
		return x.fi.fi;
	return max(x.fi.fi - c[cl], x.se.fi);
}
void calc(int _n, int _m, int cl) {
/*	debug(_n);
	for (int i = 0; i <= _n; i++)
		printf("[(%d, %d), (%d, %d)], ", f[i].fi.fi, f[i].fi.se, f[i].se.fi, f[i].se.se);
	printf("\n");
	debug(_m);
	for (int i = 1; i <= _m; i++)
		printf("%d, ", g[i]);
	printf("\n");*/
	int l = 0, r = 0;
	for (int i = _m, j = 0; i >= 1; i--) {
		while (j <= _n && j <= R - i) {
			while (l < r && cal(f[q[r - 1]], cl) < cal(f[j], cl))
				r--;
			q[r++] = j++;
		}
		while (l < r && q[l] < L - i)
			l++;
		if (l < r)
			ans = max(ans, g[i] + cal(f[q[l]], cl));
	}
}

void slv(int x) {
	getsz(x, 0);
	totsz = sz[x], mxsz = 2e9, rt = 0;
	getrt(x, 0);
	x = rt;
	vector<pair<int, pii> > res;
	for (auto i: e[x])
		if (!vis[i.to])
			res.push_back(mp(getd(i.to, x) + 1, mp(i.to, i.val)));
	sort(all(res));
	int len = 0;
	f[0] = mp(mp(0, 0), mp(0, 0));
	
//	debug(x);
	
	for (auto j: res) {
		int mxd = j.fi, u = j.se.fi, w = j.se.se;
		for (int i = 1; i <= mxd; i++)
			g[i] = inf;
	//	debug(u);
		upd(u, x, 1, c[w], w, w);
		calc(len, mxd, w);
		for (int i = 1; i <= mxd; i++) {
			if (i <= len)
				add(f[i], mp(g[i], w));
			else
				f[i] = mp(mp(g[i], w), mp(inf, 0));
		}
		len = mxd;
	}
	
	vis[x] = true;
	for (auto i: res)
		slv(i.se.fi);
}

int main() {
	scanf("%d%d%d%d", &n, &m, &L, &R);
	for (int i = 1; i <= m; i++)
		scanf("%d", &c[i]);
	for (int i = 1, u, v, w; i < n; i++) {
		scanf("%d%d%d", &u, &v, &w);
		e[u].push_back(Edge(v, w));
		e[v].push_back(Edge(u, w));
	}
	ans = -2e9;
	slv(1);
	printf("%d\n", ans);
	return 0;
} 

评论

1 条评论,欢迎与作者交流。

正在加载评论...