专栏文章

【洛谷日报#185】浅谈虚树

算法·理论参与者 44已保存评论 51

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
51 条
当前快照
1 份
快照标识符
@mhz5ryx5
此快照首次捕获于
2025/11/15 01:55
4 个月前
此快照最后确认于
2025/11/29 05:25
3 个月前
查看原文

虚树的概念

虚树,是对于一棵给定的节点数为 nn 的树 TT,构造一棵新的树 TT' 使得总结点数最小包含指定的某几个节点和他们的LCA

虚树解决的问题

利用虚树,可以对于指定多组点集 SS 的询问进行每组 O(S(logn+logS)+f(S))O(|S|(\log n+\log |S|)+f(|S|)) 的回答,其中 f(x)f(x) 指的是对于一棵 xx 个点的树单组询问这个问题的时间复杂度。可以看到,这个复杂度基本上(除了那个 logn\log n 以外)与 nn 无关了。这样,对于多组询问的回答就可以省去每次询问都遍历一整棵树的 O(n)O(n) 复杂度了。

前置芝士

dfs\text{dfs} 序的性质以及 lca\text{lca}

例题

这里推荐 [SDOI2011]消耗战 作为例题(虽然许多博客都以它作为板子),因为消耗战有一些不用考虑的情况,不宜作为模板。
题意:给定一棵树,多组询问,每组询问给定 kk 个点,你可以删掉不同于kk 个点的 mm 个点,使得这 kk 个点两两不连通,要求最小化 mm,如果不可能输出 1-1。询问之间独立。
数据范围 n105,k105n\leq10^5,\sum k\leq10^5
一看到这种k105\sum k\leq10^5 的题很可能就是虚树了。

构造方法

先预处理整棵树 lca 和 dfs 序,接下来是对于每组询问的构造。
虚树的构建是一个增量算法,要首先将指定的这 kk 个点按照 dfs 序排序,然后按照顺序一一加入。可以强行先加入根节点以方便后面的处理。
虚树构建时会开一个栈 stack[]\text{stack}[],这个栈本质上和 dfs 递归时系统自动开的一个栈原理是一样的,也就是说这个栈保存了从根出发的一条不连续的路径(只储存原路径上按照之前加入的询问点要加入虚树的那些点,按照深度从小到大储存)。当加入 a[k]a[k] 后,满足 stack[1]=root,stack[top]=a[k],stack[x]\text{stack}[1]=\text{root},\text{stack}[\text{top}]=a[k],\text{stack}[x]stack[x1]\text{stack}[x-1] 后代。虚树上 uvu\rightarrow v 边的连接时间都是在 vv 被弹出栈时。
考虑如何加入一个新的节点 xx,设 z=lca(x,stack[top])z=lca(x,\text{stack}[\text{top}]),分两类讨论
  1. z=stack[top]z=\text{stack}[\text{top}],也就是 xxstack[top]\text{stack}[\text{top}] 的子树内节点。这时直接把 xx 入栈就好了。
  2. zstack[top]z\neq \text{stack}[\text{top}]
这种情况中,xx 一定不是 stack[top]\text{stack}[\text{top}] 子树内节点。如图
这是原树上的情况。这时,“......”指代的那些节点以及 stack[top-1],stack[top]\text{stack[\text{top}-1]},\text{stack[\text{top}]} 都应弹出栈外(相当于回溯了,开始访问 zz另一棵子树)(注意这里 stack[top1]stack[top]\text{stack}[\text{top}-1]\rightarrow\text{stack}[\text{top}]zxz\rightarrow x 在原树上不一定直接相连,这里未画出中间结点)
那我们不断弹出 stack[top]\text{stack}[\text{top}] ,直到 dep[stack[top1]]<dep[z]\text{dep}[\text{stack}[\text{top}-1]]<\text{dep}[z] ,这时“......”表示的点全部弹完。弹 stack[top]\text{stack}[\text{top}] 时都要在虚树上连一条 stack[top1]stack[top]\text{stack}[\text{top}-1]\rightarrow\text{stack}[\text{top}] 的边。
注意弹完时可能 stack[top]z\text{stack}[\text{top}]\neq z,我们需要把 zz 补充进虚树中来维护这个,直接加进栈即可。
插入完所有点之后要完全回溯,也就是把栈内节点都弹出,也要连 stack[top1]stack[top]\text{stack}[\text{top}-1]\rightarrow\text{stack}[\text{top}] 边。
代码如下,实现起来有一些差别
CPP
inline void ins(int x)
{
	if (tp==0)
	{
		st[tp=1]=x;
		return;
	}
	ance=lca(st[tp],x);
	while ((tp>1)&&(dep[ance]<dep[st[tp-1]]))
	{
		add(st[tp-1],st[tp]);
		--tp;
	}
	if (dep[ance]<dep[st[tp]]) add(ance,st[tp--]);
	if ((!tp)||(st[tp]!=ance)) st[++tp]=ance;
	st[++tp]=x;
}

