专栏文章

【2】做题心得 - 2025 NOIP #65 - T3【数据结构】【矩阵快速幂】【线段树合并】

题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@min4sjsm
此快照首次捕获于
2025/12/01 20:34
3 个月前
此快照最后确认于
2025/12/01 20:34
3 个月前
查看原文
T3,随便可以推出一个矩阵快速幂的转移方案。
f=[x1y0]f1=[01y1xy]f=\begin{bmatrix} x & 1 \\ y & 0 \end{bmatrix}\\ f^{-1}=\begin{bmatrix} 0 & \frac{1}{y} \\ 1 & -\frac{x}{y} \end{bmatrix}
也就可以容易计算出两数相减的情况了,fx×fyf^x\times f^{-y} 是容易计算的。然后我们考虑如何计算答案。根据暴力容易得出一个树上差分的做法。你就发现这个东西可以线段树合并去做。使用两颗线段树维护两种矩阵应该就可以了。
CPP
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=5e5+10,M=2e7+10;
const ll P=998244353;
int n,x,y,sub;
ll a[N],ans[N];
vector<int>e[N];
struct mtr{
	ll a[2][2];
	mtr(){ a[0][0]=a[0][1]=a[1][0]=a[1][1]=0; }
} t[M],fib[2],pv; // 0+ 1-
ll qp(ll a,ll b){
	ll res=1;
	for(;b;b>>=1,a=(a*a)%P)	
		if(b&1) res=(res*a)%P;
	return res;
}
mtr add(mtr x,mtr y){
	mtr ans=mtr();
	for(int i=0;i<2;i++)for(int j=0;j<2;j++)
		ans.a[i][j]=(x.a[i][j]+y.a[i][j])%P;
	return ans;
}
mtr mul(mtr x,mtr y){
	mtr ans=mtr();
	for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int k=0;k<2;k++)
		ans.a[i][j]=(ans.a[i][j]+x.a[i][k]*y.a[k][j]%P)%P;
	return ans;
}
mtr qp(mtr x,ll b){
	mtr ans=mtr();
	ans.a[0][0]=ans.a[1][1]=1;
	for(;b;b>>=1,x=mul(x,x)) if(b&1) ans=mul(ans,x);
	return ans;
}
int tot,mrd[M],top; 	
int rt[2][N],ls[M],rs[M];
void merge(int &u,int &v,int l,int r){
	if(!v) return;
	if(!u) return u=v, void(0);
	mrd[++top]=v;
	if(l==r) return t[u]=add(t[u],t[v]), void(0);
	int mid=(l+r)>>1;
	merge(ls[u],ls[v],l,mid);
	merge(rs[u],rs[v],mid+1,r);
	t[u]=add(t[u],t[v]);
}
void query(int u,int v,int l,int r,int uv){
	if(!u||!v) return;
	if(l==r) return;
	int mid=(l+r)>>1;
	(ans[uv]+=mul(mul(pv,t[ls[v]]),t[rs[u]]).a[0][1])%=P;
	query(ls[u],ls[v],l,mid,uv);
	query(rs[u],rs[v],mid+1,r,uv);
	return;
}
void update(int &p,int l,int r,int pos,mtr uv){
	if(!p){
		if(top) p=mrd[top--], t[p]=mtr(), ls[p]=rs[p]=0;
		else p=++tot;
	}
	if(l==r) return t[p]=add(t[p],uv), void(0);
	int mid=(l+r)>>1;
	if(pos<=mid) update(ls[p],l,mid,pos,uv);
	else         update(rs[p],mid+1,r,pos,uv);
	t[p]=add(t[ls[p]],t[rs[p]]);
}
int d[N],l,na[N];
void dfs(int p,int fa){
	update(rt[0][p],1,l,na[p],qp(fib[0],a[p]));
	update(rt[1][p],1,l,na[p],qp(fib[1],a[p]));
	for(auto v:e[p])if(v^fa){
		dfs(v,p);
		query(rt[0][p],rt[1][v],1,l,p);
		query(rt[0][v],rt[1][p],1,l,p);
		ans[p]=(ans[p]+ans[v])%P;
		merge(rt[0][p],rt[0][v],1,l);
		merge(rt[1][p],rt[1][v],1,l);
	}
	return;
}
int main(){
	freopen("sam.in","r",stdin);
	freopen("sam.out","w",stdout);
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin>>n>>x>>y>>sub;
	for(int i=1;i<=n;i++)
		cin>>a[i],
		d[i]=a[i];
	sort(d+1,d+n+1);
	l=unique(d+1,d+n+1)-d-1;
	for(int i=1;i<=n;i++)
		na[i]=lower_bound(d+1,d+l+1,a[i])-d;
	for(int i=1;i<n;i++){
		int u,v;
		cin>>u>>v;
		e[u].push_back(v);
		e[v].push_back(u);
	}
	fib[0].a[0][0]=x, fib[0].a[0][1]=1,
	fib[0].a[1][0]=y, fib[0].a[1][1]=0;
	fib[1].a[0][0]=0, fib[1].a[0][1]=qp(y,P-2),
	fib[1].a[1][0]=1, fib[1].a[1][1]=(P-x)*qp(y,P-2)%P;
	pv.a[0][0]=1, pv.a[0][1]=0,
	pv.a[1][0]=0, pv.a[1][1]=0;
	dfs(1,0);
	for(int i=1;i<=n;i++)
		cout<<ans[i]<<"\n";
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...