社区讨论

求卡常

P5282【模板】快速阶乘算法参与者 2已保存回复 1

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
1 条
当前快照
1 份
快照标识符
@lo159qcc
此快照首次捕获于
2023/10/22 15:25
2 年前
此快照最后确认于
2023/11/02 14:57
2 年前
查看原帖
rt.
写的 MTT,理论复杂度是 O(nlogn)\mathcal{O}(\sqrt{n} \log n),如果复杂度假了也欢迎指出。
样例开 O2 本机大概跑 1.8s。
CPP
# include <cstdio>
# include <cmath>
# include <cstring>
# include <cstdlib>
# include <algorithm>

# define int long long

struct Complex {
  double r, i;
  Complex() {}
  Complex(double R, double I) { r = R, i = I; }
  Complex conj() { return Complex(r, -i); }
  Complex operator/ (const int& A) const { return Complex(r / A, i / A); }
  Complex operator+ (const Complex& A) const { return Complex(r + A.r, i + A.i); }
  Complex operator- (const Complex& A) const { return Complex(r - A.r, i - A.i); }
  Complex operator* (const Complex& A) const { return Complex(r * A.r - i * A.i, r * A.i + A.r * i); }
};

int MOD;

constexpr int MAXN = (1 << 18) + 233;
const double Pi = std::acos(-1.0);
const Complex I(0, 1);

int fac[MAXN << 1], ifac[MAXN << 1];

int rev[MAXN << 1], bit, len;

Complex wn[MAXN << 1];