正确性

对于任意指定两点 a,ba,blca\text{lca},在排序后的数组中都存在连续的两点 u,v(dfn[u]dfn[v])u,v(dfn[u]\leq dfn[v]) 分别属于 lca\text{lca} 的两棵子树,此时这 vv 加入时按照上面的操作必定会把 lca\text{lca} 加入栈,所以应当加入的点都加入了。对于非 lca\text{lca} 点,按照上面操作是不会计算出这个点的,更不会被加入虚树。所以,加入的点恰为所需要的点。

复杂度

每个指定点进栈出栈一次,这部分 O(k)O(\sum k)。排序和求 lca 为 O((klogk+klogn))O(\sum (k\log k+k\log n))

构建完成后的使用

以例题为例介绍虚树的使用。首先特判掉无解情况(即一个点和他的父亲都被指定)。构造好虚树后,我们给真正被指定的点的 siz 设置成 11 (因为有一些加入的点实际上只是 lca\text{lca},要区分开来),然后 dfs\text{dfs} 这棵虚树。
以下所说的节点 uu 有 siz 均表示有一个指定节点可以到达 uu (在执行了下面的删点之后)
对于一个被指定的点 uu ,如果存在孩子 vv 有 siz,那么意味着 uvu\rightarrow v 不删点就出现了连通,所以 uvu\rightarrow v 上随便去掉一个点就可以了(这里要 ++ans) ,如果孩子没有 siz 那就不用处理了。
对于一个未被指定的点 uu,统计有多少个孩子 vv 有 siz 。如果只有一个,把 uu 设置成有 siz 的就好了(相当于看上面的情况决定是否处理)。如果超过一个,那把 uu 删掉就好了(这里也要 ++ans )。
事实上难点完全在于建虚树,后面部分非常容易想到。
算法总复杂度 O(n+klog2k)O(n+\sum k\log _2k),如果用非 O(1)O(1) 的 lca 要多一个 O(lognk)O(\log n\sum k),如果用倍增/ST 表 lca\text{lca} 要多一个 nlog2nn\log _2n
注意在每组询问最后一遍 dfs\text{dfs} 虚树时,要把边清空,具体只需要修改头指针(对于 vector 直接 erase)就可以了。如果对每个询问暴力 memset0 会导致复杂度退化为 O(nq)O(nq)
完整代码如下
CPP
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
const int N=1e5+2,M=2e5+2;
int a[N],lj[M],nxt[M],fir[N],dfn[N],top[N],hc[N],siz[N],f[N],dep[N],st[N];
int n,m,q,i,x,y,c,bs,tp,ance,ans;
inline void read(int &x)
{
	c=getchar();
	while ((c<48)||(c>57)) c=getchar();
	x=c^48;c=getchar();
	while ((c>=48)&&(c<=57))
	{
		x=x*10+(c^48);
		c=getchar();
	}
}
inline void add(int x,int y)
{
	lj[++bs]=y;
	nxt[bs]=fir[x];
	fir[x]=bs;
}
void dfs1(int x)
{
	siz[x]=1;
	int i;
	for (i=fir[x];i;i=nxt[i]) if (lj[i]!=f[x])
	{
		dep[lj[i]]=dep[f[lj[i]]=x]+1;
		dfs1(lj[i]);
		siz[x]+=siz[lj[i]];
		if (siz[hc[x]]<siz[lj[i]]) hc[x]=lj[i];
	}
}
void dfs2(int x)
{
	dfn[x]=++bs;
	if (hc[x])
	{
		int i;
		top[hc[x]]=top[x];
		dfs2(hc[x]);
		for (i=fir[x];i;i=nxt[i]) if ((lj[i]!=f[x])&&(lj[i]!=hc[x])) dfs2(top[lj[i]]=lj[i]);
	}
}
inline int lca(int x,int y)
{
	while (top[x]!=top[y]) if (dep[top[x]]>dep[top[y]]) x=f[top[x]]; else y=f[top[y]];
	if (dep[x]<dep[y]) return x; return y;
}
void qs(int l,int r)//按照dfs序排序
{
	int i=l,j=r,m=dfn[a[l+r>>1]];
	while (i<=j)
	{
		while (dfn[a[i]]<m) ++i;
		while (dfn[a[j]]>m) --j;
		if (i<=j) swap(a[i++],a[j--]);
	}
	if (i<r) qs(i,r);
	if (l<j) qs(l,j);
}
inline void ins(int x)
{
	if (tp==0)
	{
		st[tp=1]=x;
		return;
	}
	ance=lca(st[tp],x);//相当于z
	while ((tp>1)&&(dep[ance]<dep[st[tp-1]]))
	{
		add(st[tp-1],st[tp]);
		--tp;
	}
	if (dep[ance]<dep[st[tp]]) add(ance,st[tp--]);
	if ((!tp)||(st[tp]!=ance)) st[++tp]=ance;
	st[++tp]=x;
}//增量构建
void dfs3(int x)
{
	int i;
	if (siz[x]) for (i=fir[x];i;i=nxt[i])
	{
		dfs3(lj[i]);
		if (siz[lj[i]])
		{
			siz[lj[i]]=0;
			++ans;
		}
	}
	else
	{
		for (i=fir[x];i;i=nxt[i])
		{
			dfs3(lj[i]);
			siz[x]+=siz[lj[i]];
			siz[lj[i]]=0;
		}
		if (siz[x]>1)
		{
			++ans;siz[x]=0;
		}
	}
	fir[x]=0;//这里清空
}//对每组询问的解决
int main()
{
	read(n);
	for (i=1;i<n;i++)
	{
		read(x);read(y);
		add(x,y);add(y,x);
	}
	bs=0;
	dfs1(dep[1]=1);
	dfs2(top[1]=1);
	memset(fir+1,0,n<<2);
	memset(siz+1,0,n<<2);
	read(q);
	bs=0;
	while (q--)
	{
		x=1;
		read(m);
		for (i=1;i<=m;i++)
		{
			read(a[i]);
			siz[a[i]]=1;
		}
		for (i=1;i<=m;i++) if (siz[f[a[i]]])
		{
			puts("-1");
			x=0;
			break;
		}//特判无解
		if (!x)
		{
			while (m) siz[a[m--]]=0;//清空打过的标记
			continue;
		}
		ans=0;
		qs(1,m);
		if (a[1]!=1) st[tp=1]=1;//先行添加根节点
		for (i=1;i<=m;i++) ins(a[i]);
		if (tp) while (--tp) add(st[tp],st[tp+1]);//回溯
		dfs3(1);
		siz[1]=bs=0;
		printf("%d\n",ans);
	}
}

