专栏文章

如何速通卷积?

算法·理论参与者 11已保存评论 10

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
10 条
当前快照
1 份
快照标识符
@mip2vz8a
此快照首次捕获于
2025/12/03 05:16
3 个月前
此快照最后确认于
2025/12/03 05:16
3 个月前
查看原文
有的题里面卷积是必要的,不会卷积就可能被暴打。
本文旨在帮助和我一样没怎么学多项式的人速通卷积。
其中可能有一些定义和结论,你不需要关心其证明也可以学会卷积,因此本文中不会证明结论。

点值表示法

通过系数表示法给出两个多项式(即给出各项系数) f(x)=a0x0++anxn,g(x)=b0x0++bmxmf(x)=a_0x^0+\dots+a_nx^n,g(x)=b_0x^0+\dots+b_mx^m,求 h(x)=f(x)g(x)=c0x0++cn+mxn+mh(x)=f(x)g(x)=c_0x^0+\dots+c_{n+m}x^{n+m} 即其乘积的各项系数。
结论 1:根据 nn 次多项式 f(x)f(x)n+1n+1 个不同 xx 处的取值 (x1,y1),(x2,y2),,(xn+1,yn+1)(x_1,y_1),(x_2,y_2),\dots,(x_{n+1},y_{n+1}) 可以唯一确定 f(x)f(x)
定义 1:根据结论 1 可以用 n+1n+1 个不同 xx 处的取值表示一个 nn 次多项式,将这种表示方法称为点值表示法。
因此可以先求出 f(x),g(x)f(x),g(x)n+m+1n+m+1 个不同 xx 处的取值,然后相乘即可得到 h(x)h(x)n+m+1n+m+1 个不同 xx 处的取值,再根据这些值求出 h(x)h(x) 的各项系数。
于是现在问题变为了在系数表示法和点值表示法之间快速转化。

系数表示法 -> 点值表示法

直接暴力算即可做到 O((n+m)2)O((n+m)^2),但是显然不够快。
f(x)=f0(x2)+xf1(x2)f(x)=f_0(x^2)+xf_1(x^2),即将其偶数次系数和奇数次系数分别拿出来组成新的多项式 f0(x),f1(x)f_0(x),f_1(x)
那么只要快速合并即可分治,为了分治可以将项数补到最小且 >n+m>n+m22 的整数次幂 2p2^p,但是合并好像很难。

单位根

不过注意到选的数是没有任何限制的,所以不妨找一些有特殊性质的数使其能够快速合并。
定义 2:令平面直角坐标系上的点 (x,y)(x,y) 表示 x+iyx+iy,其中 ii 是虚数单位满足 i2=1i^2=-1,将这个平面直角坐标系称为复平面。
复数运算:
CPP
typedef double db;
struct cpx{
	db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
	return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
	return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
	return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
	return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}
定义 3:将平面直角坐标系上以原点为圆心单位长度为半径的圆称为单位圆。
定义 4:将复平面上的单位圆平均分为 n(n2)n(n\ge2) 段且 (1,0)(1,0) 为其中一个分段点,将从 (1,0)(1,0) 开始逆时针走到的第 22 个分段点表示的数称为 ωn\omega_n。根据三角函数基础知识,可知 ωn=cos2π2p+isin2π2p\omega_n=\cos\frac{2\pi}{2^p}+i\sin\frac{2\pi}{2^p}
结论 2:ωnk\omega_n^k 对应从 (1,0)(1,0) 开始逆时针走到的第 k+1k+1 个分段点。
结论 3:当 2n2\mid n 时,ωnk+n2=ωnk-\omega_n^{k+\frac{n}{2}}=\omega_n^k

快速合并

不难发现 ωnk\omega_n^k 有一些良好性质,因此考虑令 xi=ω2pi1x_i=\omega_{2^p}^{i-1}
于是可以注意到当 j>2p1j>2^{p-1} 时,f(xj)=f0(xj2)+xjf1(xj2)=f0(xj2p12)xj2p1f1(xj2p12)f(x_j)=f_0(x_j^2)+x_jf_1(x_j^2)=f_0(x_{j-2^{p-1}}^2)-x_{j-2^{p-1}}f_1(x_{j-2^{p-1}}^2)
因此只需求出 f0(x1),,f0(x2p1),f1(x1),,f1(x2p1)f_0(x_1),\dots,f_0(x_{2^{p-1}}),f_1(x_1),\dots,f_1(x_{2^{p-1}}) 即可,直接分治即可,时间复杂度 O(2pp)=O((n+m)log(n+m))O(2^pp)=O((n+m)\log(n+m))

卡常

