专栏文章

题解:P11539 [Code+#5] 方案计数

P11539题解参与者 7已保存评论 6

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
6 条
当前快照
1 份
快照标识符
@min3t5g0
此快照首次捕获于
2025/12/01 20:06
3 个月前
此快照最后确认于
2025/12/01 20:06
3 个月前
查看原文
好题啊。提供一种由老师讲解的思路。

思路

由于最终得到一个排列的要求并不好找性质,所以我们考虑从排列得到原本的有序序列,这个是等价于原问题的。
我们设一个位置是合法的断点,即是一个合法的题面函数中生成的 mm,他需要满足 plpmp_l\sim p_mpm+1prp_{m + 1} \sim p_r 的值域不交。因为当我们分完这个断点之后,一定会先遍历完其中一边的所有值,这些值就是我们最终有序序列的一段等长前缀的值。又因为有序序列的一段前缀和去掉这个对应前缀的后缀的值域不交,所以我们选取的断点也需要满足这个条件。对于判断断点,我们可以直接使用 ST 表维护一段区间的最大最小值即可。
考虑如何计算答案。我们设 fl,rf_{l,r} 表示区间 lrl\sim r 这一段区间的答案。如果 l=rl=rfl,r=1f_{l,r}=1。否则 fl,r=jSfl,j×fj+1,rf_{l,r}=\sum_{j\in S} f_{l,j} \times f_{j+1,r}。此处的 SS 表示区间 lrl\sim r 所有合法的断点的集合。此时我们可以得到一种枚举断点然后递归的 n3n^3 做法。
n3n^3 代码CPP
#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N = 5e5 + 5, mod = 998244353;

int n, fac[N << 1], inv[N << 1], c[N], mx[N][20], mn[N][20], a[N];

__inline__ int power(int x, int y) {
	int w = 1;
	
	while(y) {
		if(y & 1) {
			w = w * x % mod;
		}
		
		x = x * x % mod;
		y >>= 1;
	}
	
	return w;
}

__inline__ int getmn(int l, int r) {// 这里是 st 表 check 断点
	int lg = __lg(r - l + 1);
	
	return min(mn[l][lg], mn[r - (1 << lg) + 1][lg]);
}

__inline__ int getmx(int l, int r) {
	int lg = __lg(r - l + 1);
	
	return max(mx[l][lg], mx[r - (1 << lg) + 1][lg]);
}

__inline__ bool check(int l, int mid, int r) {
	if(a[l] > a[r]) {
		return getmn(l, mid) > getmx(mid + 1, r);
	}
	
	return getmx(l, mid) < getmn(mid + 1, r);
}

int dfs(int l, int r) {
    
	if(l == r) {
		return 1;
	}
	
	int sum = 0;
	
	for(int i = l; i < r; i ++) {
		if(check(l, i, r)) {// 如果有合法断点直接递归
			sum = (sum + dfs(l, i) * dfs(i + 1, r) % mod) % mod;
		}
	}
	
	return sum;
}

__inline__ int C(int n, int m) {
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	cin >> n;
	
	for(int i = 1; i <= n; i ++) {
		cin >> a[i];
		mx[i][0] = mn[i][0] = a[i];
	}
	
	for(int i = 1; (1 << i) <= n; i ++) {
		for(int j = 1; j + (1 << i) - 1 <= n; j ++) {
			mx[j][i] = max(mx[j][i - 1], mx[j + (1 << i - 1)][i - 1]);
			mn[j][i] = min(mn[j][i - 1], mn[j + (1 << i - 1)][i - 1]);
		}
    }
	
	fac[0] = inv[0] = 1;
	
	for(int i = 1; i <= n * 2; i ++) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = power(fac[i], mod - 2);
	}
	
	for(int i = 1; i <= n; i ++) {
		c[i] = (C(2 * i, i) - C(2 * i, i - 1) + mod) % mod;
	}
	
	cout << dfs(1, n);
	
	return 0;
}
考虑优化。拆贡献,首先你会发现对于你递归的一段区间如果他包含了原本大区间的断点那么这段小区间的断点就是大区间的断点,此时我们的贡献就变成了对于所有的小区间的值再乘上分割区间的方案数。对于这个方案数就是经典的二叉树方案数计数,方案数为卡特兰数的第段数项。对于小区间如果包含大区间断点就不会产生新断点的证明:你会发现,由于你需要满足被断点隔开的任意两个子区间的值不交,因此你的小区间的值域肯定是单调递增或者是单调递减的,所以在你包含原本大区间的断点的情况下是无法产生新断点的,否则原本大区间中也一定会包含这个断点。
此时,我们找到大区间的所有断点后只需要对于每个小区间分别递归了,复杂度优化为 n2n^2
n2n^2 代码CPP
#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N = 5e5 + 5, mod = 998244353;

