专栏文章

题解:AT_abc395_g [ABC395G] Minimum Steiner Tree 2

AT_abc395_g题解参与者 4已保存评论 5

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
5 条
当前快照
1 份
快照标识符
@miq10ljp
此快照首次捕获于
2025/12/03 21:12
3 个月前
此快照最后确认于
2025/12/03 21:12
3 个月前
查看原文

前言

之前没怎么正经学习这个算法,这次考到了就无语了。所以经过我的痛定思痛啊,参考了很多文章,打算写一篇详细的题解来讲解最小斯坦纳树以及本题相关做法,如果你已经掌握了模板题的做法请略过前面的介绍部分。

什么是最小斯坦纳树

最小斯坦纳树问题与最小生成树问题其实非常相似,把问题形式化如下:
给你一个无向连通图 G=(V,E)G=(V,E),以及一个包含 kk 个节点的点集 SS,包含点集 SS 的联通子图就是斯坦纳树。但一般在算法竞赛中我们都只在意包含点集 SS 的最小联通子图,也就是最小斯坦纳树
先预告一下,这个问题是 NP 困难问题,所以没有正经的多项式解法,但是存在多项式的求近似解做法。

求解问题

这里先给出一张无向图,加粗的点为点集 S={2,4,5,7}S=\{2,4,5,7\},要求求出点集 SS 的最小斯坦纳树。
其中最小斯坦纳树的值为 1111,整棵树长这样:
你会发现这棵树不仅包含了点集 SS,还包含了图上另外一个点 66。所以我们发现,最小斯坦纳树上不一定只有 SS 中的点。像点 66 这样添加后能使答案最小化的点我们称之为斯坦纳点
接下来我们先介绍两个简单的性质:
  1. 最小斯坦纳树只包含点集 SS 与斯坦纳点,这个从定义就能说明,所以不多赘述。
  2. 答案子图一定是一颗树
【证明】 如果答案子图 GG' 包含至少一个环,那么把环上边权最大的那一条边删除也能使联通,所以 GG' 中不包括环,证毕。这也是它的名字由来。
所以实际上我们考虑自下而上确定这颗树,具体的,我们设 fsf_{s} 表示联通状态包括 ss 的最小代价。
你可能想当然的以为:“转移很显然啊,枚举 SS 中的点集与 ss 中的点计算最短路转移即可。”
这样的转移是对的吗?显然是错的。我们没有考虑斯坦点带来的收益,导致没有枚举到正确状态使答案变大,这里还是图例解释。
刚刚这一张图如果按照刚刚的 dp 设计那么答案应该是 1111。因为它会钦定 14,47,151-4,4-7,1-5 这三条路径的答案。而最小斯坦纳树只需要考虑把点 66 加入集合即可,取到最小边权和为 4+1+2+3=104+1+2+3=10
所以我们需要改进刚刚的状态设计,不妨直接记一个点,表示树根,设 fi,sf_{i,s} 表示以 ii 为根节点且包含 ss 的树的最小边权和,其中如果 ii 不属于 SS 那么它就是一个斯坦纳点。
考虑怎么转移,因为我们不知道哪个点作为斯坦点能使答案变优,所以我们只能枚举一个 jj,考虑当前状态是从 jjii 转移而来。
fi,s=min(fi,s,fj,s+dis(i,j))f_{i,s}=\min(f_{i,s},f_{j,s}+dis(i,j))
这样的转移并不完备,因为它构造的树没有枚举树的形态,也就是说这样转移只会类似一条链,因为它把 jj 换根移动到了 ii,而没有考虑儿子更复杂的情况。
为了充分的枚举到所有状态集合。我们考虑枚举 ss 的子集即可。为什么这样是对的?可以这样理解:实际上 ss 的子集 TTsTs-T 可以看作是两个以 ii 为根的子图,只不过我们把它们的边取了一个并集而已,也就枚举到了所有集合。
fi,s=min(fi,s,fi,T+fj,sT)f_{i,s}=\min(f_{i,s},f_{i,T}+f_{j,s-T})
如果你真的看懂了斯坦纳树是怎么构造的,那么一定挥发两种转移是有先后顺序的,因为我们在换根的时候一定要保证当前状态最优,所以我们需要先枚举子集。也就是说枚举子集之后才去换根转移。并且实际上以最小斯坦纳树上的哪个点为根都是最优解。

实现以及细节

