专栏文章
如何速通卷积?
算法·理论参与者 11已保存评论 10
文章操作
快速查看文章及其快照的属性,并进行相关操作。
- 当前评论
- 10 条
- 当前快照
- 1 份
- 快照标识符
- @mip2vz8a
- 此快照首次捕获于
- 2025/12/03 05:16 3 个月前
- 此快照最后确认于
- 2025/12/03 05:16 3 个月前
有的题里面卷积是必要的,不会卷积就可能被暴打。
本文旨在帮助和我一样没怎么学多项式的人速通卷积。
其中可能有一些定义和结论,你不需要关心其证明也可以学会卷积,因此本文中不会证明结论。
点值表示法
通过系数表示法给出两个多项式(即给出各项系数) ,求 即其乘积的各项系数。
结论 1:根据 次多项式 在 个不同 处的取值 可以唯一确定 。
定义 1:根据结论 1 可以用 个不同 处的取值表示一个 次多项式,将这种表示方法称为点值表示法。
因此可以先求出 在 个不同 处的取值,然后相乘即可得到 在 个不同 处的取值,再根据这些值求出 的各项系数。
于是现在问题变为了在系数表示法和点值表示法之间快速转化。
系数表示法 -> 点值表示法
直接暴力算即可做到 ,但是显然不够快。
设 ,即将其偶数次系数和奇数次系数分别拿出来组成新的多项式 。
那么只要快速合并即可分治,为了分治可以将项数补到最小且 的 的整数次幂 ,但是合并好像很难。
单位根
不过注意到选的数是没有任何限制的,所以不妨找一些有特殊性质的数使其能够快速合并。
定义 2:令平面直角坐标系上的点 表示 ,其中 是虚数单位满足 ,将这个平面直角坐标系称为复平面。
复数运算:
CPPtypedef double db;
struct cpx{
db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}
定义 3:将平面直角坐标系上以原点为圆心单位长度为半径的圆称为单位圆。
定义 4:将复平面上的单位圆平均分为 段且 为其中一个分段点,将从 开始逆时针走到的第 个分段点表示的数称为 。根据三角函数基础知识,可知
结论 2: 对应从 开始逆时针走到的第 个分段点。
结论 3:当 时,。
快速合并
不难发现 有一些良好性质,因此考虑令 。
于是可以注意到当 时,。
因此只需求出 即可,直接分治即可,时间复杂度 。
卡常
首先要把递归写成循环形式。
考虑将往下分的过程优化。(此过程中需要将偶数次系数和奇数次系数分到两边)
定义 5:将 在这个过程结束后移到的位置称为 。
结论 4: 即为 的二进制表示将前 位 reverse 得到的数。
因此有递推式:。
于是可以将该过程优化到线性。
CPPconst db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
for(int k=2;k<=len;k<<=1){
cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
for(int i=0;i<len;i+=k){
cpx x={1,0};
for(int j=0;j<(k>>1);j++){
cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
x=x*w;
}
}
}
}
点值表示法 -> 系数表示法
直接根据上面代码倒推即可。
CPPvoid ifft(int len,cpx *a){
for(int k=len;k>=2;k>>=1){
cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
for(int i=0;i<len;i+=k){
cpx x={1,0};
for(int j=0;j<(k>>1);j++){
cpx p=a[i+j],q=a[i+j+(k>>1)];
a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
x=x*w;
}
}
}
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
于是我们已经可以写出卷积代码了:
CPPvoid convolution_fft(int n,ll *A,int m,ll *B,ll *C){
int len=1;
while(len<=(n+m))len<<=1;
rep(i,0,len-1)a[i]={(db)A[i],0.0};
rep(i,0,len-1)b[i]={(db)B[i],0.0};
fft(len,a);
fft(len,b);
rep(i,0,len-1)c[i]=a[i]*b[i];
ifft(len,c);
rep(i,0,len-1)C[i]=(ll)round(c[i].x);
}
三次变两次优化
原理:
于是可以将 的系数分别放在实部和虚部,求平方后虚部除以 便是 。
CPPcpx a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
int len=1;
while(len<=(n+m))len<<=1;
rep(i,0,len-1)a[i]={(db)A[i],(db)B[i]};
fft(len,a);
rep(i,0,len-1)a[i]=a[i]*a[i];
ifft(len,a);
rep(i,0,len-1)C[i]=(ll)round(a[i].y/2.0);
}
考虑模意义
显然三角函数与浮点数运算会产生精度误差,同时大多数情况下都是在特定模意义下使用卷积,因此考虑使用整数代替这些浮点数运算,只需要在特定模意义中找到和单位根有类似性质的数即可。
可以将 分解,使用 CRT 合并即可。
一般 有 的较高整数次幂因子时可以使用。
原根
定义 6:对于奇质数 ,将满足 互不相同的 称为其原根。
结论 5:若 存在原根,则其最小原根是 的。
结论 6:若 不为原根,则 。
于是可以暴力枚举找最小原根。
的原根是 。
代替单位根
结论 7:。
因此考虑令 。
于是可以注意到当 时,。
因此只需求出 即可,直接分治即可,时间复杂度 。
CPPconst ll mod=998244353;
const ll I2=(mod+1)/2;
const ll G=3;
ll ksm(ll a,ll b,ll p){
a=a%p;
ll r=1;
while(b){
if(b&1)r=r*a%p;
a=a*a%p;
b>>=1;
}
return r%p;
}
const ll IG=ksm(G,mod-2,mod);
void ntt(int len,ll *a){
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
for(int k=2;k<=len;k<<=1){
ll w=ksm(G,(mod-1)/k,mod);
for(int i=0;i<len;i+=k){
ll x=1;
for(int j=0;j<(k>>1);j++){
ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod;
a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod;
x=x*w%mod;
}
}
}
}
void intt(int len,ll *a){
for(int k=len;k>=2;k>>=1){
ll w=ksm(IG,(mod-1)/k,mod);
for(int i=0;i<len;i+=k){
ll x=1;
for(int j=0;j<(k>>1);j++){
ll p=a[i+j],q=a[i+j+(k>>1)];
a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod;
x=x*w%mod;
}
}
}
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
ll ntt_a[N],ntt_b[N],ntt_c[N];
void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){
int len=1;
while(len<=(n+m))len<<=1;
rep(i,0,len-1)ntt_a[i]=A[i];
rep(i,0,len-1)ntt_b[i]=B[i];
ntt(len,ntt_a);
ntt(len,ntt_b);
rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod;
intt(len,ntt_c);
rep(i,0,len-1)C[i]=ntt_c[i];
}
模板题代码
CPP#include<bits/stdc++.h>
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define per(i,r,l) for(int i=(r);i>=(l);i--)
#define repll(i,l,r) for(ll i=(l);i<=(r);i++)
#define perll(i,r,l) for(ll i=(r);i>=(l);i--)
#define pb push_back
#define ins insert
#define clr clear
using namespace std;
namespace ax_by_c{
typedef long long ll;
const int N=4e6+5;
namespace Bpoly{
typedef double db;
struct cpx{
db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}
const db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
for(int k=2;k<=len;k<<=1){
cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
for(int i=0;i<len;i+=k){
cpx x={1,0};
for(int j=0;j<(k>>1);j++){
cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
x=x*w;
}
}
}
}
void ifft(int len,cpx *a){
for(int k=len;k>=2;k>>=1){
cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
for(int i=0;i<len;i+=k){
cpx x={1,0};
for(int j=0;j<(k>>1);j++){
cpx p=a[i+j],q=a[i+j+(k>>1)];
a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
x=x*w;
}
}
}
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
cpx fft_a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
int len=1;
while(len<=(n+m))len<<=1;
rep(i,0,len-1)fft_a[i]={(db)A[i],(db)B[i]};
fft(len,fft_a);
rep(i,0,len-1)fft_a[i]=fft_a[i]*fft_a[i];
ifft(len,fft_a);
rep(i,0,len-1)C[i]=(ll)round(fft_a[i].y/2.0);
}
const ll mod=998244353;
const ll I2=(mod+1)/2;
const ll G=3;
ll ksm(ll a,ll b,ll p){
a=a%p;
ll r=1;
while(b){
if(b&1)r=r*a%p;
a=a*a%p;
b>>=1;
}
return r%p;
}
const ll IG=ksm(G,mod-2,mod);
void ntt(int len,ll *a){
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
for(int k=2;k<=len;k<<=1){
ll w=ksm(G,(mod-1)/k,mod);
for(int i=0;i<len;i+=k){
ll x=1;
for(int j=0;j<(k>>1);j++){
ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod;
a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod;
x=x*w%mod;
}
}
}
}
void intt(int len,ll *a){
for(int k=len;k>=2;k>>=1){
ll w=ksm(IG,(mod-1)/k,mod);
for(int i=0;i<len;i+=k){
ll x=1;
for(int j=0;j<(k>>1);j++){
ll p=a[i+j],q=a[i+j+(k>>1)];
a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod;
x=x*w%mod;
}
}
}
rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
ll ntt_a[N],ntt_b[N],ntt_c[N];
void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){
int len=1;
while(len<=(n+m))len<<=1;
rep(i,0,len-1)ntt_a[i]=A[i];
rep(i,0,len-1)ntt_b[i]=B[i];
ntt(len,ntt_a);
ntt(len,ntt_b);
rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod;
intt(len,ntt_c);
rep(i,0,len-1)C[i]=ntt_c[i];
}
};
int n,m;
ll a[N],b[N],c[N];
void slv(int _csid,int _csi){
scanf("%d %d",&n,&m);
rep(i,0,n)scanf("%lld",&a[i]);
rep(i,0,m)scanf("%lld",&b[i]);
// Bpoly::convolution_fft(n,a,m,b,c);
Bpoly::convolution_ntt(n,a,m,b,c);
rep(i,0,n+m)printf("%lld ",c[i]);
}
void main(){
// ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
int T=1,csid=0;
// scanf("%d",&csid);
// scanf("%d",&T);
rep(i,1,T)slv(csid,i);
}
}
int main(){
string __name="";
if(__name!=""){
freopen((__name+".in").c_str(),"r",stdin);
freopen((__name+".out").c_str(),"w",stdout);
}
ax_by_c::main();
return 0;
}
相关推荐
评论
共 10 条评论,欢迎与作者交流。
正在加载评论...