专栏文章

形式幂级数复合/复合逆的 Kinoshita–Li 算法简洁实现

个人记录参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@mioagqqw
此快照首次捕获于
2025/12/02 16:00
3 个月前
此快照最后确认于
2025/12/02 16:00
3 个月前
查看原文
这篇文章是给出一个简洁的代码实现,学习算法还是看原论文(见参考文献 1)。

实现

这个实现的意义在于,FFT 可以是黑盒,我们只需要确定其输入和输出是什么,并不需要关心其实现(意味着可以优化 FFT 而不需要修改 Kinoshita–Li 算法的实现),并且在其他地方不会访问预处理的“单位根数组”。原始的实现可以看我在 OI-Wiki 的提交 https://github.com/OI-wiki/OI-wiki/blob/fef757921b0956907d58a99e9940541f5e1a5ee2/docs/math/code/poly/comp-rev/rev_1.cpp,这里的函数 KinoshitaLi() 几乎是函数式的实现。但是这样一来会有一个问题:我们做了很多次冗余的内存拷贝/移动,因为其实考虑下一层的运算,其实内存已经在我们想要的位置了,只需要把高位“清零”,于是我接下来写了下面这个代码,基本上没有压行。这里的 FFT() 也是常见的 DIF-DIT 实现即 FFT() 输出为 bit-reversal 置换,InvFFT() 输入为 bit-reversal 置换,所以不需要显式的做置换,这个应该大家都会。
代码C
// CXXFLAGS=-std=c++17 -Wall -Wextra
#include <algorithm>
#include <cassert>
#include <cstring>
#include <tuple>
#include <utility>
#include <vector>

using uint         = unsigned;
using ull          = unsigned long long;
constexpr uint MOD = 998244353;

constexpr uint PowMod(uint a, ull e) {
    for (uint res = 1;; a = (ull)a * a % MOD) {
        if (e & 1) res = (ull)res * a % MOD;
        if ((e /= 2) == 0) return res;
    }
}

constexpr uint InvMod(uint a) { return PowMod(a, MOD - 2); }

constexpr uint QUAD_NONRESIDUE = 3;
constexpr int LOG2_ORD         = __builtin_ctz(MOD - 1);
constexpr uint ZETA            = PowMod(QUAD_NONRESIDUE, (MOD - 1) >> LOG2_ORD);
constexpr uint INV_ZETA        = InvMod(ZETA);

std::pair<std::vector<uint>, std::vector<uint>> GetFFTRoot(int n) {
    assert((n & (n - 1)) == 0);
    if (n / 2 == 0) return {};
    std::vector<uint> root(n / 2), inv_root(n / 2);
    root[0] = inv_root[0] = 1;
    for (int i = 0; (1 << i) < n / 2; ++i)
        root[1 << i]               = PowMod(ZETA, 1LL << (LOG2_ORD - i - 2)),
                  inv_root[1 << i] = PowMod(INV_ZETA, 1LL << (LOG2_ORD - i - 2));
    for (int i = 1; i < n / 2; ++i)
        root[i]     = (ull)root[i - (i & (i - 1))] * root[i & (i - 1)] % MOD,
        inv_root[i] = (ull)inv_root[i - (i & (i - 1))] * inv_root[i & (i - 1)] % MOD;
    return {root, inv_root};
}

void Butterfly(uint a[], int n, const uint root[]) {
    assert((n & (n - 1)) == 0);
    for (int i = n; i >= 2; i /= 2)
        for (int j = 0; j < n; j += i)
            for (int k = j; k < j + i / 2; ++k) {
                const uint u = a[k];
                a[k + i / 2] = (ull)a[k + i / 2] * root[j / i] % MOD;
                if ((a[k] += a[k + i / 2]) >= MOD) a[k] -= MOD;
                if ((a[k + i / 2] = u + MOD - a[k + i / 2]) >= MOD) a[k + i / 2] -= MOD;
            }
}

void InvButterfly(uint a[], int n, const uint root[]) {
    assert((n & (n - 1)) == 0);
    for (int i = 2; i <= n; i *= 2)
        for (int j = 0; j < n; j += i)
            for (int k = j; k < j + i / 2; ++k) {
                const uint u = a[k];
                if ((a[k] += a[k + i / 2]) >= MOD) a[k] -= MOD;
                a[k + i / 2] = (ull)(u + MOD - a[k + i / 2]) * root[j / i] % MOD;
            }
}

