专栏文章

题解:AT_abc406_e [ABC406E] Popcount Sum 3

AT_abc406_e题解参与者 4已保存评论 4

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
4 条
当前快照
1 份
快照标识符
@mip9phsw
此快照首次捕获于
2025/12/03 08:27
3 个月前
此快照最后确认于
2025/12/03 08:27
3 个月前
查看原文

题解:AT_abc406_e [ABC406E] Popcount Sum 3

题意:求 N\le N 中满足其恰好有 KK 二进制位为 1 的所有数之和。
首先考虑一个特殊情况,若没有 N\le N 的限制的答案是多少呢?我们记 F(i,j)F(i, j)ii 个二进制位中选 jj 个位填 1 的数的和,不难发现:
F(i,j)=(2i1)×(i1j1)F(i, j) = (2 ^ i - 1) \times \dbinom{i - 1}{j - 1}
现实意义很简单,对于每个位 kk,其如果选定为 1,则方案数为在 i1i - 1 位中选定 j1j - 1 位为 1 的方案数,其他情况不会产生贡献。因而单个位贡献为 2k×(i1j1)2 ^ k \times \dbinom{i - 1}{j - 1}。因而提公因式后等比数列求和可得:
F(i,j)=k=0i12k×(i1j1)=(2i1)×(i1j1)F(i, j) = \sum _ {k = 0} ^ {i - 1} {2 ^ k \times \dbinom{i - 1}{j - 1}} = (2 ^ i - 1) \times \dbinom{i - 1}{j - 1}
现在回到题目。加上了 N\le N 的限制,其实本可以像数位 DP 那样做,但是似乎这种思路会好想一些?
我们记cnt(n,k)cnt(n, k)sum(n,k)sum(n, k) 分别为 n\le n 中满足 kk 个二进制位为 1 的方案数、数之和。答案即为 sum(N,K)sum(N, K)
考虑如何用已知的 F(n,k)F(n, k) 求出 sum(n,k)sum(n, k)
首先,如果 nn 的最高位 ii 选了 1,后面的答案可以拆成第 ii 位的贡献和第 1 位到第 i1i - 1 位的贡献,即 sum(n2i,k1)+2i×cnt(n2i,k1)sum(n - 2 ^ i, k - 1) + 2^i \times cnt(n - 2 ^ i, k - 1),否则若不选 1,后面 ii 位随便选,答案为 F(i,k)F(i, k)
cnt(n,k)cnt(n, k) 的方法同理可得 cnt(n,k)=cnt(n2i,k1)+(ik)cnt(n, k) = cnt(n - 2 ^ i, k - 1) + \dbinom{i}{k}
实现方法有递归和递推,我觉得递归会好实现一些,cntcntsumsum 可以存一块。此外,由于不明原因,此代码过程量不开 __int128 无法 AC,欢迎 dalao 们指处错误。
CPP
#include <bits/stdc++.h>
using namespace std;
#define int unsigned long long int
const int mod = 998244353;
int qpow(int x, int y, const int mod) {
    int res = 1;
    for (; y; y >>= 1) {
        if (y & 1) res = res * x % mod;
        x = x * x % mod;
    } return res;
}
int fac[65], invfac[65];
void init() {
    fac[0] = 1; for (int i = 1; i <= 63; i++) fac[i] = fac[i - 1] * i % mod;
    for (int i = 0; i <= 63; i++) invfac[i] = qpow(fac[i], mod - 2, mod);
}
int comb(int n, int m) { return m > n ? 0 : (__int128)fac[n] * invfac[m] % mod * invfac[n - m] % mod; }
int chose(int i, int j) { return (__int128)((1ULL << i) - 1) * comb(i - 1, j - 1) % mod; }
struct node { int sum, cnt; };
node solve(int n, int k, int t) { // t 记录的是当前位数
    if (k > t + 1) return {0, 0};
    else if (n == 0) return {0, !k};
    if (n >> t) {
        node x = solve(n ^ (1ULL << t), k - 1, t - 1);
        return {(int)((chose(t, k) + x.sum) % mod + (__int128)x.cnt * (1ULL << t) % mod) % mod,
                (x.cnt + comb(t, k)) % mod};
    } else return solve(n, k, t - 1); // 还不是最高位
}
main() {
    ios::sync_with_stdio(false), cin.tie(0);
    init();
    int t, n, k;
    for (cin >> t; t; t--) {
        cin >> n >> k;
        int w = 63; while (!(n >> w & 1)) w--; // 求最高位
        cout << solve(n, k, w).sum % mod << '\n';
    }
    return 0;
}

评论

4 条评论,欢迎与作者交流。

正在加载评论...