专栏文章

CF1085G Beautiful Matrix 题解

CF1085G题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@minx3ahw
此快照首次捕获于
2025/12/02 09:46
3 个月前
此快照最后确认于
2025/12/02 09:46
3 个月前
查看原文
好题。
首先用上“一般求小于某个东西的方案数”的策略:枚举前面那几行相等,然后枚举这一行不等。
我们运用它。我们可以钦定前 ii 行相同,注意 ii 的范围是 00n1n-1。如果我们能求出 fif_i 表示使得第 i+1i+1 行合法的方案数,那么答案就是:i=0n1fiDnni1\sum_{i=0}^{n-1}f_iD_n^{n-i-1}。其中 DnD_n 表示长度为 nn 的错排方案数。
接下来的问题是:如何求出 fif_i
首先我们发现,第 i+1i+1 行有如下限制:字典序小于原来的第 i+1i+1 行,与原来的第 ii 行没有重复的地方。
我们发现第 11 行只有第一个限制。所以我们可以通过计算排列的排名,得到 f0f_0。运用树状数组可以做到 O(nlogn)O(n\log n),详情请见 P5367
然后我们再次运用一开始提到的策略。我们钦定前 j1j-1 个数相同(接下来的讲解均在“前 ii 行相同“的条件下进行)。那么,我们枚举第 jj 个数是多少,然后看它是否合法。
你或许注意到了它是 O(n3)O(n^3) 的。先别急!我们慢慢细说。
然后我们会剩下来后面的数。注意到后面的数,有些是相互重复的,有些是不重复的。
假设我们前面选择了 xx 个数(也就是 jj),有 yy 组重复的数。请注意,这里的重复都指:存在 k,lk,l 满足 ai,k=ai+1,la_{i,k}=a_{i+1,l}
那么,后面的方案数为 sznx,n2x+ysz_{n-x,n-2x+y},表示后面有 nxn-x 的空位,其中有 n2x+yn-2x+y 组重复数。
问题来了,怎么求 szn,msz_{n,m} 呢?
我们发现,这可以容斥,那么 szn,m=i=0m(1)i(mi)(ni)!sz_{n,m}=\sum_{i=0}^{m}(-1)^i\binom{m}{i}(n-i)!,相信大家都能理解,不理解也没关系。
现在,我们成功做到了 O(n3)O(n^3)
我们需要优化。
先给出我们的暴力代码,以便后面叙述:
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2010;
const int val=1000000;
const int mod=998244353;
int inv[N];
int jc[N],ijc[N];
int D[N];
void init(int n){
	inv[1]=1;
	jc[0]=ijc[0]=1;
	for(int i=2;i<=n;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for(int i=1;i<=n;i++){
		jc[i]=jc[i-1]*i%mod;
		ijc[i]=ijc[i-1]*inv[i]%mod;
	}
	D[1]=0,D[2]=1;
	for(int i=3;i<=n;i++)D[i]=(i-1)*(D[i-1]+D[i-2])%mod;
	return;
}
int C(int n,int m){
	if(n<m)return 0;
	int fz=jc[n];
	int fm=ijc[m]*ijc[n-m]%mod;
	return fz*fm%mod;
}
int n;
struct BIT{
	int t[N];
	int lowbit(int x){return x&-x;}
	void add(int x,int y){
		for(int i=x;i<=n;i+=lowbit(i))t[i]+=y;
		return;
	}
	int query(int x){
		int res=0;
		for(int i=x;i>=1;i-=lowbit(i))res+=t[i];
		return res;
	}
}T;
int ans;
int a[N][N];
int sz[N][N];
int f[N];
int rk(){
	int ans=0;
	for(int i=n;i>=1;i--){
		ans=(ans+jc[n-i]*T.query(a[1][i]-1)%mod)%mod;
		T.add(a[1][i],1);
	}
	return ans;
}
int qpow(int x){return x%2==1?-1:1;}
int ksm(int a,int b){
	int z=1;
	while(b){
		if(b&1)z=z*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return z;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	// freopen("B.in","r",stdin);
	// freopen("B.out","w",stdout);
	init(N-10);
	cin>>n;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=n;j++){
			cin>>a[i][j];
		}
	}
//这里的 sz 与前文的不同。这里的意义就是,当我们前面枚举的状态是(x,y)时的方案数。后面会修改的,这里请不要重视!!!
	for(int x=1;x<=n;x++){
		for(int y=0;y<=x;y++){
			if(n-2*x+y<0)continue;
			for(int i=0;i<=n-2*x+y;i++){
				sz[x][y]+=qpow(i)*C(n-2*x+y,i)*jc[n-x-i]%mod;
				sz[x][y]=(sz[x][y]%mod+mod)%mod;
			}
		}
	}
	f[0]=rk();
	for(int i=1;i<n;i++){
		int x=0,y=0;
		unordered_map<int,int>M1,M2;
		for(int j=1;j<=n;j++){
			x++;
			for(int k=1;k<a[i+1][j];k++){
				if(M2[k])continue;
				if(a[i][j]==k)continue;
				int xx=x,yy=y;
				if(M1[k])yy++;
				if(M2[a[i][j]])yy++;
				f[i]=(f[i]+sz[xx][yy])%mod;
			}
			M1[a[i][j]]++;
			M2[a[i+1][j]]++;
			y+=M2[a[i][j]];
			y+=M1[a[i+1][j]];
		}
	}
	for(int i=0;i<n;i++)ans=(ans+f[i]*ksm(D[n],n-i-1)%mod)%mod;
	cout<<ans<<"\n";
	return 0;
}
/*

*/
我们先修改后面对 fif_i 的处理。
我们发现,第 9797 行(可以粘贴方便查看)可以单独提出来,提到枚举的外面。
那么,我们发现,yyyy 无非就两种状态:yyy+1y+1
我们可以统计几种情况需要加一,然后用总的个数减掉,就得出了不需要加一的方案。
我的实现有点丑陋,大家真的仅供参考,仅供参考……
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2010;
const int val=1000000;
const int mod=998244353;
int inv[N];
int jc[N],ijc[N];
int D[N];
void init(int n){
	inv[1]=1;
	jc[0]=ijc[0]=1;
	for(int i=2;i<=n;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for(int i=1;i<=n;i++){
		jc[i]=jc[i-1]*i%mod;
		ijc[i]=ijc[i-1]*inv[i]%mod;
	}
	D[1]=0,D[2]=1;
	for(int i=3;i<=n;i++)D[i]=(i-1)*(D[i-1]+D[i-2])%mod;
	return;
}
int C(int n,int m){
	if(n<m)return 0;
	int fz=jc[n];
	int fm=ijc[m]*ijc[n-m]%mod;
	return fz*fm%mod;
}
int n;
struct BIT{
	int t[N];
	void clear(){
		for(int i=1;i<=n;i++)t[i]=0;
		return;
	}
	int lowbit(int x){return x&-x;}
	void add(int x,int y){
		for(int i=x;i<=n;i+=lowbit(i))t[i]+=y;
		return;
	}
	int query(int x){
		int res=0;
		for(int i=x;i>=1;i-=lowbit(i))res+=t[i];
		return res;
	}
}T,T1,T2,T3;//T1是第i行的出现个数,T2是第i+1行的,T3就是两行都出现的
int ans;
int a[N][N];
int sz[N][N];
int f[N];
int rk(){
	int ans=0;
	for(int i=n;i>=1;i--){
		ans=(ans+jc[n-i]*T.query(a[1][i]-1)%mod)%mod;
		T.add(a[1][i],1);
	}
	return ans;
}
int qpow(int x){return x%2==1?-1:1;}
int ksm(int a,int b){
	int z=1;
	while(b){
		if(b&1)z=z*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return z;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	// freopen("B.in","r",stdin);
	// freopen("B.out","w",stdout);
	init(N-10);
	cin>>n;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=n;j++){
			cin>>a[i][j];
		}
	}
//这里改回了正确的sz(前面描述的表示状态)
	for(int x=0;x<=n;x++){
		for(int y=0;y<=n;y++){
			for(int i=0;i<=y;i++){
				sz[x][y]+=qpow(i)*C(y,i)*jc[x-i]%mod;
				sz[x][y]=(sz[x][y]%mod+mod)%mod;
			}
		}
	}
	f[0]=rk();
//打注释的代码块就是之前写的暴力
	for(int i=1;i<n;i++){
		int x=0,y=0;
		T1.clear();T2.clear();T3.clear();
		unordered_map<int,int>M1,M2;
		for(int j=1;j<=n;j++){
			x++;
			if(M2[a[i][j]])y++;
			int total=(a[i+1][j]-1)-T2.query(a[i+1][j]-1)-(a[i][j]<a[i+1][j]&&!M2[a[i][j]]);
			int toty_add=T1.query(a[i+1][j]-1)-T3.query(a[i+1][j]-1);//加上1,减去重复,就是+1的次数
			int toty=total-toty_add;//不加一的次数
			f[i]=(f[i]+sz[n-x][n-2*x+y]*toty%mod+sz[n-x][n-2*x+y+1]*toty_add%mod)%mod;
			// if(i==2&&j==2)cerr<<total<<"\n";
			// for(int k=1;k<a[i+1][j];k++){
			// 	if(M2[k])continue;
			// 	if(a[i][j]==k)continue;
			// 	int xx=x,yy=y;
			// 	if(M1[k])yy++;
			//	if(M2[a[i][j]])yy++;
			// 	f[i]=(f[i]+sz[xx][yy])%mod;
			// }
			T1.add(a[i][j],1);
			M1[a[i][j]]++;
			T2.add(a[i+1][j],1);
			M2[a[i+1][j]]++;
			// y+=M2[a[i][j]];
			y+=M1[a[i+1][j]];
			if(M1[a[i][j]]&&M2[a[i][j]])T3.add(a[i][j],1);
			if(M1[a[i+1][j]]&&M2[a[i+1][j]])T3.add(a[i+1][j],1);
		}
	}
	for(int i=0;i<n;i++)ans=(ans+f[i]*ksm(D[n],n-i-1)%mod)%mod;
	cout<<ans<<"\n";
	return 0;
}
/*

*/
我们就差最后一步了!
我们发现,容斥是没前途的,考虑朴素转移(真的有人是先想到容斥,才想到朴素转移吗……)。
有一个经典的套路:这种关于序列大小的 DP,可以通过“加一个新的数”来进行转移。
我们插入第 ii 个数,有两种情况。
第一种:插入数之后,多一个重复的数对:我们在 mm 个数中选择一个,故方案为 m×szn1,m1m\times sz_{n-1,m-1}
第二种:插入数之后,没有变化:我们在 nmn-m 个没限制的数中选一个,故方案为 (nm)szn1,m(n-m)sz_{n-1,m}
所以,szn,m=m×szn1,m1+(nm)szn1,msz_{n,m}=m\times sz_{n-1,m-1}+(n-m)sz_{n-1,m}。时间复杂度 O(n2logn)O(n^2\log n),瓶颈在于对于 ff 的计算。
代码有点丑:
CPP
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2010;
const int val=1000000;
const int mod=998244353;
int inv[N];
int jc[N],ijc[N];
int D[N];
void init(int n){
	inv[1]=1;
	jc[0]=ijc[0]=1;
	for(int i=2;i<=n;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	for(int i=1;i<=n;i++){
		jc[i]=jc[i-1]*i%mod;
		ijc[i]=ijc[i-1]*inv[i]%mod;
	}
	D[0]=1;
	D[1]=0,D[2]=1;
	for(int i=3;i<=n;i++)D[i]=(i-1)*(D[i-1]+D[i-2])%mod;
	return;
}
int C(int n,int m){
	if(n<m)return 0;
	int fz=jc[n];
	int fm=ijc[m]*ijc[n-m]%mod;
	return fz*fm%mod;
}
int n;
struct BIT{
	int t[N];
	void clear(){
		for(int i=1;i<=n;i++)t[i]=0;
		return;
	}
	int lowbit(int x){return x&-x;}
	void add(int x,int y){
		for(int i=x;i<=n;i+=lowbit(i))t[i]+=y;
		return;
	}
	int query(int x){
		int res=0;
		for(int i=x;i>=1;i-=lowbit(i))res+=t[i];
		return res;
	}
}T,T1,T2,T3;
int ans;
int a[N][N];
int sz[N][N];
int f[N];
int rk(){
	int ans=0;
	for(int i=n;i>=1;i--){
		ans=(ans+jc[n-i]*T.query(a[1][i]-1)%mod)%mod;
		T.add(a[1][i],1);
	}
	return ans;
}
int qpow(int x){return x%2==1?-1:1;}
int ksm(int a,int b){
	int z=1;
	while(b){
		if(b&1)z=z*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return z;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	// freopen("B.in","r",stdin);
	// freopen("B.out","w",stdout);
	init(N-10);
	cin>>n;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=n;j++){
			cin>>a[i][j];
		}
	}
	for(int x=0;x<=n;x++){
		sz[x][x]=D[x];
		for(int y=0;y<x;y++){
			sz[x][y]=y*(y==0?0:sz[x-1][y-1])+(x-y)*sz[x-1][y];
			sz[x][y]%=mod;
		}
	}
	f[0]=rk();
	for(int i=1;i<n;i++){
		int x=0,y=0;
		T1.clear();T2.clear();T3.clear();
		unordered_map<int,int>M1,M2;
		for(int j=1;j<=n;j++){
			x++;
			if(M2[a[i][j]])y++;
			int total=(a[i+1][j]-1)-T2.query(a[i+1][j]-1)-(a[i][j]<a[i+1][j]&&!M2[a[i][j]]);
			int toty_add=T1.query(a[i+1][j]-1)-T3.query(a[i+1][j]-1);//加上1,减去重复,就是+1的次数
			int toty=total-toty_add;//不加一的次数
			f[i]=(f[i]+sz[n-x][n-2*x+y]*toty%mod+sz[n-x][n-2*x+y+1]*toty_add%mod)%mod;
			// if(i==2&&j==2)cerr<<total<<"\n";
			// for(int k=1;k<a[i+1][j];k++){
			// 	if(M2[k])continue;
			// 	if(a[i][j]==k)continue;
			// 	int xx=x,yy=y;
			// 	if(M1[k])yy++;
			//	if(M2[a[i][j]])yy++;
			// 	f[i]=(f[i]+sz[xx][yy])%mod;
			// }
			T1.add(a[i][j],1);
			M1[a[i][j]]++;
			T2.add(a[i+1][j],1);
			M2[a[i+1][j]]++;
			// y+=M2[a[i][j]];
			y+=M1[a[i+1][j]];
			if(M1[a[i][j]]&&M2[a[i][j]])T3.add(a[i][j],1);
			if(M1[a[i+1][j]]&&M2[a[i+1][j]])T3.add(a[i+1][j],1);
		}
	}
	for(int i=0;i<n;i++)ans=(ans+f[i]*ksm(D[n],n-i-1)%mod)%mod;
	cout<<ans<<"\n";
	return 0;
}
/*

*/

评论

0 条评论,欢迎与作者交流。

正在加载评论...