专栏文章

题解:AT_arc209_e [ARC209E] I hate ABC

AT_arc209_e题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@min96lyc
此快照首次捕获于
2025/12/01 22:37
3 个月前
此快照最后确认于
2025/12/01 22:37
3 个月前
查看原文

Solution

考虑到 NK100N-K\le 100,我们尝试将 KK 转化为 NKN-K。将价值改写为删除尽量少的项使得原串不含 ABC\tt ABC 子序列,则相当于问价值等于 NKN-K 的串有多少个,这就完成了 KNKK\gets N-K 的转化(下文认为 KNKK\gets N-K)。
考虑如何判断一个串是否不含任何 ABC\tt ABC 子序列。进行一些感受后可以发现这相当于能够将原串划分为三段,第一段不含 A\tt A,第二段不含 B\tt B,且第三段不含 C\tt C。据此我们可以写出计算一个序列的价值的代码:
CPP
int A=0,B=0,C=0;
for(int i=1;i<=N;i++){
    if(s[i]=='A')A=A+1;
    if(s[i]=='B')B=min(B+1,A);
    if(s[i]=='C')C=min(C+1,B);
}
其中变量 A,B,CA,B,C 分别表示目前划分了 1,2,31,2,3 段的代价。
尝试计算价值 K\ge K 的序列数量,一步差分即可求得原答案。将原代码改写为
CPP
int A=0,B=0,C=0;
for(int i=1;i<=N;i++){
    if(s[i]=='A')A=min(A+1,K);
    if(s[i]=='B')B=min(B+1,A);
    if(s[i]=='C')C=min(C+1,B);
}
若程序结束后有 A=B=C=KA=B=C=K 则说明序列 ss 的价值 K\ge K
考察三元组 (A,B,C)(A,B,C) 的变化,不难发现它会改变恰好 3K3K 次。考察在相邻两次变化中间出现的字符,不难发现这些字符均不会使 (A,B,C)(A,B,C) 发生变化(废话)。设判定过程中有 1,2,31,2,3 种字符不使三元组发生变化的三元组分别出现了 p,q,rp,q,r 次,则一个三元组序列对答案的贡献就是
[xN3K]1(1x)p(12x)q(13x)r[x^{N-3K}]\frac{1}{(1-x)^p(1-2x)^q(1-3x)^r}
不难发现我们有 p,q3Kp,q\le 3Kr=1r=1(这是因为当且仅当 A=B=C=KA=B=C=K 时三种字符均不会使三元组发生变化)。我们将原序列通分为
[xN3K]f(x)(1x)3K(12x)3K(13x)[x^{N-3K}]\frac{f(x)}{(1-x)^{3K}(1-2x)^{3K}(1-3x)}
则只要求出了 f(x)f(x) 的和我们就能获得答案的表达式,这个显然可以 O(K4)O(K^4) 计算。
但是这样之后求答案还是有些困难。考虑分式分解,将原式转化为
f(x)(1x)3K(12x)3K(13x)=f1(x)(1x)3K+f2(x)(12x)3K+f3(x)13x\frac{f(x)}{(1-x)^{3K}(1-2x)^{3K}(1-3x)}=\frac{f_1(x)}{(1-x)^{3K}}+\frac{f_2(x)}{(1-2x)^{3K}}+\frac{f_3(x)}{1-3x}
则只要我们求出了 f1(x),f2(x),f3(x)f_1(x),f_2(x),f_3(x) 就容易在 O(K)O(K) 时间内回答单组询问。接下来有两种求 f1(x),f2(x),f3(x)f_1(x),f_2(x),f_3(x) 的方法:
  1. 高斯消元。将上述等式乘以 (1x)3K(12x)3K(13x)(1-x)^{3K}(1-2x)^{3K}(1-3x) 后对照系数可以获得 O(K)O(K) 个方程,高消即可求解。这样做是 O(K4)O(K^4) 的,但因为常数原因可能过不了。
  2. 考虑 CRT。根据 CRT,我们容易将限制转化为如下三个方程:
    {f1(x)(12x)3K(13x)f(x)(mod(1x)3K)f2(x)(1x)3K(13x)f(x)(mod(12x)3K)f3(x)(1x)3K(12x)3Kf(x)(mod(13x))\begin{cases}f_1(x)(1-2x)^{3K}(1-3x)\equiv f(x)\pmod{(1-x)^{3K}}\\f_2(x)(1-x)^{3K}(1-3x)\equiv f(x)\pmod{(1-2x)^{3K}}\\f_3(x)(1-x)^{3K}(1-2x)^{3K}\equiv f(x)\pmod{(1-3x)}\end{cases}
    则关键在于求出形如 (1wx)n(1-wx)^n 的多项式在 mod(1sx)m\bmod (1-sx)^m 意义下的逆元。一种方法是 exgcd 算,简单分析一下可以发现一次求逆的复杂度是 O(K2)O(K^2) 的;另一种方法是换元:设 y=1sxy=1-sx,则相当于计算 (ky+b)nmodym(ky+b)^{-n}\bmod y^m,这是容易计算的,算完后将 yy 换回去即可,一次求逆的复杂度仍然是 O(K2)O(K^2) 的。这部分的复杂度即为 O(K3)O(K^3)
综上,我们可以在 O(K4+TK)O(K^4+TK) 的复杂度内解决原问题。

Code

CPP
bool Mst;
#include<bits/stdc++.h>
using namespace std;
using ui=unsigned int;
using ll=long long;
using ull=unsigned long long;
using i128=__int128;
using u128=__uint128_t;
using pii=pair<int,int>;
#define fi first
#define se second
constexpr int N=1e6+105,K=105,mod=998244353;
inline ll add(ll x,ll y){return (x+=y)>=mod&&(x-=mod),x;}
inline ll Add(ll &x,ll y){return x=add(x,y);}
inline ll sub(ll x,ll y){return (x-=y)<0&&(x+=mod),x;}
inline ll Sub(ll &x,ll y){return x=sub(x,y);}
inline ll qpow(ll a,ll b){
	ll res=1;
	for(;b;b>>=1,a=a*a%mod)
		if(b&1)res=res*a%mod;
	return res;
}
using poly=vector<ll>;
const poly w1=poly{1,mod-1},w2=poly{1,mod-2},w3=poly{1,mod-3};
inline ostream& operator <<(ostream &ouf,const poly &f){
	ouf<<'{';
	if(f.size()){
		ouf<<f[0];
		for(int i=1;i<(int)f.size();i++)ouf<<", "<<f[i];
	}
	ouf<<'}';
	return ouf;
}
inline void shrink(poly &f){
	while(f.size()&&!f.back())
		f.pop_back();
}
inline poly operator +(const poly &f,const poly &g){
	poly h(max(f.size(),g.size()));
	for(int i=0;i<(int)f.size();i++)Add(h[i],f[i]);
	for(int i=0;i<(int)g.size();i++)Add(h[i],g[i]);
	return shrink(h),h;
}
inline poly operator -(const poly &f,const poly &g){
	poly h(max(f.size(),g.size()));
	for(int i=0;i<(int)f.size();i++)Add(h[i],f[i]);
	for(int i=0;i<(int)g.size();i++)Sub(h[i],g[i]);
	return shrink(h),h;
}
inline poly operator *(const poly &f,const poly &g){
	if(!f.size()||!g.size())return poly{};
	poly h(f.size()+g.size()-1);
	for(int i=0;i<(int)f.size();i++)
		for(int j=0;j<(int)g.size();j++)
			Add(h[i+j],f[i]*g[j]%mod);
	return shrink(h),h;
}
inline pair<poly,poly> div(const poly &f,const poly &g){
	if(f.size()<g.size())return make_pair(poly{},f);
	int lf=f.size(),lg=g.size();
	poly p(f.size()-g.size()+1),q(f);
	ll inv=qpow(g.back(),mod-2);
	for(int i=lf-1;i>=lg-1;i--){
		if(!q[i])continue;
		ll coef=q[i]*inv%mod;p[i-lg+1]=coef;
		for(int j=0;j<lg;j++)
			Sub(q[i-lg+j+1],coef*g[j]%mod);
	}
	return shrink(p),shrink(q),make_pair(p,q);
}
inline poly operator /(const poly &f,const poly &g){
	return div(f,g).fi;
}
inline poly operator %(const poly &f,const poly &g){
	return div(f,g).se;
}
inline void operator +=(poly &f,const poly &g){f=f+g;}
inline void operator -=(poly &f,const poly &g){f=f-g;}
inline void operator *=(poly &f,const poly &g){f=f*g;}
inline void operator /=(poly &f,const poly &g){f=f/g;}
inline void operator %=(poly &f,const poly &g){f=f%g;}
inline void exgcd(const poly &a,const poly &b,poly &x,poly &y){
	if(!b.size())
		x=poly{qpow(a[0],mod-2)},y=poly{};
	else{
		pair<poly,poly> o=div(a,b);
		exgcd(b,o.se,y,x),y-=o.fi*x;
	}
}
inline poly inv(const poly &p,const poly &q){
	poly x,y;
	exgcd(p,q,x,y);
	return x;
}
inline void decomp(const poly &f,const poly &p,const poly &q,const poly &r,poly &fp,poly &fq,poly &fr){
	fp=f%p*inv(q,p)%p*inv(r,p)%p;
	fq=f%q*inv(p,q)%q*inv(r,q)%q;
	fr=f%r*inv(p,r)%r*inv(q,r)%r;
}
ll fac[N],ifac[N],pw2[N],pw3[N],ipw2[N],ipw3[N];
inline void init(int n){
	fac[0]=1;
	for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
	ifac[n]=qpow(fac[n],mod-2);
	for(int i=n;i>=1;i--)ifac[i-1]=ifac[i]*i%mod;
	pw2[0]=1;
	for(int i=1;i<=n;i++)pw2[i]=pw2[i-1]*2%mod;
	ipw2[n]=qpow(pw2[n],mod-2);
	for(int i=n;i>=1;i--)ipw2[i-1]=ipw2[i]*2%mod;
	pw3[0]=1;
	for(int i=1;i<=n;i++)pw3[i]=pw3[i-1]*3%mod;
	ipw3[n]=qpow(pw3[n],mod-2);
	for(int i=n;i>=1;i--)ipw3[i-1]=ipw3[i]*3%mod;
}
inline ll binom(int n,int m){
	if(m<0||m>n)return 0;
	return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
poly f[K][K],g[K][K];
poly f1[K],f2[K],f3[K];
inline void Init(int n){
	f[0][0]=poly{1};
	poly p1{1},p2{1},p3=w3,q1=w1*w1*w1,q2=w2*w2*w2,q=q1*q2,qw1=q/w1,qw2=q/w2;
	for(int i=0;i<=n;i++,p1*=q1,p2*=q2){
		for(int j=0;j<=i;j++)
			for(int k=0;k<=j;k++)
				g[j][k]=f[j][k];
		for(int j=0;j<=i;j++)
			for(int k=0;k<=i;k++){
				int c=3-(j<i)-(k<j);
				if(c==1)g[j][k]/=w1;
				if(c==2)g[j][k]/=w2;
				if(j<i)g[j+1][k]+=g[j][k];
				if(k<j)g[j][k+1]+=g[j][k];
			}
		decomp(g[i][i],p1,p2,p3,f1[i],f2[i],f3[i]);
		for(int j=0;j<=i;j++)
			for(int k=0;k<=j;k++)
				g[j][k]=f[j][k]*q;
		for(int j=0;j<=i;j++)
			for(int k=0;k<=i;k++){
				int c=3-1-(j<i)-(k<j);
				if(c==1)g[j][k]/=w1;
				if(c==2)g[j][k]/=w2;
				f[j][k]=g[j][k];
				if(j<i)g[j+1][k]+=g[j][k];
				if(k<j)g[j][k+1]+=g[j][k];
			}
	}
}
inline ll qry(int n,int k){
	n-=k*3;
	if(n<0)return 0;
	ll ans=0;int k3=k*3;
	for(int i=0;i<=n&&i<(int)f1[k].size();i++)Add(ans,f1[k][i]*binom(n-i+k3-1,n-i)%mod);
	for(int i=0;i<=n&&i<(int)f2[k].size();i++)Add(ans,f2[k][i]*binom(n-i+k3-1,n-i)%mod*pw2[n-i]%mod);
	for(int i=0;i<=n&&i<(int)f3[k].size();i++)Add(ans,f3[k][i]*pw3[n-i]%mod);
	return ans;
}
bool Med;
int main(){
	cerr<<abs(&Mst-&Med)/1048576.0<<endl;
	ios::sync_with_stdio(false);
	cin.tie(0),cout.tie(0);
	init(1000101);
	Init(101);
	int _Test;cin>>_Test;
	while(_Test--){
		int n,k;cin>>n>>k,k=n-k;
		cout<<sub(qry(n,k),qry(n,k+1))<<'\n';
	}
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...