社区讨论

求教如何O(n)递推

AT_arc061_d[ARC061F] 3人でカードゲーム参与者 1已保存回复 1

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
1 条
当前快照
1 份
快照标识符
@mi6x6cgx
此快照首次捕获于
2025/11/20 12:16
4 个月前
此快照最后确认于
2025/11/20 12:16
4 个月前
查看原帖
蒟蒻只会O(mk) 推完式子还是和暴力一样的复杂度(雾)
暴力是k=0K 3KkCn1+kk m=0MCn1+m+km3Mm\sum_{k = 0}^{K}\ 3^{K-k}C_{n-1+k}^{k}\ \sum_{m = 0}^{M}C_{n - 1 + m + k}^{m}3^{M-m} 如果把后面那一坨m=0MCn1+m+km3Mm\sum_{m = 0}^{M}C_{n - 1 + m + k}^{m}3^{M-m}
记为f(i)f(i)的话,f(i)f(i1)f(i)与f(i-1)的关系可以推出来,
但求这个sigma还是要O(m)O(m)的时间,总时间复杂度不变
正解可以一次线性推完 求教接下来怎么推
附部分分代码
CPP
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
const int maxn = 9e5 + 7, md = 1e9 + 7;
ll inv[maxn], finv[maxn], fac[maxn], f[maxn], p2[maxn], p3[maxn];
ll N, M, K, m, k;
ll read(){
	int s = 0; char c = getchar();
	while (c > '9' || c < '0') c = getchar();
	while (c >= '0' && c <= '9') s = s * 10 + c - '0', c = getchar();
	return s;
}
ll ksm(ll a, int b){
	ll res = 1;
	while (b){
		if (b & 1) res = res * a % md;
		a = a * a % md;
		b >>=1 ;
	}
	return res;
}
void init(){
	fac[0] = fac[1] = 1; inv[1] = finv[1] = 1;
	/* p = ki + r;
	ki+r /eqiv  0 mod p
	i^{-1} /eqiv -kr^{-1} mod p
	i^{-1} /eqiv p -(p/i)*(p-(p/i)*i)^{-1} mod p
	*/
	p3[1] = 3;p3[0] = 1;
	for (int i = 2; i < maxn; i++) p3[i] = p3[i - 1] * 3 % md;
	for (int i = 2; i < maxn; i++){
		inv[i] = ( md - (md / i) * inv[md%i] % md) % md;
		//finv[i] = finv[i - 1] * inv[i] % md;
		//printf("%lld\n", inv[i] * i % md );
		//inv2[i] = (md - md / i) * inv2[md % i] % md;
	}
	for (ll i = 1; i < maxn; i++) {
		fac[i] = fac[i - 1] * i % md;
		//finv[i] = finv[i - 1] * inv[i] % md;
		//printf("%lld\n", fac[i] * finv[i] % md );
	}
	finv[maxn - 1] = ksm(fac[maxn - 1], md - 2);
	for (ll i = maxn - 2; i >= 0; i--) finv[i] = finv[i + 1] * (i + 1) % md;
	//for (int i = 1; i <= N; i++) printf("inv = %lld inv2 = %lld\n",inv[i],inv2[i]);
}
ll ans = 0;
ll getf(){
	ll res = 0;
	f[0] = p3[M];
	res = f[0];
	for (int i = 1; i <= M; i++)
		f[i] = (((((f[i - 1] * (N - 1 + i + k) % md + md) % md ) * inv[3]) % md) * inv[i] + md) % md, res = ((res + f[i]) % md + md) % md;
	return res;
}
ll C(int x, int y){
	return fac[x] * finv[y] % md * finv[x - y] % md;
}
int main(){
	N = read(), M = read(), K = read();
	init();
	for (k = 0; k <= K; k++) {
		ans = (ans + p3[K-k] * C(N - 1 + k, k) % md * getf() % md);
	}
	printf("%lld\n", ans % md);
	return 0;
}

回复

1 条回复,欢迎继续交流。

正在加载回复...