int GetFFTSize(int n) {
    int len = 1;
    while (len < n) len *= 2;
    return len;
}

void FFT(uint a[], int n, const uint root[]) { Butterfly(a, n, root); }

void InvFFT(uint a[], int n, const uint root[]) {
    InvButterfly(a, n, root);
    const uint invn = InvMod(n);
    for (int i = 0; i < n; ++i) a[i] = (ull)a[i] * invn % MOD;
}

std::vector<uint> FPSComp(std::vector<uint> f, std::vector<uint> g, int n) {
    assert(empty(g) || g[0] == 0);
    const int len = GetFFTSize(n);
    std::vector<uint> root, inv_root;
    tie(root, inv_root) = GetFFTRoot(len * 4);
    // [y^(-1)] (f(y) / (-g(x) + y)) mod x^n in R[x]((y^(-1)))
    const auto KinoshitaLi = [&](auto &&KinoshitaLi, std::vector<uint> &P, std::vector<uint> Q,
                                 int d, int n) {
        assert((int)size(P) == d * n * 2);
        assert((int)size(Q) == d * n * 2);
        if (n == 1) return;
        Q.resize(d * n * 4);
        Q[d * n * 2] = 1;
        FFT(data(Q), d * n * 4, data(root));
        std::vector<uint> V(d * n * 2);
        for (int i = 0; i < d * n * 4; i += 2) V[i / 2] = (ull)Q[i] * Q[i + 1] % MOD;
        InvFFT(data(V), d * n * 2, data(inv_root));
        assert(V[0] == 1);
        V[0] = 0;
        for (int i = 0; i < d * 2; ++i)
            std::memset(data(V) + i * n + n / 2, 0, sizeof(uint) * (n / 2));
        KinoshitaLi(KinoshitaLi, P, std::move(V), d * 2, n / 2);
        FFT(data(P), d * n * 2, data(root));
        for (int i = 0; i < d * n * 4; i += 2) {
            const uint u = Q[i];
            Q[i]         = (ull)P[i / 2] * Q[i + 1] % MOD;
            Q[i + 1]     = (ull)P[i / 2] * u % MOD;
        }
        InvFFT(data(Q), d * n * 4, data(inv_root));
        for (int i = 0; i < d; ++i) {
            uint *const u = data(P) + i * n * 2;
            std::memcpy(u, data(Q) + (i + d) * (n * 2), sizeof(uint) * n);
            std::memset(u + n, 0, sizeof(uint) * n);
        }
    };
    f.resize(len * 2);
    g.resize(len * 2);
    for (int i = len - 1; i >= 0; --i) f[i * 2] = f[i], f[i * 2 + 1] = 0;
    for (int i = 0; i < len; ++i) g[i] = (g[i] != 0 ? MOD - g[i] : 0);
    std::memset(data(g) + len, 0, sizeof(uint) * len);
    KinoshitaLi(KinoshitaLi, f, std::move(g), 1, len);
    f.resize(n);
    return f;
}

