专栏文章
题解:P11363 [NOIP2024] 树的遍历
P11363题解参与者 3已保存评论 2
文章操作
快速查看文章及其快照的属性,并进行相关操作。
- 当前评论
- 2 条
- 当前快照
- 1 份
- 快照标识符
- @miqwzynh
- 此快照首次捕获于
- 2025/12/04 12:07 3 个月前
- 此快照最后确认于
- 2025/12/04 12:07 3 个月前
一个关键性质是,一个点连接的所有边在最终的树上会形成一条链,且链与链之间的结构不互相影响。注意这并不意味着你可以把每个点的合法链方案乘起来,因为一个方案是否合法与你选定的起点有关。博主因为这个问题虚空调试 1h 并最终直接导致爆炸。
对于 的情况,我们发现链的端点一定是来源方向的边,其他边任意。方案数就显然了:。
对于 的情况,假设两个关键边为 ,先分别计算以每个边为起点的树的数量和(其实就是 的答案乘 ),然后我们需要考虑哪些树是算重的。由于我们要求链的起点是来源方向的边,因此 之间的所有点的链的两端都固定了;其它点由于两个关键边对应点上的同一条边,答案仍然不变。于是算重的树的个数为 ,其中 是 之间的点的集合。
对于 更大的情况,我们发现,若一棵树被边集 里的关键边同时算重,则 内的边一定形成在同一条链内。否则一定会存在一个点,要求其至少三条邻边都在链的端点。这是不合法的。基于这个性质,我们只需要枚举任意两个关键边,使得它们之间没有关键边,然后计算同时被这两个关键边算到的方案数,从总答案中减去这些方案即可。
不难发现,对于一棵这样的树,他会在初始时被算到 次,然后被减去 次,故这样的容斥是正确的。
对于第二部分的计算,我们可以设 表示 子树内无选定的关键边;有一个选定的关键边;有两个选定的关键边的方案数。设 是 的儿子转移如下:
- 对于有一个边的情况,考察 是否是关键边:
- 若不是,则
- 否则,由于我们要求两个关键边之间没有其他边,则选定的关键边只能是 ,有 。
- 对于有两个边的情况,分下面几种情况讨论:
- 在某个儿子内就已经完成了合并:。
- 两个不同的子树内各选了一个:此时另设 表示选了多少个的方案数,根据每条边是否是关键边以及选不选讨论,做背包合并即可,有 。
- 选了某个关键边 ,另一个在 的子树内:此时有 。
通过预处理 可以做到 ,瓶颈在求逆元。总复杂度 。
CPP#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;i++)
#define repp(i,j,k) for(int i=j;i>=k;i--)
#define pii pair<int,int>
#define mp make_pair
#define fir first
#define sec second
#define ls(x) (x<<1)
#define rs(x) ((x<<1)|1)
#define lowbit(i) (i&-i)
#define int long long
#define qingbai 666
using namespace std;
typedef long long ll;
const int N=1e5+5,inf=(ll)1e18+7,mo=1e9+7;
void read(int &p){
int w=1,x=0;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')w=-1;
ch=getchar();
}
while(isdigit(ch)){
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
p=w*x;
}
int T;
int n,m;
vector<pii>e[N];
int deg[N],jc[N],qj[N],f[N][3],g[3];
int quick_power(int base,int x){
int res=1;
while(x){
if(x&1)res*=base,res%=mo;
base*=base,base%=mo;
x>>=1;
}
return res;
}
bool imp[N];
void dfs(int x,int p){
int prod=1,sum2=0;
for(auto j:e[x])
if(j.fir!=p)dfs(j.fir,x),prod*=f[j.fir][0],prod%=mo;
rep(i,0,2)
f[x][i]=g[i]=0;
g[0]=1;
f[x][0]=prod*jc[deg[x]-1]%mo;
for(auto j:e[x]){
if(j.fir==p)continue;
int inv0=quick_power(f[j.fir][0],mo-2);
if(deg[x]>=2){
if(imp[j.sec])f[x][1]++,f[x][1]%=mo;
else f[x][1]+=f[j.fir][1]*inv0%mo,f[x][1]%=mo;
}
repp(k,2,0){
g[k]=g[k]*f[j.fir][0]%mo;
if(k){
if(imp[j.sec])g[k]+=g[k-1]*f[j.fir][0]%mo;
else g[k]+=g[k-1]*f[j.fir][1]%mo;
}
g[k]%=mo;
}
if(imp[j.sec])f[x][2]+=f[j.fir][1]*inv0%mo,f[x][2]%=mo;
sum2+=f[j.fir][2]*inv0%mo,sum2%=mo;
}
f[x][2]*=prod*jc[deg[x]-1]%mo,f[x][2]%=mo;
if(deg[x]>=2){
f[x][1]*=prod*jc[deg[x]-2]%mo,f[x][1]%=mo;
f[x][2]+=g[2]*jc[deg[x]-2]%mo;
}
if(e[x].size()>=2||x==1)f[x][2]+=sum2*prod%mo*jc[deg[x]-1],f[x][2]%=mo;
}
void solve(){
read(n),read(m);
rep(i,1,n)
deg[i]=0,e[i].clear(),imp[i]=0;
rep(i,1,n-1){
int x,y;
read(x),read(y);
e[x].push_back(mp(y,i)),e[y].push_back(mp(x,i));
deg[x]++,deg[y]++;
}
rep(i,1,m){
int x;
read(x),imp[x]=1;
}
int ans=1;
rep(i,1,n)
ans*=jc[deg[i]-1],ans%=mo;
ans*=m,ans%=mo;
dfs(1,0);
ans=ans+mo-f[1][2],ans%=mo;
printf("%lld\n",ans);
}
int cid;
signed main(){
jc[0]=qj[0]=1;
rep(i,1,100000)
jc[i]=jc[i-1]*i%mo,qj[i]=quick_power(jc[i],mo-2);
read(cid),read(T);
while(T--)
solve();
return 0;
}
相关推荐
评论
共 2 条评论,欢迎与作者交流。
正在加载评论...