首先要把递归写成循环形式。
考虑将往下分的过程优化。(此过程中需要将偶数次系数和奇数次系数分到两边)
定义 5:将 ii 在这个过程结束后移到的位置称为 toito_i
结论 4:toito_i 即为 ii 的二进制表示将前 pp 位 reverse 得到的数。
因此有递推式:toi=toi22+[2i]2p1to_i=\lfloor\frac{to_{\lfloor\frac{i}{2}\rfloor}}{2}\rfloor+[2\nmid i]2^{p-1}
于是可以将该过程优化到线性。
CPP
const db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
				a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
				x=x*w;
			}
		}
	}
}

点值表示法 -> 系数表示法

直接根据上面代码倒推即可。
CPP
void ifft(int len,cpx *a){
	for(int k=len;k>=2;k>>=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
				x=x*w;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
于是我们已经可以写出卷积代码了:
CPP
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)a[i]={(db)A[i],0.0};
	rep(i,0,len-1)b[i]={(db)B[i],0.0};
	fft(len,a);
	fft(len,b);
	rep(i,0,len-1)c[i]=a[i]*b[i];
	ifft(len,c);
	rep(i,0,len-1)C[i]=(ll)round(c[i].x);
}

三次变两次优化

原理:(a+bi)2=(a2b2)+(2ab)i(a+bi)^2=(a^2-b^2)+(2ab)i
于是可以将 f(x),g(x)f(x),g(x) 的系数分别放在实部和虚部,求平方后虚部除以 22 便是 h(x)h(x)
CPP
cpx a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)a[i]={(db)A[i],(db)B[i]};
	fft(len,a);
	rep(i,0,len-1)a[i]=a[i]*a[i];
	ifft(len,a);
	rep(i,0,len-1)C[i]=(ll)round(a[i].y/2.0);
}

考虑模意义

显然三角函数与浮点数运算会产生精度误差,同时大多数情况下都是在特定模意义下使用卷积,因此考虑使用整数代替这些浮点数运算,只需要在特定模意义中找到和单位根有类似性质的数即可。
可以将 modmod 分解,使用 CRT 合并即可。
一般 p1p-122 的较高整数次幂因子时可以使用。

原根

定义 6:对于奇质数 pp,将满足 g1,,gp1g^1,\dots,g^{p-1} 互不相同的 gg 称为其原根。
结论 5:若 nn 存在原根,则其最小原根是 O(n14)O(n^\frac{1}{4}) 的。
结论 6:若 xx 不为原根,则 y,xp1y1(modp)\exists y,x^{\frac{p-1}{y}}\equiv 1 \pmod p
于是可以暴力枚举找最小原根。
998244353998244353 的原根是 33

代替单位根