int n, fac[N << 1], inv[N << 1], c[N], mx[N][21], mn[N][21], a[N];

__inline__ int power(int x, int y) {
	int w = 1;
	
	while(y) {
		if(y & 1) {
			w = w * x % mod;
		}
		
		x = x * x % mod;
		y >>= 1;
	}
	
	return w;
}

__inline__ int getmn(int l, int r) {
	int lg = __lg(r - l + 1);
	
	return min(mn[l][lg], mn[r - (1 << lg) + 1][lg]);
}

__inline__ int getmx(int l, int r) {
	int lg = __lg(r - l + 1);
	
	return max(mx[l][lg], mx[r - (1 << lg) + 1][lg]);
}

__inline__ bool check(int l, int mid, int r) {
	if(a[l] > a[r]) {
		return getmn(l, mid) > getmx(mid + 1, r);
	}
	
	return getmx(l, mid) < getmn(mid + 1, r);
}

int dfs(int l, int r) {
//	cout << l << ' ' << r << '\n';
	
	if(l == r) {
		return 1;
	}
	
	int sum = 1, lst = l, cnt = 1;
	
	for(int i = l; i < r; i ++) {
		if(check(l, i, r)) {// 改变计算方式改为由所有小区间的贡献之积乘上一个卡特兰数
			sum = (sum * dfs(lst, i)) % mod;
			lst = i + 1;
			cnt ++;
		}
	}
	
	sum = (sum * dfs(lst, r)) % mod;// 别漏掉了这个区间。
	
	return sum * c[cnt - 1] % mod;
}

__inline__ int C(int n, int m) {
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

signed main() {
//	freopen("qwq.in", "r", stdin);
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	cin >> n;
	
	for(int i = 1; i <= n; i ++) {
		cin >> a[i];
		mx[i][0] = mn[i][0] = a[i];
	}
	
	for(int i = 1; (1 << i) <= n; i ++) {
		for(int j = 1; j + (1 << i) - 1 <= n; j ++) {
			mx[j][i] = max(mx[j][i - 1], mx[j + (1 << i - 1)][i - 1]);
			mn[j][i] = min(mn[j][i - 1], mn[j + (1 << i - 1)][i - 1]);
		}
    }
	
	fac[0] = inv[0] = 1;
	
	for(int i = 1; i <= n * 2; i ++) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = power(fac[i], mod - 2);
	}
	
	for(int i = 1; i <= n; i ++) {// 计算卡特兰数
		c[i] = (C(2 * i, i) - C(2 * i, i - 1) + mod) % mod;
	}
	
	cout << dfs(1, n);
	
	return 0;
}
到这里其实就已经可以在原题的水数据上通过了。但能更优。
此时我们的复杂度瓶颈在于找断点,其实你会发现,对于存在两个及以上区间的大区间,你只需要从两边同时往中间扫就可以用较短段的段长的时间复杂度内找到的。不难发现,这其实是启发式合并的复杂度。对于最大的那个子区间,由于你并没有遍历他的所有位置去寻找断点,因此你直接递归他复杂度加起来就是正确的了。并且考虑我们当前这个大区间,剩下的还没有找断点这个小区间的左右段点的大小关系和原本大区间左右端点的大小关系不同的时候就说明这个剩下的小区间内没有断点了,具体证明根据上面所说的我们的小区间的值域关系是单调的,如果我们这个小区间内还有断点断开后一定不会继续满足单调性,所以这个区间内部一定没有断点。
nlognn\log nCPP
#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N = 5e5 + 5, mod = 998244353;

