专栏文章

ABC401F题解

AT_abc401_f题解参与者 3已保存评论 2

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
2 条
当前快照
1 份
快照标识符
@miplz0iz
此快照首次捕获于
2025/12/03 14:10
3 个月前
此快照最后确认于
2025/12/03 14:10
3 个月前
查看原文
首先找到第一棵树的直径 d1d_1 和第二棵树的直径 d2d_2
不难发现,连接第一棵树的 ii 结点和第二棵树的 jj 结点后,f(i,j)f(i, j) 就等于 d1d_1d2d_2 和一条经过了边 (i,j)(i,j) 的路径的长度。这条路径一定可以被分成三部分:
  1. ii 到第一棵树中离 ii 最远一点的路径。
  2. jj 到第二棵树中离 jj 最远一点的路径。
  3. (i,j)(i, j)
则将这三部分的长度加起来就得到了 f(i,j)f(i, j)
我们记 ii 到第一棵树中离 ii 最远一点的路径长度为 aia_i,记 jj 到第二棵树中离 jj 最远一点的路径长度为 bjb_j,则可写出 f(i,j)=max(max(d1,d2),ai+bj+1)f(i, j) = \max(\max(d_1, d_2), a_i + b_j + 1)
计算直径的长度非常简单,先随便以一个点为起点,做一遍 bfs,找到离这个起点最远的点,这个点一定是直径的一个端点。然后再以这个端点为起点,做一遍 bfs,离这个端点最远的点一定是直径的另一个端点。
然后我们考虑如何计算出 aabb。在计算直径时,我们可以发现,端点中的一个一定是离这个点的最远的一个点,我们可以在 dfs 过程中记录这个值。注意两个端点都要当一次起点。
然后我们考虑计算所有的 f(i,j)f(i, j) 的值。发现如果 ai+bj+1<max(d1,d2)a_i + b_j + 1 < \max(d_1, d_2),则 f(i,j)=max(d1,d2)f(i, j) = \max(d_1, d_2)
可以将 aa 按从小到大排序,将 bb 从大到小排序。从 11n2n_2 枚举 jj,然后用一个变量记录第一个 ii 使得 ai+bj+1max(d1,d2)a_i + b_j + 1 \ge \max(d_1, d_2) 的值,则此时对于所有 kk(1k<i)(1 \le k < i),有 ak+bj+1<max(d1,d2)a_k + b_j + 1 < \max(d_1, d_2),则 f(i,j)=max(d1,d2)f(i, j) = \max(d_1, d_2)。而所有 kk(ikn1)(i \le k \le n_1),有 ak+bj+1max(d1,d2)a_k + b_j + 1 \ge \max(d_1, d_2),则 f(i,j)=ak+bj+1f(i, j) = a_k + b_j + 1。那么 bjb_j 对于所有 aa 的贡献就是 (i1)×max(d1,d2)+(n1i+1)×(bi)+i+k=in1ak(i-1) \times \max(d_1, d_2) + (n_1 - i + 1) \times (b_i) + i + \displaystyle\sum_{k=i}^{n_1} a_k
显然式子的最后一项可以用后缀和优化掉。并且 ii 对于 bjb_j 的减小是递增的,则可以用一个类似于滑动窗口的东西求出来。计算的时间复杂度是 O(n)O(n) 的。

代码

CPP
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 2e5+5;

int n1, n2;
int d1, d2, dmax;
int dis1[MAXN], dis2[MAXN];
int d11, d12, d21, d22;
int a[MAXN], b[MAXN];
vector<int> e1[MAXN], e2[MAXN];
ll ans;
ll sa[MAXN];
//e:当前 dfs 的树。
//d:起点到 u 的距离。
//maxd:上文中的 a 或者 b。
//c:求出的端点。
void dfs(int u, int fa, vector<int> *e, int *d, int *maxd, int *c)
{
	d[u] = d[fa]+1;
	maxd[u] = max(maxd[u], d[u]);
	if(d[u] > d[*c]) *c = u;
	for(auto v : e[u])
	{
		if(v == fa) continue;
		dfs(v, u, e, d, maxd, c);
	}
}

int main()
{
	scanf("%d", &n1);
	for(int i = 1;i < n1;i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		e1[u].push_back(v);
		e1[v].push_back(u);
	}
	dis1[0] = -1;
	dfs(1, 0, e1, dis1, a, &d11);
	dfs(d11, 0, e1, dis1, a, &d12);
	d1 = dis1[d12];
	dfs(d12, 0, e1, dis1, a, &d11);//记住第二个端点也要 dfs 一便。
	
	scanf("%d", &n2);
	for(int i = 1;i < n2;i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		e2[u].push_back(v);
		e2[v].push_back(u);
	}
	dis2[0] = -1;
	dfs(1, 0, e2, dis2, b, &d21);
	dfs(d21, 0, e2, dis2, b, &d22);
	d2 = dis2[d22];
	dfs(d22, 0, e2, dis2, b, &d21);

	dmax = max(d1, d2);

	sort(a + 1, a + n1 + 1); a[n1+1] = 0x3f3f3f3f; a[0] = INT_MIN;//a[n1+1] 赋值是为了防止下面 cur 不断的加。
	sort(b + 1, b + n2 + 1, greater<int>() );

	for(int i = n1;i >= 1;i--) sa[i] = sa[i+1] + a[i];
    //求出 a 的后缀和 sa.
	for(int i = 1, cur = 0;i <= n2;i++)
	{
		while(a[cur] + b[i] + 1 <= dmax) cur++;
        //类似于滑动窗口的东西
		ans += (ll)(cur - 1LL) * (ll)dmax + (ll)(n1 - cur + 1LL) * (ll)(b[i] + 1LL) + sa[cur];
	}
	printf("%lld\n", ans);

	return 0;
}

评论

2 条评论,欢迎与作者交流。

正在加载评论...