社区讨论

爆零求调

P4238【模板】多项式乘法逆参与者 1已保存回复 1

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
1 条
当前快照
1 份
快照标识符
@lo81605m
此快照首次捕获于
2023/10/27 11:05
2 年前
此快照最后确认于
2023/10/27 11:05
2 年前
查看原帖
rt/kel
犯了什么 sb 错误希望大佬轻喷
CPP
#include"iostream"
#include"cstring"
#include"cstdio"
#include"cmath"
using namespace std;

#define MAXN 100005
#define MOD 998244353
#define ll long long 

int N,M,n=1,l=0;
struct poly
{
	int a[MAXN*6];
	poly(){memset(a,0,sizeof(a));}
}f;
int res[MAXN*6],w[MAXN*6];

void p(poly g,int s)
{
	for(int j=0;j<s;j++) cout<<g.a[j]<<" ";
	cout<<endl;
}

int mul(int a,int b){return (ll)a*b%MOD;}

int quickpow(int a,int b)
{
	int ans=1,base=a;
	while(b)
	{
		if(b&1) ans=mul(ans,base);
		base=mul(base,base);
		b>>=1;
	}
	return ans;
}

int inv(int a){return quickpow(a,MOD-2);}

void NTT(int *a,int unit,int n,int l)
{
    w[0]=1,w[1]=quickpow(unit,(MOD-1)/n);
    for(int i=2;i<n;i++) w[i]=w[i]=mul(w[1],w[i-1]);
    for(int i=0;i<n;i++) if(i>res[i]) swap(a[i],a[res[i]]);
	for(int i=2;i<=l+1;i++)
    {
        int t=n>>(i-1);
        for(int j=1;j<=t;j++) 
        {
            int s=n/t;
            for(int k=0;k<(s>>1);k++)
            {
                int op=s*(j-1)+k,G=mul(a[op+(s>>1)],w[k*t]);
                a[op+(s>>1)]=(a[op]-G+MOD)%MOD;
                a[op]=(a[op]+G)%MOD; 
            }   
        }
    } 
}

poly poly_mul(poly F,poly G,int n,int l)
{
	poly h;
	for(int i=0;i<n;i++) res[i]=(res[i>>1]>>1)|((i&1)<<(l-1));
	NTT(F.a,3,n,l),NTT(G.a,3,n,l);
	for(int i=0;i<n;i++) F.a[i]=mul(F.a[i],G.a[i]);
	NTT(F.a,332748118,n,l);
	int m=inv(n);
	for(int i=0;i<n;i++) h.a[i]=mul(F.a[i],m);
	return h;
}

poly poly_times(poly F,int k,int n)
{
	poly h;
	for(int i=0;i<n;i++) h.a[i]=mul(F.a[i],k);
	return h;
}

poly poly_min(poly F,poly G,int n)
{
	poly h;
	for(int i=0;i<n;i++) h.a[i]=(F.a[i]-G.a[i]+MOD)%MOD;
	return h;
}

poly poly_inv(poly F)
{
	poly h,g;
	for(int i=0;i<n;i++) h.a[i]=0;
	h.a[0]=inv(F.a[0]);
	for(int i=1;i<l;i++)
	{
		int s=1<<i;
		g=poly_mul(h,h,s,i);
		g=poly_mul(g,F,s<<1,i+1);
		for(int j=n-1;j>=s;j--) g.a[j]=0;
		h=poly_min(poly_times(h,2,s>>1),g,s);
		for(int j=n-1;j>=s;j--) h.a[j]=0;
	}
	return h;
}

int main()
{
	scanf("%d",&N);
	while(n<=(N-1)*2) n<<=1,l++;
	for(int i=0;i<N;i++) scanf("%d",&f.a[i]);
	poly H=poly_inv(f);
	for(int i=0;i<N;i++) printf("%d ",H.a[i]);
	return puts(""),0;
}

回复

1 条回复,欢迎继续交流。

正在加载回复...