int n, fac[N << 1], inv[N << 1], c[N], mx[N][20], mn[N][20], a[N];

__inline__ int power(int x, int y) {
	int w = 1;
	
	while(y) {
		if(y & 1) {
			w = w * x % mod;
		}
		
		x = x * x % mod;
		y >>= 1;
	}
	
	return w;
}

__inline__ int getmn(int l, int r) {
	int lg = __lg(r - l + 1);
	
	return min(mn[l][lg], mn[r - (1 << lg) + 1][lg]);
}

__inline__ int getmx(int l, int r) {
	int lg = __lg(r - l + 1);
	
	return max(mx[l][lg], mx[r - (1 << lg) + 1][lg]);
}

__inline__ bool check(int l, int mid, int r) {
	if(a[l] > a[r]) {
		return getmn(l, mid) > getmx(mid + 1, r);
	}
	
	return getmx(l, mid) < getmn(mid + 1, r);
}

int dfs(int l, int r) {
    
	if(l == r) {
		return 1;
	}
	
	bool tag = (a[l] > a[r]);
	int cnt = 1, sum = 1;
	
	while((a[l] > a[r]) == tag) {// 当不同的时候说明不存在断点了。
        if(l == r) {
            break;
        }
        
        bool tag = 0;
        
		for(int i = 0; l + i <= r - i - 1; i ++) {
			if(check(l, l + i, r)) {// 如果找到了断点就递归,由于找到小区间的花费为较小区间的大小,且较小区间的大小不会超过大区间的一半,因此是启发式合并复杂度。
				tag = 1;
				sum *= dfs(l, l + i);
				sum %= mod;
				l = l + i + 1;
				cnt ++;
				break;
			}
			else if(check(l, r - i - 1, r)) {
				tag = 1;
				sum *= dfs(r - i, r);
				sum %= mod;
				r = r - i - 1;
				cnt ++;
				break;
			}
		}
		
		if(!tag) {// 没有找到断点但是当前子段长度大于等于 2 必然无解。
			cout << 0;
			exit(0);
		}
	}
    
	sum = sum * dfs(l, r) % mod * c[cnt - 1] % mod;
	return sum;
}

__inline__ int C(int n, int m) {
	return fac[n] * inv[m] % mod * inv[n - m] % mod;
}

signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	cin >> n;
	
	for(int i = 1; i <= n; i ++) {
		cin >> a[i];
		mx[i][0] = mn[i][0] = a[i];
	}
	
	for(int i = 1; (1 << i) <= n; i ++) {
		for(int j = 1; j + (1 << i) - 1 <= n; j ++) {
			mx[j][i] = max(mx[j][i - 1], mx[j + (1 << i - 1)][i - 1]);
			mn[j][i] = min(mn[j][i - 1], mn[j + (1 << i - 1)][i - 1]);
		}
    }
	
	fac[0] = inv[0] = 1;
	
	for(int i = 1; i <= n * 2; i ++) {
		fac[i] = fac[i - 1] * i % mod;
		inv[i] = power(fac[i], mod - 2);
	}
	
	for(int i = 1; i <= n; i ++) {
		c[i] = (C(2 * i, i) - C(2 * i, i - 1) + mod) % mod;
	}
	
	cout << dfs(1, n);
	
	return 0;
}

upd:没判无解被 hack 了,你会发现对于一个长度大于等于 22 的还能被分割的子段如果没有找到断点就是无解,判掉即可。

评论

6 条评论,欢迎与作者交流。

正在加载评论...