这里我们需要枚举点集 SS 中的每个点 xx,初始化以 xx 为根的子树以及对应的联通状态为 00
然后正如刚才所说,枚举状态集合,先转移子集再转移换根。
时间复杂度 O(n×3k+mlogm×2k)O(n \times 3^k + m \log m \times 2^k)kk 是点集大小。
以刚才讲解为例,这里提供 P6192 【模板】最小斯坦纳树 的代码。
CPP
#include<bits/stdc++.h>
using namespace std;
const int N=1e2+5;
int u,v,w,n,m,k,a[N],f[N][(1<<11)+5],dis[N];
struct Point{
	int v,val;
};
struct cmp{
	bool operator()(Point x,Point y){
		return x.val>y.val;
	}
};
priority_queue <Point,vector<Point>,cmp> q;
vector <Point> e[N];
bool vis[N];
void dij(int state){
	for(int i = 0;i < n;i++)
		if(f[i][state]!=0x3f3f3f3f)
			q.push((Point){i,f[i][state]});
	memset(vis,false,sizeof(vis));
	while(!q.empty()){
		int head=q.top().v;
		q.pop();
		if(vis[head])continue;
		vis[head]=1;
		for(int i = 0;i < e[head].size();i++){
			int v=e[head][i].v,val=e[head][i].val;
			if(f[v][state]>f[head][state]+val){
				f[v][state]=f[head][state]+val;
				q.push((Point){v,f[v][state]});
			}
		}
	}
}
int main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> n >> m >> k;
	for(int i = 1;i <= m;i++){
		cin >> u >> v >> w;u--,v--;
		e[u].push_back((Point){v,w});
		e[v].push_back((Point){u,w});
	}
	memset(f,0x3f,sizeof(f));
	for(int i = 0;i < k;i++){
		cin >> a[i];a[i]--;
		f[a[i]][1<<i]=0;
	}
	for(int i = 1;i < (1<<k);i++){
		for(int j = i&(i-1);j;j=(j-1)&i){
			if(j<(i^j))break;
			for(int now = 0;now < n;now++)f[now][i]=min(f[now][i],f[now][j]+f[now][i-j]);
		}
		dij(i);
	}
	int ans=2e9;
	for(int i = 0;i < n;i++)ans=min(ans,f[i][(1<<k)-1]);
	cout << ans;
	return 0;
}

这是我去年七月学的时候照着题解敲的,实际上有很多地方可以优化,比如可以预处理点对之间的最短路,方便换根时的转移。

本题做法

现在默认你至少会了模板题,现在这个问题就很简单了。
这个题实际上就是固定点集 SS 然后每次询问等价于往里面多丢两个数。
因为每次重新做复杂度爆炸,所以我们不妨预处理答案,直接枚举然后加入一个数 ii,暴力把 ii 计入状态。
然后跑 nn 次即可,时间复杂度 O(3k×n2+2k×n2)O(3^k \times n^2+2^k \times n^2)。这里预处理了最短路。
CPP
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=85,M=9,inf=1e18;
int n,m,K,a[N][N],x,y,f[N][N][1<<M];
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0),cout.tie(0);
	cin >> n >> K;
	for(int i = 1;i <= n;i++)for(int j = 1;j <= n;j++)cin >> a[i][j];
	for(int k = 1;k <= n;k++)
		for(int i = 1;i <= n;i++)
			for(int j = 1;j <= n;j++)
				a[i][j]=min(a[i][j],a[i][k]+a[k][j]);
	memset(f,0x3f,sizeof(f));
	for(int i = 1;i <= n;i++){
		for(int j = 1;j <= K;j++)f[i][j][1<<j-1]=0;
		f[i][i][1<<K]=0;
	}
	for(int state = 1;state < (1<<K+1);state++){
		for(int i = 1;i <= n;i++){
			for(int j = 1;j <= n;j++){
				for(int k = state&(state-1);k;k=(k-1)&state){
					if(k<(state^k))break;
					f[i][j][state]=min(f[i][j][state],f[i][j][k]+f[i][j][k^state]);
				}
			}
			for(int j = 1;j <= n;j++)for(int k = 1;k <= n;k++)f[i][j][state]=min(f[i][j][state],f[i][k][state]+a[k][j]);
		}
	}cin >> m;//
	while(m--){
		cin >> x >> y;
		cout << f[x][y][(1<<K+1)-1] << "\n";
	}
	return 0;
}

评论

5 条评论,欢迎与作者交流。

正在加载评论...