社区讨论

刚学3天萌新小白求卡常

P5205【模板】多项式开根参与者 3已保存回复 6

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
6 条
当前快照
1 份
快照标识符
@mkjf7kwp
此快照首次捕获于
2026/01/18 15:34
上个月
此快照最后确认于
2026/01/21 21:40
4 周前
查看原帖
只有35分
CPP
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int, int>
#define pll pair<ll, ll>
#define x first
#define y second
#define vc vector
#define vci vc<int>
#define vcl vc<ll>
#define psb push_back
using namespace std;
inline int read(){
	char c = getchar();
	int ans = 0, cnt = 1;
	while(c < '0' || c > '9'){
		if(c == '-') cnt = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9'){
		ans *= 10;
		ans += (c - '0');
		c = getchar();
	}
	return ans * cnt;
}
inline void write(int x){
    if(x < 0){
        putchar('-');
        x = -x;
    }
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
ll ksm(ll a, ll b, ll md){
	ll ans = 1;
	a %= md;
	while(b){
		if(b % 2) ans = ans * a % md;
		a = a * a % md;
		b /= 2;
	}
	return ans;
}
mt19937 rnd(time(0));
ll rand(ll l, ll r){
	return rnd() % (r - l + 1) + l;
}
const ll md = 998244353;
int n, m;
ll un[500005][2];
vcl f;
pll SQRT(ll n, ll md){
	if(ksm(n, (md - 1) / 2, md) == md - 1) return {-1, -1};
	if(!n) return {0ll, 0ll};
	ll x = 0;
	while(!x || ksm((x * x - n + md) % md, (md - 1) / 2, md) == 1) x = rand(0, md - 1);
	ll unit = (x * x + md - n) % md;
	pll ans = {1, 0}, a = {x, 1}, t;
	int b = (md + 1) / 2;
	while(b){
		if(b % 2){
			t.x = (ans.x * a.x + unit * ans.y % md * a.y % md) % md;
			t.y = (ans.y * a.x + a.y * ans.x) % md;
			ans = t;
		}
		t.x = (a.x * a.x + unit * a.y % md * a.y % md) % md;
		t.y = (a.y * a.x * 2) % md;
		a = t;
		b /= 2;
	}
	ans.y = md - ans.x;
	if(ans.x > ans.y) swap(ans.x, ans.y);
	return ans;
}
vcl operator %(vcl f, int n){
	vcl ans;
	for(int i = 0; i < min(n, (int)f.size()); i++) ans.psb(f[i]);
	return ans;
}
vcl operator *(vcl f, int x){
	for(int i = 0; i < f.size(); i++) f[i] *= x, f[i] %= md;
	return f;
}
vcl operator *(int x, vcl f){
	for(int i = 0; i < f.size(); i++) f[i] *= x, f[i] %= md;
	return f;
}
vcl operator +(vcl f, vcl g){
	while(f.size() < g.size()) f.psb(0);
	for(int i = 0; i < g.size(); i++) f[i] += g[i], f[i] %= md;
	return f;
}
vcl operator -(vcl f, vcl g){
	while(f.size() < g.size()) f.psb(0);
	for(int i = 0; i < g.size(); i++) f[i] += md - g[i], f[i] %= md;
	return f;
}
void ntt(ll *a, int n, int op){
	if(n == 1) return ;
	ll l[n / 2], r[n / 2];
	for(int i = 0; i < n; i += 2) l[i / 2] = a[i], r[i / 2] = a[i + 1];
	ntt(l, n / 2, op), ntt(r, n / 2, op);
	ll rt = 1;
	for(int i = 0; i < n / 2; i++){
		a[i] = (l[i] + rt * r[i]) % md, a[i + n / 2] = (l[i] - rt * r[i]) % md;
		rt = rt * un[n][op] % md;
	}
}
vcl operator *(vcl f, vcl g){
	int sz = f.size() + g.size() - 1, SZ = 1;
	while(SZ < sz) SZ *= 2;
	ll a[SZ], b[SZ];
	for(int i = 0; i < SZ; i++) a[i] = b[i] = 0;
	for(int i = 0; i < f.size(); i++) a[i] = f[i];
	for(int i = 0; i < g.size(); i++) b[i] = g[i];
	ntt(a, SZ, 0), ntt(b, SZ, 0);
	for(int i = 0; i < SZ; i++) a[i] = a[i] * b[i] % md;
	ntt(a, SZ, 1);
	ll ny = ksm(SZ, md - 2, md);
	vcl ans;
	for(int i = 0; i < sz; i++) ans.psb((a[i] * ny % md + md) % md);
	return ans;
}
vcl get_inv(ll *a, int n){
	vcl ans;
	if(n == 1){
		ans.psb(ksm(a[0], md - 2, md));
		return ans;
	}
	vcl f, res = get_inv(a, n / 2);
	for(int i = 0; i < n; i++) f.psb(a[i]);
	ans = ((res * 2) - (f * ((res * res) % n))) % n;
	return ans;
}
vcl INV(vcl f, int n){
	int SZ = 1;
	while(SZ < n) SZ *= 2;
	ll a[SZ];
	for(int i = 0; i < SZ; i++) a[i] = 0;
	for(int i = 0; i < min(n, (int)f.size()); i++) a[i] = f[i];
	vcl ans = get_inv(a, SZ) % n;
	for(ll i = 0; i < ans.size(); i++) ans[i] = (ans[i] % md + md) % md; 
	return ans;
}
vcl get_sqrt1(ll *a, int n){
	vcl ans;
	if(n == 1){
		ans.psb(SQRT(a[0], md).x);
		return ans;
	}
	vcl f, res = get_sqrt1(a, n / 2);
	for(int i = 0; i < n; i++) f.psb(a[i]);
	ans = INV(res * 2, n);
	ans = (res * res + f) * INV(res * 2, n) % n;
	return ans;
}
vcl get_sqrt2(ll *a, int n){
	vcl ans;
	if(n == 1){
		ans.psb(SQRT(a[0], md).y);
		return ans;
	}
	vcl f, res = get_sqrt2(a, n / 2);
	for(int i = 0; i < n; i++) f.psb(a[i]);
	ans = INV(res * 2, n);
	ans = (res * res + f) * INV(res * 2, n) % n;
	return ans;
}
vcl SQRT(vcl f, int n){
	int SZ = 1;
	while(SZ < n) SZ *= 2;
	ll a[SZ];
	for(int i = 0; i < SZ; i++) a[i] = 0;
	for(int i = 0; i < n; i++) a[i] = f[i];
	vcl ans = get_sqrt1(a, SZ) % n;
	for(ll i = 0; i < ans.size(); i++) ans[i] = (ans[i] % md + md) % md; 
	vcl res = get_sqrt2(a, SZ) % n;
	for(ll i = 0; i < res.size(); i++) res[i] = (res[i] % md + md) % md; 
	if(res[0] < ans[0]) return res;
	return ans;
}
int main(){
	for(int i = 1; i <= 500000; i *= 2) un[i][0] = ksm(3, (md - 1) / i, md), un[i][1] = ksm(un[i][0], md - 2, md);
	cin >> n;
	for(int i = 0, a; i < n; i++) cin >> a, f.psb(a);
	vcl ans = SQRT(f, n);
	for(int x : ans) cout << x << ' ';
	return 0;
}

回复

6 条回复,欢迎继续交流。

正在加载回复...