专栏文章
形式幂级数复合/复合逆的 Kinoshita–Li 算法简洁实现
个人记录参与者 1已保存评论 0
文章操作
快速查看文章及其快照的属性,并进行相关操作。
- 当前评论
- 0 条
- 当前快照
- 1 份
- 快照标识符
- @mioagqqw
- 此快照首次捕获于
- 2025/12/02 16:00 3 个月前
- 此快照最后确认于
- 2025/12/02 16:00 3 个月前
这篇文章是给出一个简洁的代码实现,学习算法还是看原论文(见参考文献 1)。
实现
这个实现的意义在于,FFT 可以是黑盒,我们只需要确定其输入和输出是什么,并不需要关心其实现(意味着可以优化 FFT 而不需要修改 Kinoshita–Li 算法的实现),并且在其他地方不会访问预处理的“单位根数组”。原始的实现可以看我在 OI-Wiki 的提交 https://github.com/OI-wiki/OI-wiki/blob/fef757921b0956907d58a99e9940541f5e1a5ee2/docs/math/code/poly/comp-rev/rev_1.cpp,这里的函数
KinoshitaLi() 几乎是函数式的实现。但是这样一来会有一个问题:我们做了很多次冗余的内存拷贝/移动,因为其实考虑下一层的运算,其实内存已经在我们想要的位置了,只需要把高位“清零”,于是我接下来写了下面这个代码,基本上没有压行。这里的 FFT() 也是常见的 DIF-DIT 实现即 FFT() 输出为 bit-reversal 置换,InvFFT() 输入为 bit-reversal 置换,所以不需要显式的做置换,这个应该大家都会。代码
C// CXXFLAGS=-std=c++17 -Wall -Wextra
#include <algorithm>
#include <cassert>
#include <cstring>
#include <tuple>
#include <utility>
#include <vector>
using uint = unsigned;
using ull = unsigned long long;
constexpr uint MOD = 998244353;
constexpr uint PowMod(uint a, ull e) {
for (uint res = 1;; a = (ull)a * a % MOD) {
if (e & 1) res = (ull)res * a % MOD;
if ((e /= 2) == 0) return res;
}
}
constexpr uint InvMod(uint a) { return PowMod(a, MOD - 2); }
constexpr uint QUAD_NONRESIDUE = 3;
constexpr int LOG2_ORD = __builtin_ctz(MOD - 1);
constexpr uint ZETA = PowMod(QUAD_NONRESIDUE, (MOD - 1) >> LOG2_ORD);
constexpr uint INV_ZETA = InvMod(ZETA);
std::pair<std::vector<uint>, std::vector<uint>> GetFFTRoot(int n) {
assert((n & (n - 1)) == 0);
if (n / 2 == 0) return {};
std::vector<uint> root(n / 2), inv_root(n / 2);
root[0] = inv_root[0] = 1;
for (int i = 0; (1 << i) < n / 2; ++i)
root[1 << i] = PowMod(ZETA, 1LL << (LOG2_ORD - i - 2)),
inv_root[1 << i] = PowMod(INV_ZETA, 1LL << (LOG2_ORD - i - 2));
for (int i = 1; i < n / 2; ++i)
root[i] = (ull)root[i - (i & (i - 1))] * root[i & (i - 1)] % MOD,
inv_root[i] = (ull)inv_root[i - (i & (i - 1))] * inv_root[i & (i - 1)] % MOD;
return {root, inv_root};
}
void Butterfly(uint a[], int n, const uint root[]) {
assert((n & (n - 1)) == 0);
for (int i = n; i >= 2; i /= 2)
for (int j = 0; j < n; j += i)
for (int k = j; k < j + i / 2; ++k) {
const uint u = a[k];
a[k + i / 2] = (ull)a[k + i / 2] * root[j / i] % MOD;
if ((a[k] += a[k + i / 2]) >= MOD) a[k] -= MOD;
if ((a[k + i / 2] = u + MOD - a[k + i / 2]) >= MOD) a[k + i / 2] -= MOD;
}
}
void InvButterfly(uint a[], int n, const uint root[]) {
assert((n & (n - 1)) == 0);
for (int i = 2; i <= n; i *= 2)
for (int j = 0; j < n; j += i)
for (int k = j; k < j + i / 2; ++k) {
const uint u = a[k];
if ((a[k] += a[k + i / 2]) >= MOD) a[k] -= MOD;
a[k + i / 2] = (ull)(u + MOD - a[k + i / 2]) * root[j / i] % MOD;
}
}
int GetFFTSize(int n) {
int len = 1;
while (len < n) len *= 2;
return len;
}
void FFT(uint a[], int n, const uint root[]) { Butterfly(a, n, root); }
void InvFFT(uint a[], int n, const uint root[]) {
InvButterfly(a, n, root);
const uint invn = InvMod(n);
for (int i = 0; i < n; ++i) a[i] = (ull)a[i] * invn % MOD;
}
std::vector<uint> FPSComp(std::vector<uint> f, std::vector<uint> g, int n) {
assert(empty(g) || g[0] == 0);
const int len = GetFFTSize(n);
std::vector<uint> root, inv_root;
tie(root, inv_root) = GetFFTRoot(len * 4);
// [y^(-1)] (f(y) / (-g(x) + y)) mod x^n in R[x]((y^(-1)))
const auto KinoshitaLi = [&](auto &&KinoshitaLi, std::vector<uint> &P, std::vector<uint> Q,
int d, int n) {
assert((int)size(P) == d * n * 2);
assert((int)size(Q) == d * n * 2);
if (n == 1) return;
Q.resize(d * n * 4);
Q[d * n * 2] = 1;
FFT(data(Q), d * n * 4, data(root));
std::vector<uint> V(d * n * 2);
for (int i = 0; i < d * n * 4; i += 2) V[i / 2] = (ull)Q[i] * Q[i + 1] % MOD;
InvFFT(data(V), d * n * 2, data(inv_root));
assert(V[0] == 1);
V[0] = 0;
for (int i = 0; i < d * 2; ++i)
std::memset(data(V) + i * n + n / 2, 0, sizeof(uint) * (n / 2));
KinoshitaLi(KinoshitaLi, P, std::move(V), d * 2, n / 2);
FFT(data(P), d * n * 2, data(root));
for (int i = 0; i < d * n * 4; i += 2) {
const uint u = Q[i];
Q[i] = (ull)P[i / 2] * Q[i + 1] % MOD;
Q[i + 1] = (ull)P[i / 2] * u % MOD;
}
InvFFT(data(Q), d * n * 4, data(inv_root));
for (int i = 0; i < d; ++i) {
uint *const u = data(P) + i * n * 2;
std::memcpy(u, data(Q) + (i + d) * (n * 2), sizeof(uint) * n);
std::memset(u + n, 0, sizeof(uint) * n);
}
};
f.resize(len * 2);
g.resize(len * 2);
for (int i = len - 1; i >= 0; --i) f[i * 2] = f[i], f[i * 2 + 1] = 0;
for (int i = 0; i < len; ++i) g[i] = (g[i] != 0 ? MOD - g[i] : 0);
std::memset(data(g) + len, 0, sizeof(uint) * len);
KinoshitaLi(KinoshitaLi, f, std::move(g), 1, len);
f.resize(n);
return f;
}
// Power Projection: [x^(n-1)] (fg^i) for i=0,..,n-1
std::vector<uint> PowProj(std::vector<uint> f, std::vector<uint> g, int n) {
assert(empty(g) || g[0] == 0);
const int len = GetFFTSize(n);
std::vector<uint> root, inv_root;
tie(root, inv_root) = GetFFTRoot(len * 4);
// [x^(n-1)] (f(x) / (-g(x) + y)) in R[x]((y^(-1)))
const auto KinoshitaLi = [&](std::vector<uint> &P, std::vector<uint> &Q, int d, int n) {
assert((int)size(P) == d * n * 2);
assert((int)size(Q) == d * n * 2);
P.insert(begin(P), d * n * 2, 0u);
Q.resize(d * n * 4);
std::vector<uint> nextP(d * n * 4);
for (; n > 1; d *= 2, n /= 2) {
Q[d * n * 2] = 1;
FFT(data(P), d * n * 4, data(inv_root));
FFT(data(Q), d * n * 4, data(root));
uint *const nP = data(nextP) + d * n * 2;
for (int i = 0; i < d * n * 4; i += 2) {
if ((nP[i / 2] = ((ull)P[i] * Q[i + 1] + (ull)P[i + 1] * Q[i]) % MOD) & 1)
nP[i / 2] += MOD;
nP[i / 2] /= 2;
Q[i / 2] = (ull)Q[i] * Q[i + 1] % MOD;
}
InvFFT(nP, d * n * 2, data(root));
InvFFT(data(Q), d * n * 2, data(inv_root));
assert(Q[0] == 1);
Q[0] = 0;
for (int i = 0; i < d * 2; ++i) {
std::memset(nP + i * n, 0, sizeof(uint) * (n / 2));
std::memset(data(Q) + i * n + n / 2, 0, sizeof(uint) * (n / 2));
}
P.swap(nextP);
std::memset(data(P), 0, sizeof(uint) * (d * n * 2));
std::memset(data(Q) + d * n * 2, 0, sizeof(uint) * (d * n * 2));
}
P.erase(begin(P), begin(P) + d * n * 2);
};
f.insert(begin(f), len - n, 0);
f.resize(len);
reverse(begin(f), end(f));
f.insert(begin(f), len, 0u);
g.resize(len * 2);
for (int i = 0; i < len; ++i) g[i] = (g[i] != 0 ? MOD - g[i] : 0);
std::memset(data(g) + len, 0, sizeof(uint) * len);
KinoshitaLi(f, g, 1, len);
for (int i = 0; i < len; ++i) f[i] = f[i * 2 + 1];
f.resize(n);
return f;
}
std::vector<uint> FPSPow1(std::vector<uint> g, uint e, int n) {
assert(!empty(g) && g[0] == 1);
if (n == 1) return std::vector<uint>{1u};
std::vector<uint> inv(n), f(n);
inv[1] = f[0] = 1;
for (int i = 2; i < n; ++i) inv[i] = (ull)(MOD - MOD / i) * inv[MOD % i] % MOD;
for (int i = 1; i < n; ++i) f[i] = (ull)f[i - 1] * (e + MOD + 1 - i) % MOD * inv[i] % MOD;
g[0] = 0;
return FPSComp(f, g, n);
}
std::vector<uint> FPSRev(std::vector<uint> f, int n) {
assert(size(f) >= 2 && f[0] == 0 && f[1] != 0);
if (n == 1) return std::vector<uint>{0u};
f.resize(n);
const uint invf1 = InvMod(f[1]);
uint invf1p = 1;
for (int i = 0; i < n; ++i) f[i] = (ull)f[i] * invf1p % MOD, invf1p = (ull)invf1p * invf1 % MOD;
std::vector<uint> inv(n);
inv[1] = 1;
for (int i = 2; i < n; ++i) inv[i] = (ull)(MOD - MOD / i) * inv[MOD % i] % MOD;
auto proj = PowProj(std::vector<uint>{1u}, f, n);
for (int i = 1; i < n; ++i) proj[i] = (ull)proj[i] * (n - 1) % MOD * inv[i] % MOD;
reverse(begin(proj), end(proj));
auto res = FPSPow1(proj, InvMod(MOD + 1 - n), n - 1);
for (int i = 0; i < n - 1; ++i) res[i] = (ull)res[i] * invf1 % MOD;
res.insert(begin(res), 0);
return res;
}
更进一步减少冗余操作
其实可以发现我们的写法中仍然有一些冗余操作,例如在
FPSComp() 的 KinoshitaLi() 函数中可以原地存储在 Q 数组而不是移动到 P 数组的话,可以减少一次 memcpy() 的调用,在递归的倒数第二层可以进行特殊处理等等,但是这样的写法已经是相当简洁了。参考文献
- Yasunori Kinoshita, Baitian Li. Power Series Composition in Near-Linear Time. FOCS 2024: 2180-2185 url: https://arxiv.org/abs/2404.05177
- Alin Bostan, Ryuhei Mori. A Simple and Fast Algorithm for Computing the N-th Term of a Linearly Recurrent Sequence. SOSA 2021: 118-132 url: https://arxiv.org/abs/2008.08822
相关推荐
评论
共 0 条评论,欢迎与作者交流。
正在加载评论...