首先考虑什么样的矩阵是合法的。
因为染黑的一定是行或列的前缀,这说明如果
(i,j) 被染黑那么
(i,j) 要在
i 行选取的前缀中或在
j 列选取的前缀中,即
(k,j)(k≤i) 或
(i,k)(k≤j) 中至少有一种全为黑点。
接下来根据这个结论来计数。
考虑现在在矩阵后加了一行,原矩阵要满足哪些性质。
首先这一行前缀的
1 都不用考虑,因为可以就当做是这一行选的前缀的位置;而对于剩余的
1,就只能由其对应的那一列来保证合法,那么就需要满足这些列在原矩阵中全为
1。
一个很暴力的想法就是记
fn,s 代表前
n 行集合
s 对应的列全为
1 的方案数,同理有
g 表示和,不过这样对于
m 是指数级的复杂度显然非常劣。
上面的这个办法是考虑每次选后再判断状态是否满足条件,考虑另一种想法:强制状态满足条件。
具体来说,对于剩余的
1,强制要求在原矩阵中这些列都为
1,那么会发现原矩阵的方案正好能够一一对应原矩阵去掉这些列的方案数,这是因为为
1 的列并不会截断前缀
1,也不会影响剩余
1。
于是考虑设
fi,j 表示一个
i 行
j 列的合法矩阵的数量,
gi,j 表示其
1 的和。
边界情况就为
fi,j=1,gi,j=0(ij=0)。
转移考虑枚举前缀
1 数量,再枚举剩余
1 数量(特殊处理一下全为
1 的情况,因为其余情况前缀
1 后必定会跟一个
0):
fi,j=fi−1,j+a=0∑j−1b=0∑j−a−1(bj−a−1)fi−1,j−bgi,j=gi−1,j+fi−1,j×j+a=0∑j−1b=0∑j−a−1(bj−a−1)(fi−1,j−b×(ib+a)+gi−1,j−b)
这样的复杂度是
O(nm3),考虑进一步优化这个求和式:
===a=0∑j−1b=0∑j−a−1(bj−1−a)fi−1,j−bb=0∑fi−1,j−ba=0∑j−b−1(bj−a−1)b=0∑fi−1,j−ba=0∑j−1(ba)b=0∑fi−1,j−b(b+1j)
===a=0∑j−1b=0∑j−a−1(bj−a−1)(fi−1,j−b×(ib+a)+gi−1,j−b)b=0∑j−1(fi−1,j−b×ib+gi−1,j−b)a=0∑j−b−1(bj−a−1)+b=0∑j−1fi−1,j−ba=0∑j−b−1(bj−a−1)×ab=0∑j−1(fi−1,j−b×ib+gi−1,j−b)(b+1j)+b=0∑j−1fi−1,j−ba=0∑j−1(bj−a−1)(1a)b=0∑j−1(fi−1,j−b×ib+gi−1,j−b)(b+1j)+b=0∑j−1fi−1,j−b(b+2j)
于是可以做到
O(nm2),又因为限制了
nm 且
n,m 交换不影响答案,考虑小的那一维作为
m,时间复杂度
O(nm×min{n,m}),不劣于
O(nmnm)。
CPP#include <bits/stdc++.h>
#include <atcoder/modint>
using mint = atcoder::modint998244353;
constexpr int maxn = 2e5 + 10;
mint fac[maxn], ifac[maxn];
inline mint binom(int n, int m) {
return n < m || m < 0 ? (mint)0 : (fac[n] * ifac[n - m] * ifac[m]);
}
mint f[maxn], g[maxn];
mint lf[maxn], lg[maxn];
int main() {
int n, m;
scanf("%d%d", &n, &m);
if (n < m) std::swap(n, m);
fac[0] = 1;
for (int i = 1; i <= m; i++) fac[i] = fac[i - 1] * i;
ifac[m] = fac[m].inv();
for (int i = m; i >= 1; i--) ifac[i - 1] = ifac[i] * i;
for (int i = 0; i <= m; i++) f[i] = 1, g[i] = 0;
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= m; j++) lf[j] = f[j], lg[j] = g[j], f[j] = g[j] = 0;
f[0] = 1, g[0] = 0;
for (int j = 1; j <= m; j++) {
f[j] = lf[j], g[j] = lg[j] + lf[j] * j;
for (int b = 0; b <= j - 1; b++) {
f[j] += binom(j, b + 1) * lf[j - b];
g[j] += binom(j, b + 1) * (lf[j - b] * b * i + lg[j - b]);
g[j] += binom(j, b + 2) * lf[j - b];
}
}
}
return printf("%d\n", g[m].val()), 0;
}