专栏文章

题解:AT_abc422_g Balls and Boxes

AT_abc422_g题解参与者 2已保存评论 1

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
1 条
当前快照
1 份
快照标识符
@minx8vkf
此快照首次捕获于
2025/12/02 09:50
3 个月前
此快照最后确认于
2025/12/02 09:50
3 个月前
查看原文

思路

给出一个根号做法。
第一问是搞笑完全背包,做 33 次就行。
考虑第二问在求什么,假设三个盒子分别放了 x,y,zx,y,z 个,求的就是 (nx)×(nxy)×(nxyz)\sum \binom{n}{x}\times \binom{n-x}{y}\times \binom{n-x-y}{z}
我们化简一下这个式子,其实就是 n!x!×y!×z!\sum \dfrac{n!}{x!\times y!\times z!}
首先有一个显然的做法,枚举 x,yx,y 硬算这个式子,复杂度是 O(n2AB)O(\dfrac{n^2}{AB}) 的。
显然我们可以在 ABnAB\ge \sqrt{n} 的时候做这个。
考虑 ABnAB\le \sqrt{n} 怎么办,发现这时可以接受一个 O(nAB)O(nAB) 的东西。
于是我们设 fi,x,yf_{i,x,y} 表示当前放了 ii 个球,第一个盒子放的球的个数模 AAxx,第二个盒子放的球的个数模 BByy,且只往前两个盒子放球的方案数。
答案就是枚举一个 CC 的倍数 zz,求 fnz,0,0×(nz)\sum f_{n-z,0,0}\times \binom{n}{z}
于是做到了 O(nn)O(n\sqrt{n})

代码

CPP
#include<bits/stdc++.h>
#define int long long
#define N 300005
#define mod 998244353
#define pii pair<int,int>
#define x first
#define y second
#define pct __builtin_popcount
#define mpi make_pair
#define inf 2e18
using namespace std;
int T=1,n,a[5],f[N],fac[N],inv[N];
vector<vector<int>>g[N];
int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1)(res*=x)%=mod;
		(x*=x)%=mod;
		y>>=1;
	}
	return res;
}
void init(){
	int n=3e5;
	fac[0]=inv[0]=1;
	for(int i=1;i<=n;i++){
		fac[i]=fac[i-1]*i%mod;
	}
	inv[n]=ksm(fac[n],mod-2);
	for(int i=n-1;i;i--){
		inv[i]=inv[i+1]*(i+1)%mod;
	}
}
int C(int n,int m){
	if(n<0||m<0||n<m)return 0;
	return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void solve(int cs){
    if(!cs)return;
	cin>>n>>a[1]>>a[2]>>a[3];
	f[0]=1;
	for(int i=1;i<=3;i++){
		for(int j=a[i];j<=n;j++){
			(f[j]+=f[j-a[i]])%=mod;
		}
	}
	cout<<f[n]<<'\n';
	if(a[1]*a[2]>=sqrt(n)){
		int res=0;
		for(int i=0;i<=n/a[1];i++){
			for(int j=0;j<=(n-a[1]*i)/a[2];j++){
				int cur=n-i*a[1]-j*a[2];
				if(cur%a[3])continue;
				(res+=fac[n]*inv[i*a[1]]%mod*inv[j*a[2]]%mod*inv[cur]%mod)%=mod;
			}
		}
		cout<<res<<'\n';
	}
	else{
		for(int i=0;i<=n;i++){
			g[i].resize(a[1]);
			for(int j=0;j<a[1];j++){
				g[i][j].resize(a[2]);
			}
		}
		g[0][0][0]=1;
		for(int i=1;i<=n;i++){
			for(int j=0;j<a[1];j++){
				for(int k=0;k<a[2];k++){
					int x=j-1,y=k-1;
					if(x<0)x=a[1]-1;
					if(y<0)y=a[2]-1;
					(g[i][j][k]+=g[i-1][x][k]+g[i-1][j][y])%=mod;
				}
			}
		}
		int res=0;
		for(int c=0;c<=n/a[3];c++){
			(res+=C(n,c*a[3])*g[n-c*a[3]][0][0]%mod)%=mod;
		}
		cout<<res<<'\n';
	}
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);cout.tie(0);
	// cin>>T;
	init();
	for(int cs=1;cs<=T;cs++){
		solve(cs);
	}
	return 0;
}

评论

1 条评论,欢迎与作者交流。

正在加载评论...