void init(int n) {
	bit = 0, len = 1;
  for (; len < n; len <<= 1, ++bit);
  for (int i = 1; i < len; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << bit - 1;
  for (int i = 0; i < len; ++i) wn[i] = Complex(std::cos(Pi / len * i), std::sin(Pi / len * i));
}

void FFT(Complex* a, int len, bool sgn) {
  for (int i = 0; i < len; ++i) if (i < rev[i]) std::swap(a[i], a[ rev[i] ]);
  for (int n = 1; n < len; n <<= 1) {
    for (int i = 0; i < len; i += n << 1) {
      for (int k = 0; k < n; ++k) {
        Complex wnk = sgn ? wn[len / n * k] : wn[len / n * k].conj();
        Complex x = a[i + k], y = wnk * a[i + k + n];
        a[i + k] = x + y, a[i + k + n] = x - y;
      }
    }
  }
  if (!sgn) for (int i = 0; i < len; ++i) a[i] = a[i] / len;
}

void FFT2(Complex* a, Complex* b, int len, bool sgn) {
	static Complex f[MAXN << 1], P[MAXN << 1];
  if (!sgn) {
    for (int i = 0; i < len; ++i) P[i] = a[i] + I * b[i];
    FFT(P, len, sgn);
    for (int i = 0; i < len; ++i) a[i].r = P[i].r, b[i].r = P[i].i;
    return;
  }
  for (int i = 0; i < len; ++i) f[i] = a[i] + b[i] * I;
  FFT(f, len, sgn);
  for (int i = 0; i < len; ++i) {
    a[i] = (f[i] + f[i ? len - i : 0].conj()) * Complex(0.5, 0);
    b[i] = (f[i ? len - i : 0].conj() - f[i]) * Complex(0, 0.5);
  }
}

void mul(int* a, int* b, int n, int m) {
	if (n * m <= 1000) {
		static int c[400];
		memset(c, 0, sizeof c);
		for (int i = 0; i < n; ++i) {
			for (int j = 0; j < m; ++j) {
				c[i + j] += a[i] * b[j] % MOD;
				c[i + j] -= c[i + j] >= MOD ? MOD : 0;
			}
		}
		memcpy(a, c, sizeof(int) * (n + m - 1));
		return;
	}
  static Complex c1[MAXN << 1], c2[MAXN << 1], c3[MAXN << 1];
  static Complex a0[MAXN << 1], a1[MAXN << 1], b0[MAXN << 1], b1[MAXN << 1];
  for (int i = 0; i < n + m << 1; ++i) a0[i] = a1[i] = b0[i] = b1[i] = Complex(0, 0);
  init(n + m - 1);
  for (int i = 0; i < n; ++i) a0[i].r = a[i] >> 15, a1[i].r = a[i] & 0x7fff;
  for (int i = 0; i < m; ++i) b0[i].r = b[i] >> 15, b1[i].r = b[i] & 0x7fff;
  FFT2(a0, a1, len, 1), FFT2(b0, b1, len, 1);
  for (int i = 0; i < len; ++i) c1[i] = a0[i] * b0[i], c2[i] = a1[i] * b0[i] + a0[i] * b1[i], c3[i] = a1[i] * b1[i];
  FFT2(c1, c2, len, 0), FFT(c3, len, 0);
  for (int i = 0; i < n + m - 1; ++i) 
    a[i] = (((int)round(c1[i].r) % MOD << 30) % MOD + ((int)round(c2[i].r) % MOD << 15) % MOD + ((int)round(c3[i].r) % MOD)) % MOD;
  for (int i = n + m - 1; i < len; ++i) a[i] = 0;
}

int fpow(int x, int k) {
  int res = 1;
  for (; k; k >>= 1) {
    if (k & 1) res = res * x % MOD;
    x = x * x % MOD;
  }
  return res;
}

int inv(int x) {
  return fpow(x, MOD - 2);
}

int ans[MAXN << 1];

void solve(int d, int m) {
	if (d == m) return;
  if (d == 1) ans[0] = 1, ans[1] = (m + 1) % MOD;
  static int S[MAXN << 1], P[MAXN << 1], Q[MAXN << 1];
	memset(S, 0, sizeof S), memset(P, 0, sizeof P);
  for (int i = 0; i <= 2 * d; ++i) S[i] = inv(i) % MOD;
  for (int i = 0; i <= d; ++i) P[i] = ans[i] * ifac[i] % MOD * ifac[d - i] % MOD * (d - i & 1 ? MOD - 1 : 1) % MOD;
	Q[1] = 1;
	for (int i = 0; i <= d; ++i) Q[1] = Q[1] * (d + 1 - i) % MOD;
	for (int i = 2; i <= d; ++i) Q[i] = Q[i - 1] * inv(i - 1) % MOD * (i + d) % MOD;
	mul(S, P, 2 * d + 1, 2 * d + 1);
	for (int i = d + 1; i <= 2 * d; ++i) ans[i] = S[i] * Q[i - d] % MOD;
	int x = d * inv(m) % MOD; d <<= 1, d = d;
	memset(S, 0, sizeof S), memset(P, 0, sizeof P);
  for (int i = 0; i <= 2 * d; ++i) S[i] = inv(i - d + x) % MOD;
  for (int i = 0; i <= d; ++i) P[i] = ans[i] * ifac[i] % MOD * ifac[d - i] % MOD * (d - i & 1 ? MOD - 1 : 1) % MOD;
  Q[0] = 1;
  for (int i = 0; i <= d; ++i) Q[0] = Q[0] * (x - i) % MOD;
  for (int i = 1; i <= d; ++i) Q[i] = Q[i - 1] * inv(i - 1 + x - d) % MOD * (i + x) % MOD;
  mul(S, P, 2 * d + 1, 2 * d + 1);
  for (int i = 0; i <= d; ++i) ans[i] = ans[i] * Q[i] % MOD * S[i + d] % MOD;
	solve(d, m);
}

signed main() {

  int T;
  scanf("%lld", &T);
  while (T--) {
    int n;
    scanf("%lld %lld", &n, &MOD);
    fac[0] = 1;
    for (int i = 1; i <= 1 << 16; ++i) fac[i] = fac[i - 1] * i % MOD;
    ifac[1 << 16] = inv(fac[1 << 16]);
    for (int i = (1 << 16) - 1; i >= 0; --i) ifac[i] = ifac[i + 1] * (i + 1) % MOD;
    solve(1, 1 << 16);
    int res = 1;
    for (int i = 0; i <= 1 << 16; ++i) {
      if (n >= i + 1 << 16) res = res * ans[i] % MOD;
      else {
        for (int j = (i << 16) + 1; j <= n; ++j) res = res * j % MOD;
        break;
      }
    }
    printf("%lld\n", res); 
  }

  return 0;
}

回复

1 条回复,欢迎继续交流。

正在加载回复...