专栏文章

题解:AT_arc197_d [ARC197D] Ancestor Relation

AT_arc197_d题解参与者 2已保存评论 2

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
2 条
当前快照
1 份
快照标识符
@mipdbcsm
此快照首次捕获于
2025/12/03 10:08
3 个月前
此快照最后确认于
2025/12/03 10:08
3 个月前
查看原文
通过精细化的实现,其实可以做到 O(n2)O(n^2) 的复杂度。
首先对于有解的情况,不难发现以下两个性质:
  1. 将树根据“分叉点”拆成若干条链,则矩阵中两行相同当且仅当结点在同一条链上;
  2. 一条链是另一条链的祖先当且仅当矩阵中对应元素为 11 且前者所在行的 11 的个数大于后者。
根据这两个性质,可以通过矩阵重建出树的结构,步骤如下:
首先使用字典树或基数排序将矩阵各行按照字典序进行排序,用时 O(n2)O(n^2)
然后按照字典序枚举每个结点,如果当前结点对应的矩阵中的行与上一个结点完全相同,则属于同一条链,否则新开一条链。通过这一步得到了所有链的划分,每条链用时 O(n)O(n),总用时 O(n2)O(n^2)
最后,对每条链,找出其父链,父链根据以下两个特征唯一确定,每条链用时 O(n)O(n),总用时 O(n2)O(n^2)
  1. 矩阵中对应元素为 11
  2. 所在行的 11 的个数比当前链大且尽可能小。
为了判断无解的情况,只需按照上面的步骤在链之间连边,如果连成的不是一颗合法的树(例如有环或不连通或结点 11 所在的链不是根)则无解,如果是合法的树就重算一遍矩阵,如果与输入矩阵不一致也无解,显然这样的判断是充分的,且可以在 O(n2)O(n^2) 时间内完成。
考虑如何统计答案,对于结点 11 所在的链,除了结点 11 之外,其余结点的顺序可以任意排列,对于其他链,所有结点的顺序都可以任意排列,因此答案呈现为若干个阶乘的乘积,通过预处理阶乘可以用 O(n)O(n) 时间完成。
核心代码如下:
CPP
int main() {
	int t = read();
	while (t--) {
		int n = read();
		vector a(n, vector<bool>(n));
		for (int i = 0; i < n; i++)
			for (int j = 0; j < n; j++)
				a[i][j] = read() == 1;
		vector<int> sorted(n);
		iota(sorted.begin(), sorted.end(), 0);
		for (int v = n - 1; v >= 0; v--) {
			vector<int> zero, one;
			for (auto u : sorted)
				(a[u][v] ? one : zero).push_back(u);
			int cnt0 = zero.size(), cnt1 = n - cnt0;
			for (int i = 0; i < cnt1; i++)
				sorted[i] = one[i];
			for (int i = 0; i < cnt0; i++)
				sorted[i + cnt1] = zero[i];
		}
		vector<vector<int>> chain;
		vector<int> in_chain(n, -1);
		for (int r = 0; r < n; r++) {
			int i = sorted[r];
			if (r == 0 || a[i] != a[sorted[r - 1]])
				chain.emplace_back();
			in_chain[i] = chain.size() - 1;
			chain.back().push_back(i);
		}
		if (in_chain[0] != 0) {
			cout << "0\n";
			continue;
		}
		int m = chain.size();
		vector<int> deg(n);
		for (int i = 0; i < n; i++)
			deg[i] = count(a[i].begin(), a[i].end(), true);
		vector b(m, vector<bool>(m));
		for (int i = 0; i < n; i++) {
			int p = -1, mindeg = n + 1;
			for (int j = 0; j < n; j++) {
				if (!a[i][j])
					continue;
				if (deg[j] > deg[i] && deg[j] < mindeg) {
					p = in_chain[j];
					mindeg = deg[j];
				}
			}
			if (p != -1)
				b[in_chain[i]][p] = b[p][in_chain[i]] = true;
		}
		int ecnt = 0;
		for (auto& bb : b)
			ecnt += count(bb.begin(), bb.end(), true);
		ecnt /= 2;
		if (ecnt != m - 1) {
			cout << "0\n";
			continue;
		}
		vector vis(m, false);
		function<void(int)> dfs1 = [&](int u) {
			vis[u] = true;
			for (int v = 0; v < m; v++)
				if (b[u][v] && !vis[v])
					dfs1(v);
			};
		dfs1(0);
		if (count(vis.begin(), vis.end(), true) != m) {
			cout << "0\n";
			continue;
		}
		vector child(m, vector<int>());
		function<void(int, int)> dfs2 = [&](int u, int p) {
			child[u].push_back(u);
			for (int v = 0; v < m; v++)
				if (b[u][v] && v != p) {
					dfs2(v, u);
					for (auto ch : child[v])
						child[u].push_back(ch);
				}
			};
		dfs2(0, -1);
		vector c(n, vector<bool>(n));
		for (int i = 0; i < m; i++) {
			for (auto u : chain[i])
				for (auto j : child[i])
					for (auto v : chain[j])
						c[u][v] = c[v][u] = true;
		}
		bool fail = false;
		for (int i = 0; i < n; i++)
			if (a[i] != c[i]) {
				fail = true;
				break;
			}
		if (fail) {
			cout << "0\n";
			continue;
		}
		m998 res = fac[chain[0].size() - 1];
		for (int i = 1; i < m; i++)
			res = res * fac[chain[i].size()];
		cout << res << '\n';
	}
	return 0;
}

评论

2 条评论,欢迎与作者交流。

正在加载评论...