其他例题

很裸,注意分类讨论即可。注意 dfs 要同时预处理最值和次值,类似树形 dp 处理树的直径。
CPP
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
const int N=1e6+2,M=2e6+2,inf=1e9;
typedef long long ll;
ll ans1;
int lj[M],nxt[M],fir[N],siz[N],top[N],dfn[N],hc[N],f[N],st[N],a[N],dep[N],dp[N][5],lc[N],dc[N];
int n,m,i,j,x,y,c,bs,tp,q,ance,len,ans2,ans3;
bool ed[N];
inline void read(int &x)
{
	c=getchar();
	while ((c<48)||(c>57)) c=getchar();
	x=c^48;c=getchar();
	while ((c>=48)&&(c<=57))
	{
		x=x*10+(c^48);
		c=getchar();
	}
}
inline void add(int x,int y)
{
	lj[++bs]=y;
	nxt[bs]=fir[x];
	fir[x]=bs;
}
void dfs1(int x)
{
	int i;
	siz[x]=1;
	for (i=fir[x];i;i=nxt[i]) if (lj[i]!=f[x])
	{
		dep[lj[i]]=dep[f[lj[i]]=x]+1;
		dfs1(lj[i]);
		siz[x]+=siz[lj[i]];
		if (siz[lj[i]]>siz[hc[x]]) hc[x]=lj[i];
	}
}
void dfs2(int x)
{
	dfn[x]=++bs;
	if (hc[x])
	{
		int i;
		top[hc[x]]=top[x];
		dfs2(hc[x]);
		for (i=fir[x];i;i=nxt[i]) if ((lj[i]!=f[x])&&(lj[i]!=hc[x])) dfs2(top[lj[i]]=lj[i]);
	}
}
inline int lca(int x,int y)
{
	while (top[x]!=top[y]) if (dep[top[x]]<dep[top[y]]) y=f[top[y]]; else x=f[top[x]];
	if (dep[x]<dep[y]) return x;
	return y;
}
void qs(int l,int r)
{
	int i=l,j=r,m=dfn[a[l+r>>1]];
	while (i<=j)
	{
		while (dfn[a[i]]<m) ++i;
		while (dfn[a[j]]>m) --j;
		if (i<=j) swap(a[i++],a[j--]);
	}
	if (i<r) qs(i,r);
	if (l<j) qs(l,j);
}
inline void ins(int x)
{
	if (!tp)
	{
		st[tp=1]=x;
		return;
	}
	ance=lca(st[tp],x);
	while ((tp>1)&&(dep[st[tp-1]]>dep[ance]))
	{
		add(st[tp-1],st[tp]);
		--tp;
	}
	if (dep[st[tp]]>dep[ance]) add(ance,st[tp--]);
	if ((!tp)||(st[tp]!=ance)) st[++tp]=ance;
	st[++tp]=x;
}
void dfs3(int x)//0 1短/2 3长/0 2最值/1 3次值
{
	dp[x][0]=((siz[x]=ed[x])^1)*(dp[x][1]=inf);
	dp[x][2]=(ed[x]^1)*(dp[x][3]=-inf);lc[x]=dc[x]=0;
	int i;
	for (i=fir[x];i;i=nxt[i])
	{
		dfs3(lj[i]);
		siz[x]+=siz[lj[i]];
		ans1+=(ll)siz[lj[i]]*(len=dep[lj[i]]-dep[x])*(m-siz[lj[i]]);
		if (dp[x][0]>dp[lj[i]][0]+len)
		{
			if (dp[x][0]<dp[x][1]) dp[x][1]=dp[x][0];
			dp[x][0]=dp[dc[x]=lj[i]][0]+len;
		}
		else dp[x][1]=min(dp[x][1],dp[lj[i]][0]+len);
		if (dp[x][2]<dp[lj[i]][2]+len)
		{
			if (dp[x][2]>dp[x][3]) dp[x][3]=dp[x][2];
			dp[x][2]=dp[lc[x]=lj[i]][2]+len;
		}
		else dp[x][3]=max(dp[x][3],dp[lj[i]][2]+len);
	}
	ans2=min(ans2,dp[x][0]+dp[x][1]);
	ans3=max(ans3,dp[x][2]+dp[x][3]);
	fir[x]=0;
	ed[x]=0;
}
int main()
{
	read(n);
	for (i=1;i<n;i++)
	{
		read(x);read(y);
		add(x,y);add(y,x);
	}
	read(q);
	bs=0;
	dfs1(dep[1]=1);
	dfs2(top[1]=1);
	bs=0;
	memset(fir+1,0,n<<2);
	while (q--)
	{
		read(m);
		for (i=1;i<=m;i++)
		{
			read(a[i]);
			ed[a[i]]=1;
		}
		qs(1,m);
		if (a[1]!=1) st[tp=1]=1;
		for (i=1;i<=m;i++) ins(a[i]);
		while (--tp) add(st[tp],st[tp+1]);
		ans1=ans3=0;ans2=inf;
		dfs3(1);
		printf("%lld %d %d\n",ans1,ans2,ans3);
	}
}
后缀树+虚树,计算虚树上每条边贡献。据说也有 SA 解法。曾经想过把这题改一改再套一个点分治&树状数组
CPP
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
typedef long long ll;
const int S=5e5+2,N=1e6+2;
int c[N][27],dep[N],f[N],t[N],zd[N],lj[N],nxt[N],fir[N],len[N],siz[N],fa[N],s[S];
int top[N],hc[N],dfn[N],dy[N],st[N],a[N];
int n,m,q,i,j,x,ds=1,point=1,ad,r,edge,remain,bs,fbs,cc,tp,ance;
ll ans;
inline void read(int &x)
{
	cc=getchar();
	while ((cc<48)||(cc>57)) cc=getchar();
	x=cc^48;cc=getchar();
	while ((cc>=48)&&(cc<=57))
	{
		x=x*10+(cc^48);
		cc=getchar();
	}
}
inline void xadd(int x,int y)
{
	lj[++bs]=y;
	nxt[bs]=fir[x];
	fir[x]=bs;
}
inline void add(int x,int y)
{
	lj[++fbs]=y;
	nxt[fbs]=fir[x];
	fir[x]=fbs;
}
inline void add(int x,int y,int z)
{
	lj[++fbs]=y;
	len[bs]=z;
	nxt[fbs]=fir[x];
	fir[x]=bs;
}
inline void add(int a,int b,int cc,int d)
{
	c[a][s[cc]]=++bs;
	f[bs]=cc;t[bs]=d;
	zd[bs]=b;
}
void dfs1(int x)
{
	siz[x]=1;
	if (!fir[x]) {dy[n-dep[x]+1]=x;return;}
	int i;
	for (i=fir[x];i;i=nxt[i])
	{
		dep[lj[i]]=dep[f[lj[i]]=x]+len[i];
		dfs1(lj[i]);
		siz[x]+=siz[lj[i]];
		if (siz[lj[i]]>siz[hc[x]]) hc[x]=lj[i];
	}
}
void dfs2(int x)
{
	dfn[x]=++bs;
	if (hc[x])
	{
		top[hc[x]]=top[x];
		dfs2(hc[x]);
		int i;
		for (i=fir[x];i;i=nxt[i]) if (lj[i]!=hc[x]) dfs2(top[lj[i]]=lj[i]);
	}
}
bool cmp(int x,int y)
{
	return dfn[x]<dfn[y];
}
inline int lca(register int x,register int y)
{
	while (top[x]!=top[y]) if (dep[top[x]]<dep[top[y]]) y=f[top[y]]; else x=f[top[x]];
	if (dep[x]<dep[y]) return x; return y;
}
inline void ins(int x)
{
	if (!tp) {st[tp=1]=x;return;}
	ance=lca(st[tp],x);
	while ((tp>1)&&(dep[st[tp-1]]>dep[ance]))
	{
		xadd(st[tp-1],st[tp]);
		--tp;
	}
	if (dep[st[tp]]>dep[ance]) xadd(ance,st[tp--]);
	if ((!tp)||(st[tp]!=ance)) st[++tp]=ance;
	st[++tp]=x;
}
void dfs3(int x)
{
	int i;
	for (i=fir[x];i;i=nxt[i])
	{
		dfs3(lj[i]);ans+=(ll)siz[x]*siz[lj[i]]*dep[x];
		siz[x]+=siz[lj[i]];
	}
}
void dfs4(int x)
{
	int i;
	for (i=fir[x];i;i=nxt[i]) dfs4(lj[i]);
	siz[x]=fir[x]=0;
}
int main()
{
	read(n);read(q);
	cc=getchar();
	while ((cc<'a')||(cc>'z')) cc=getchar();
	s[1]=cc-97;
	for (i=2;i<=n;i++) s[i]=getchar()-97;fa[1]=1;
	s[++n]=26;
	for (i=1;i<=n;i++)
	{
		ad=0;++remain;
		while (remain)
		{
			if (r==0) edge=i;
			if ((j=c[point][s[edge]])==0)
			{
				fa[ad]=point;
				fa[++ds]=1;
				add(ad=point,ds,edge,n);
				add(point,s[edge]);
			}
			else
			{
				if ((t[j]!=n)&&(t[j]-f[j]+1<=r))
				{
					r-=t[j]-f[j]+1;
					edge+=t[j]-f[j]+1;
					point=zd[j];
					continue;
				}
				if (s[i]==s[f[j]+r])
				{
					++r;fa[ad]=point;break;
				}
				fa[fa[ad]=++ds]=1;add(ad=ds,zd[j],f[j]+r,t[j]);
				add(ds,s[f[j]+r]);zd[j]=ds;t[j]=f[j]+r-1;
				add(ds,s[i]);fa[++ds]=1;add(ds-1,ds,i,n);
			}
			--remain;
			if ((r)&&(point==1))
			{
				--r;
				edge=i-remain+1;
			} else point=fa[point];
		}
	}//ukk后缀树,sam的可以按照自己板子来建
	for (i=1;i<=ds;i++) for (j=fir[i];j;j=nxt[j])
	{
		x=c[i][lj[j]];
		lj[j]=zd[x];
		len[j]=t[x]-f[x]+1;
	}
	memset(f+1,0,bs<<2);bs=0;
	dfs1(1);dfs2(top[1]=1);
	memset(siz+1,0,ds<<2);
	memset(fir+1,0,ds<<2);bs=0;
	while (q--)
	{
		read(m);ans=bs=0;
		for (i=1;i<=m;i++)
		{
			read(a[i]);siz[a[i]=dy[a[i]]]=1;
		}
		sort(a+1,a+m+1,cmp);if (a[1]!=1) ins(1);
		for (i=1;i<=m;i++) if ((i==1)||(a[i]!=a[i-1])) ins(a[i]);
		while (--tp) xadd(st[tp],st[tp+1]);
		dfs3(1);dfs4(1);printf("%lld\n",ans);
	}
}

习题

评论

51 条评论,欢迎与作者交流。

正在加载评论...