专栏文章

题解:AT_abc425_g [ABC425G] Sum of Min of XOR

AT_abc425_g题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@minqkoam
此快照首次捕获于
2025/12/02 06:44
3 个月前
此快照最后确认于
2025/12/02 06:44
3 个月前
查看原文

解法

看到异或想到 01-Trie。
考虑对于每一个 xx,求出最小的异或值。
我们可以贪心地做。假如我们已经知道了一个数,什么情况下另一个数与这个数的异或值最小呢?对了,就是在这两个数相等的时候。
所以我们在 01-Trie 上从上往下走,每次尽可能走与 xx 的这一二进制位一致的数即可。
但是这种方法的时间复杂度依旧过不去,怎么优化呢?
我们换种角度,从 01-Trie 节点的角度来考虑。
从上往下走,记录经过该点的 xx 数量 cntcnt 和深度(从大到小) depdep
延续之前的贪心做法,如果两个子节点都存在,那么将所有 xx 放进它的下一位对应的子节点里,否则全部放进唯一的子节点里,并且对答案产生贡献。
但有个问题,怎么求出下一位为 0011xx 个数呢?
很明显,由于每个 xx 都是连续的,所以经过这个节点的 xx 的下一位一定是从 0011 的,由于 00 是一定会先加满才会变成 11,所以下一位为 00 的个数为 min(cnt,2dep1)\min(cnt,2^{dep-1}),为 11 的个数为 cntmin(cnt,2dep1)cnt-\min(cnt,2^{dep-1})
可以发现,cntcnt2dep2^{dep} 的节点会经常出现,所以我们可以先预处理出所有这样的节点的答案。
这样时间复杂度就可以足够通过本题了。

代码

CPP
#include <bits/stdc++.h>
#define int long long
const int N = 2e5 + 5;
const int Mod = 1e9 + 7;
using namespace std;
int n, m;
int a[N];
struct node
{
    int ch[2] = {-1, -1};
    int cnt;
    int val;
} trie[N * 30];
int tot;
void insert(int x, int p)
{
    for (int i = 29; i >= 0; i--)
    {
        bool c = x & (1ll << i);
        if (trie[p].ch[c] == -1)
        {
            trie[p].ch[c] = ++tot;
            trie[tot].cnt = 1ll << i;
        }
        p = trie[p].ch[c];
    }
}
// 求走满情况下
void pre(int s)
{
    int ls = trie[s].ch[0];
    int rs = trie[s].ch[1];
    if (ls != -1 && rs != -1)
    {
        pre(ls);
        pre(rs);
        trie[s].val = trie[ls].val + trie[rs].val;
    }
    else if (ls != -1)
    {
        pre(ls);
        trie[s].val = trie[ls].val * 2 + trie[ls].cnt * trie[ls].cnt;
    }
    else if (rs != -1)
    {
        pre(rs);
        trie[s].val = trie[rs].val * 2 + trie[rs].cnt * trie[rs].cnt;
    }
}
int bfs()
{
    int ans = 0;
    queue<pair<int, int>> q;
    q.emplace(0, m);
    while (!q.empty())
    {
        auto [s, cnt] = q.front();
        q.pop();
        if (!cnt)
        {
            continue;
        }
        int ls = trie[s].ch[0];
        int rs = trie[s].ch[1];
        int nl = min(cnt, trie[s].cnt >> 1);
        int nr = cnt - nl;
        if (!nr)
        {
            if (ls != -1)
            {
                q.emplace(ls, nl);
            }
            else
            {
                q.emplace(rs, nl);
                ans += trie[rs].cnt * nl;
            }
        }
        else
        {
            if (ls != -1 && rs != -1)
            {
                ans += trie[ls].val;
                q.emplace(rs, nr);
            }
            else if (ls != -1)
            {
                q.emplace(ls, nr);
                ans += trie[ls].cnt * nr + trie[ls].val;
            }
            else if (rs != -1)
            {
                q.emplace(rs, nr);
                ans += trie[rs].cnt * nl + trie[rs].val;
            }
        }
    }
    return ans;
}
signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
        insert(a[i], 0);
    }
    trie[0].cnt = 1ll << 30;
    pre(0);
    cout << bfs();
    return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...