// Power Projection: [x^(n-1)] (fg^i) for i=0,..,n-1
std::vector<uint> PowProj(std::vector<uint> f, std::vector<uint> g, int n) {
    assert(empty(g) || g[0] == 0);
    const int len = GetFFTSize(n);
    std::vector<uint> root, inv_root;
    tie(root, inv_root) = GetFFTRoot(len * 4);
    // [x^(n-1)] (f(x) / (-g(x) + y)) in R[x]((y^(-1)))
    const auto KinoshitaLi = [&](std::vector<uint> &P, std::vector<uint> &Q, int d, int n) {
        assert((int)size(P) == d * n * 2);
        assert((int)size(Q) == d * n * 2);
        P.insert(begin(P), d * n * 2, 0u);
        Q.resize(d * n * 4);
        std::vector<uint> nextP(d * n * 4);
        for (; n > 1; d *= 2, n /= 2) {
            Q[d * n * 2] = 1;
            FFT(data(P), d * n * 4, data(inv_root));
            FFT(data(Q), d * n * 4, data(root));
            uint *const nP = data(nextP) + d * n * 2;
            for (int i = 0; i < d * n * 4; i += 2) {
                if ((nP[i / 2] = ((ull)P[i] * Q[i + 1] + (ull)P[i + 1] * Q[i]) % MOD) & 1)
                    nP[i / 2] += MOD;
                nP[i / 2] /= 2;
                Q[i / 2] = (ull)Q[i] * Q[i + 1] % MOD;
            }
            InvFFT(nP, d * n * 2, data(root));
            InvFFT(data(Q), d * n * 2, data(inv_root));
            assert(Q[0] == 1);
            Q[0] = 0;
            for (int i = 0; i < d * 2; ++i) {
                std::memset(nP + i * n, 0, sizeof(uint) * (n / 2));
                std::memset(data(Q) + i * n + n / 2, 0, sizeof(uint) * (n / 2));
            }
            P.swap(nextP);
            std::memset(data(P), 0, sizeof(uint) * (d * n * 2));
            std::memset(data(Q) + d * n * 2, 0, sizeof(uint) * (d * n * 2));
        }
        P.erase(begin(P), begin(P) + d * n * 2);
    };
    f.insert(begin(f), len - n, 0);
    f.resize(len);
    reverse(begin(f), end(f));
    f.insert(begin(f), len, 0u);
    g.resize(len * 2);
    for (int i = 0; i < len; ++i) g[i] = (g[i] != 0 ? MOD - g[i] : 0);
    std::memset(data(g) + len, 0, sizeof(uint) * len);
    KinoshitaLi(f, g, 1, len);
    for (int i = 0; i < len; ++i) f[i] = f[i * 2 + 1];
    f.resize(n);
    return f;
}

std::vector<uint> FPSPow1(std::vector<uint> g, uint e, int n) {
    assert(!empty(g) && g[0] == 1);
    if (n == 1) return std::vector<uint>{1u};
    std::vector<uint> inv(n), f(n);
    inv[1] = f[0] = 1;
    for (int i = 2; i < n; ++i) inv[i] = (ull)(MOD - MOD / i) * inv[MOD % i] % MOD;
    for (int i = 1; i < n; ++i) f[i] = (ull)f[i - 1] * (e + MOD + 1 - i) % MOD * inv[i] % MOD;
    g[0] = 0;
    return FPSComp(f, g, n);
}

std::vector<uint> FPSRev(std::vector<uint> f, int n) {
    assert(size(f) >= 2 && f[0] == 0 && f[1] != 0);
    if (n == 1) return std::vector<uint>{0u};
    f.resize(n);
    const uint invf1 = InvMod(f[1]);
    uint invf1p      = 1;
    for (int i = 0; i < n; ++i) f[i] = (ull)f[i] * invf1p % MOD, invf1p = (ull)invf1p * invf1 % MOD;
    std::vector<uint> inv(n);
    inv[1] = 1;
    for (int i = 2; i < n; ++i) inv[i] = (ull)(MOD - MOD / i) * inv[MOD % i] % MOD;
    auto proj = PowProj(std::vector<uint>{1u}, f, n);
    for (int i = 1; i < n; ++i) proj[i] = (ull)proj[i] * (n - 1) % MOD * inv[i] % MOD;
    reverse(begin(proj), end(proj));
    auto res = FPSPow1(proj, InvMod(MOD + 1 - n), n - 1);
    for (int i = 0; i < n - 1; ++i) res[i] = (ull)res[i] * invf1 % MOD;
    res.insert(begin(res), 0);
    return res;
}

更进一步减少冗余操作

其实可以发现我们的写法中仍然有一些冗余操作,例如在 FPSComp()KinoshitaLi() 函数中可以原地存储在 Q 数组而不是移动到 P 数组的话,可以减少一次 memcpy() 的调用,在递归的倒数第二层可以进行特殊处理等等,但是这样的写法已经是相当简洁了。

参考文献

  1. Yasunori Kinoshita, Baitian Li. Power Series Composition in Near-Linear Time. FOCS 2024: 2180-2185 url: https://arxiv.org/abs/2404.05177
  2. Alin Bostan, Ryuhei Mori. A Simple and Fast Algorithm for Computing the N-th Term of a Linearly Recurrent Sequence. SOSA 2021: 118-132 url: https://arxiv.org/abs/2008.08822

评论

0 条评论,欢迎与作者交流。

正在加载评论...