结论 7:gp12p1(modp)g^{\frac{p-1}{2}}\equiv p-1\pmod p
因此考虑令 xi=(gmod12p)i1x_i=(g^{\frac{mod-1}{2^p}})^{i-1}
于是可以注意到当 j>2p1j>2^{p-1} 时,f(xj)f0(xj2)+xjf1(xj2)f0(xj2p12)xj2p1f1(xj2p12)f(x_j)\equiv f_0(x_j^2)+x_jf_1(x_j^2)\equiv f_0(x_{j-2^{p-1}}^2)-x_{j-2^{p-1}}f_1(x_{j-2^{p-1}}^2)
因此只需求出 f0(x1),,f0(x2p1),f1(x1),,f1(x2p1)f_0(x_1),\dots,f_0(x_{2^{p-1}}),f_1(x_1),\dots,f_1(x_{2^{p-1}}) 即可,直接分治即可,时间复杂度 O(2pp)=O((n+m)log(n+m))O(2^pp)=O((n+m)\log(n+m))
CPP
const ll mod=998244353;
const ll I2=(mod+1)/2;
const ll G=3;
ll ksm(ll a,ll b,ll p){
	a=a%p;
	ll r=1;
	while(b){
		if(b&1)r=r*a%p;
		a=a*a%p;
		b>>=1;
	}
	return r%p;
}
const ll IG=ksm(G,mod-2,mod);
void ntt(int len,ll *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		ll w=ksm(G,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod;
				a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod;
				x=x*w%mod;
			}
		}
	}
}
void intt(int len,ll *a){
	for(int k=len;k>=2;k>>=1){
		ll w=ksm(IG,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod;
				x=x*w%mod;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
ll ntt_a[N],ntt_b[N],ntt_c[N];
void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)ntt_a[i]=A[i];
	rep(i,0,len-1)ntt_b[i]=B[i];
	ntt(len,ntt_a);
	ntt(len,ntt_b);
	rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod;
	intt(len,ntt_c);
	rep(i,0,len-1)C[i]=ntt_c[i];
}

模板题代码

CPP
#include<bits/stdc++.h>
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define per(i,r,l) for(int i=(r);i>=(l);i--)
#define repll(i,l,r) for(ll i=(l);i<=(r);i++)
#define perll(i,r,l) for(ll i=(r);i>=(l);i--)
#define pb push_back
#define ins insert
#define clr clear
using namespace std;
namespace ax_by_c{
typedef long long ll;
const int N=4e6+5;
namespace Bpoly{
typedef double db;
struct cpx{
	db x,y;
};
cpx operator + (const cpx &a,const cpx &b){
	return {a.x+b.x,a.y+b.y};
}
cpx operator - (const cpx &a,const cpx &b){
	return {a.x-b.x,a.y-b.y};
}
cpx operator * (const cpx &a,const cpx &b){
	return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
cpx operator / (const cpx &a,const cpx &b){
	return {(a.x*b.x+a.y*b.y)/(b.x*b.x+b.y*b.y),(a.y*b.x-a.x*b.y)/(b.x*b.x+b.y*b.y)};
}
const db PI=acos(-1.0);
int to[N];
void fft(int len,cpx *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)]*x;
				a[i+j]=p+q,a[i+j+(k>>1)]=p-q;
				x=x*w;
			}
		}
	}
}
void ifft(int len,cpx *a){
	for(int k=len;k>=2;k>>=1){
		cpx w={cos(PI*2.0/k),sin(PI*2.0/k)};
		for(int i=0;i<len;i+=k){
			cpx x={1,0};
			for(int j=0;j<(k>>1);j++){
				cpx p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)/(cpx){2,0},a[i+j+(k>>1)]=(p-q)/(cpx){2,0}/x;
				x=x*w;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
cpx fft_a[N];
void convolution_fft(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)fft_a[i]={(db)A[i],(db)B[i]};
	fft(len,fft_a);
	rep(i,0,len-1)fft_a[i]=fft_a[i]*fft_a[i];
	ifft(len,fft_a);
	rep(i,0,len-1)C[i]=(ll)round(fft_a[i].y/2.0);
}
const ll mod=998244353;
const ll I2=(mod+1)/2;
const ll G=3;
ll ksm(ll a,ll b,ll p){
	a=a%p;
	ll r=1;
	while(b){
		if(b&1)r=r*a%p;
		a=a*a%p;
		b>>=1;
	}
	return r%p;
}
const ll IG=ksm(G,mod-2,mod);
void ntt(int len,ll *a){
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
	for(int k=2;k<=len;k<<=1){
		ll w=ksm(G,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)]*x%mod;
				a[i+j]=(p+q)%mod,a[i+j+(k>>1)]=(p-q+mod)%mod;
				x=x*w%mod;
			}
		}
	}
}
void intt(int len,ll *a){
	for(int k=len;k>=2;k>>=1){
		ll w=ksm(IG,(mod-1)/k,mod);
		for(int i=0;i<len;i+=k){
			ll x=1;
			for(int j=0;j<(k>>1);j++){
				ll p=a[i+j],q=a[i+j+(k>>1)];
				a[i+j]=(p+q)*I2%mod,a[i+j+(k>>1)]=(p-q+mod)*I2%mod*x%mod;
				x=x*w%mod;
			}
		}
	}
	rep(i,0,len-1)to[i]=(to[i>>1]>>1)|((i&1)*(len>>1));
	rep(i,0,len-1)if(i<to[i])swap(a[i],a[to[i]]);
}
ll ntt_a[N],ntt_b[N],ntt_c[N];
void convolution_ntt(int n,ll *A,int m,ll *B,ll *C){
	int len=1;
	while(len<=(n+m))len<<=1;
	rep(i,0,len-1)ntt_a[i]=A[i];
	rep(i,0,len-1)ntt_b[i]=B[i];
	ntt(len,ntt_a);
	ntt(len,ntt_b);
	rep(i,0,len-1)ntt_c[i]=ntt_a[i]*ntt_b[i]%mod;
	intt(len,ntt_c);
	rep(i,0,len-1)C[i]=ntt_c[i];
}
};
int n,m;
ll a[N],b[N],c[N];
void slv(int _csid,int _csi){
	scanf("%d %d",&n,&m);
	rep(i,0,n)scanf("%lld",&a[i]);
	rep(i,0,m)scanf("%lld",&b[i]);
//	Bpoly::convolution_fft(n,a,m,b,c);
	Bpoly::convolution_ntt(n,a,m,b,c);
	rep(i,0,n+m)printf("%lld ",c[i]);
}
void main(){
//	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	int T=1,csid=0;
//	scanf("%d",&csid);
//	scanf("%d",&T);
	rep(i,1,T)slv(csid,i);
}
}
int main(){
	string __name="";
	if(__name!=""){
		freopen((__name+".in").c_str(),"r",stdin);
		freopen((__name+".out").c_str(),"w",stdout);
	}
	ax_by_c::main();
	return 0;
}

评论

10 条评论,欢迎与作者交流。

正在加载评论...