社区讨论

为什么这个NTT这么慢啊qwq

P3803【模板】多项式乘法(FFT)参与者 2已保存回复 1

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
1 条
当前快照
1 份
快照标识符
@mi6tvfx8
此快照首次捕获于
2025/11/20 10:44
4 个月前
此快照最后确认于
2025/11/20 10:44
4 个月前
查看原帖
这有一份O2之后1.8s巨慢无比的NTT代码

有没有dalao能帮忙改一下啊qwq难道是结构体这种问题吗orz

CPP
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#define neko 3000010
#define f(i,a,b) for(register int i=(a);i<=(b);++i)
int a[neko],b[neko];
int rev[neko],n1,n2,mod=998244353,gen=3;
    int cnt=0,inv=0;
    struct Num
    {
        int n,m;
    }F;
    void init()
    {
        F.m=n1+n2;
        for(F.n=1;F.n<=F.m;F.n<<=1)++cnt;--cnt;
        f(i,1,F.n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<cnt);
    }
    inline void swap(int &x,int &y)
    {int t=x;x=y,y=t;}
    inline int slowpow(int m,int n)
    {
        static int bb;
        for(bb=1;n;n>>=1,m=1ll*m*m%mod)if(n&1)bb=1ll*bb*m%mod;
        return bb;
    }
    inline void NTT(int *p,int opt)
    {
        f(i,1,F.n-1)if(i<rev[i])swap(p[i],p[rev[i]]);
        static int gi,times,x;
        for(register int i=2;i<=F.n;i<<=1)
        {
            gi=slowpow(gen,(mod-1)/i),times=i>>1;
            for(register int j=0;j<F.n;j+=i)
            {
                x=1;
                for(register int k=0;k<times;++k,x=1ll*x*gi%mod)
                {
                    int u=p[j+k],v=1ll*x*p[j+k+times]%mod;
                    p[j+k]=(u+v)%mod,p[j+k+times]=(u-v+mod)%mod;
                }
            }
        }
        if(opt==-1)
        {
            std::reverse(p+1,p+F.n);
            if(!inv)inv=slowpow(F.n,mod-2);
            f(i,0,F.n-1)p[i]=1ll*p[i]*inv%mod;
        }
    }
    void work()
    {
        init();
        NTT(a,1),NTT(b,1);
        f(i,0,F.n-1)a[i]=1ll*a[i]*b[i]%mod;
        NTT(a,-1);
        f(i,0,F.m)printf("%d ",a[i]);
        putchar('\n');
    }
char gc()
{
    static char buf[262144],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,262144,stdin),p1==p2)?EOF:*p1++;
}
void read(int &x)
{
    char c=gc();x=0;
    while(!isdigit(c))c=gc();
    while(isdigit(c))x=(x<<1)+(x<<3)+(c^'0'),c=gc();
}
int main()
{
    read(n1),read(n2);
    f(i,0,n1)read(a[i]);
    f(i,0,n2)read(b[i]);
    work();
}

回复

1 条回复,欢迎继续交流。

正在加载回复...