社区讨论
刚学3天萌新小白求卡常
P5205【模板】多项式开根参与者 3已保存回复 6
讨论操作
快速查看讨论及其快照的属性,并进行相关操作。
- 当前回复
- 6 条
- 当前快照
- 1 份
- 快照标识符
- @mkjf7kwp
- 此快照首次捕获于
- 2026/01/18 15:34 上个月
- 此快照最后确认于
- 2026/01/21 21:40 4 周前
只有35分
CPP#include<bits/stdc++.h>
#define ll long long
#define pii pair<int, int>
#define pll pair<ll, ll>
#define x first
#define y second
#define vc vector
#define vci vc<int>
#define vcl vc<ll>
#define psb push_back
using namespace std;
inline int read(){
char c = getchar();
int ans = 0, cnt = 1;
while(c < '0' || c > '9'){
if(c == '-') cnt = -1;
c = getchar();
}
while(c >= '0' && c <= '9'){
ans *= 10;
ans += (c - '0');
c = getchar();
}
return ans * cnt;
}
inline void write(int x){
if(x < 0){
putchar('-');
x = -x;
}
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
ll ksm(ll a, ll b, ll md){
ll ans = 1;
a %= md;
while(b){
if(b % 2) ans = ans * a % md;
a = a * a % md;
b /= 2;
}
return ans;
}
mt19937 rnd(time(0));
ll rand(ll l, ll r){
return rnd() % (r - l + 1) + l;
}
const ll md = 998244353;
int n, m;
ll un[500005][2];
vcl f;
pll SQRT(ll n, ll md){
if(ksm(n, (md - 1) / 2, md) == md - 1) return {-1, -1};
if(!n) return {0ll, 0ll};
ll x = 0;
while(!x || ksm((x * x - n + md) % md, (md - 1) / 2, md) == 1) x = rand(0, md - 1);
ll unit = (x * x + md - n) % md;
pll ans = {1, 0}, a = {x, 1}, t;
int b = (md + 1) / 2;
while(b){
if(b % 2){
t.x = (ans.x * a.x + unit * ans.y % md * a.y % md) % md;
t.y = (ans.y * a.x + a.y * ans.x) % md;
ans = t;
}
t.x = (a.x * a.x + unit * a.y % md * a.y % md) % md;
t.y = (a.y * a.x * 2) % md;
a = t;
b /= 2;
}
ans.y = md - ans.x;
if(ans.x > ans.y) swap(ans.x, ans.y);
return ans;
}
vcl operator %(vcl f, int n){
vcl ans;
for(int i = 0; i < min(n, (int)f.size()); i++) ans.psb(f[i]);
return ans;
}
vcl operator *(vcl f, int x){
for(int i = 0; i < f.size(); i++) f[i] *= x, f[i] %= md;
return f;
}
vcl operator *(int x, vcl f){
for(int i = 0; i < f.size(); i++) f[i] *= x, f[i] %= md;
return f;
}
vcl operator +(vcl f, vcl g){
while(f.size() < g.size()) f.psb(0);
for(int i = 0; i < g.size(); i++) f[i] += g[i], f[i] %= md;
return f;
}
vcl operator -(vcl f, vcl g){
while(f.size() < g.size()) f.psb(0);
for(int i = 0; i < g.size(); i++) f[i] += md - g[i], f[i] %= md;
return f;
}
void ntt(ll *a, int n, int op){
if(n == 1) return ;
ll l[n / 2], r[n / 2];
for(int i = 0; i < n; i += 2) l[i / 2] = a[i], r[i / 2] = a[i + 1];
ntt(l, n / 2, op), ntt(r, n / 2, op);
ll rt = 1;
for(int i = 0; i < n / 2; i++){
a[i] = (l[i] + rt * r[i]) % md, a[i + n / 2] = (l[i] - rt * r[i]) % md;
rt = rt * un[n][op] % md;
}
}
vcl operator *(vcl f, vcl g){
int sz = f.size() + g.size() - 1, SZ = 1;
while(SZ < sz) SZ *= 2;
ll a[SZ], b[SZ];
for(int i = 0; i < SZ; i++) a[i] = b[i] = 0;
for(int i = 0; i < f.size(); i++) a[i] = f[i];
for(int i = 0; i < g.size(); i++) b[i] = g[i];
ntt(a, SZ, 0), ntt(b, SZ, 0);
for(int i = 0; i < SZ; i++) a[i] = a[i] * b[i] % md;
ntt(a, SZ, 1);
ll ny = ksm(SZ, md - 2, md);
vcl ans;
for(int i = 0; i < sz; i++) ans.psb((a[i] * ny % md + md) % md);
return ans;
}
vcl get_inv(ll *a, int n){
vcl ans;
if(n == 1){
ans.psb(ksm(a[0], md - 2, md));
return ans;
}
vcl f, res = get_inv(a, n / 2);
for(int i = 0; i < n; i++) f.psb(a[i]);
ans = ((res * 2) - (f * ((res * res) % n))) % n;
return ans;
}
vcl INV(vcl f, int n){
int SZ = 1;
while(SZ < n) SZ *= 2;
ll a[SZ];
for(int i = 0; i < SZ; i++) a[i] = 0;
for(int i = 0; i < min(n, (int)f.size()); i++) a[i] = f[i];
vcl ans = get_inv(a, SZ) % n;
for(ll i = 0; i < ans.size(); i++) ans[i] = (ans[i] % md + md) % md;
return ans;
}
vcl get_sqrt1(ll *a, int n){
vcl ans;
if(n == 1){
ans.psb(SQRT(a[0], md).x);
return ans;
}
vcl f, res = get_sqrt1(a, n / 2);
for(int i = 0; i < n; i++) f.psb(a[i]);
ans = INV(res * 2, n);
ans = (res * res + f) * INV(res * 2, n) % n;
return ans;
}
vcl get_sqrt2(ll *a, int n){
vcl ans;
if(n == 1){
ans.psb(SQRT(a[0], md).y);
return ans;
}
vcl f, res = get_sqrt2(a, n / 2);
for(int i = 0; i < n; i++) f.psb(a[i]);
ans = INV(res * 2, n);
ans = (res * res + f) * INV(res * 2, n) % n;
return ans;
}
vcl SQRT(vcl f, int n){
int SZ = 1;
while(SZ < n) SZ *= 2;
ll a[SZ];
for(int i = 0; i < SZ; i++) a[i] = 0;
for(int i = 0; i < n; i++) a[i] = f[i];
vcl ans = get_sqrt1(a, SZ) % n;
for(ll i = 0; i < ans.size(); i++) ans[i] = (ans[i] % md + md) % md;
vcl res = get_sqrt2(a, SZ) % n;
for(ll i = 0; i < res.size(); i++) res[i] = (res[i] % md + md) % md;
if(res[0] < ans[0]) return res;
return ans;
}
int main(){
for(int i = 1; i <= 500000; i *= 2) un[i][0] = ksm(3, (md - 1) / i, md), un[i][1] = ksm(un[i][0], md - 2, md);
cin >> n;
for(int i = 0, a; i < n; i++) cin >> a, f.psb(a);
vcl ans = SQRT(f, n);
for(int x : ans) cout << x << ' ';
return 0;
}
回复
共 6 条回复,欢迎继续交流。
正在加载回复...