专栏文章

题解:P14224 [ICPC 2024 Kunming I] 子数组

P14224题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@minlb579
此快照首次捕获于
2025/12/02 04:16
3 个月前
此快照最后确认于
2025/12/02 04:16
3 个月前
查看原文
考虑 O(n2)O(n^2) 怎么做。先找出这个序列中所有最大值的位置 p1,p2,,pkp_1, p_2, \dots, p_k,这些点将序列分成了 k+1k+1子段,记这些段的长度为 a1,a2,,ak+1a_1, a_2, \dots, a_{k+1},那么该最大值对答案的贡献为 ansi=j(aj+1)(ai+j+1)ans_i = \sum_j (a_j+1) (a_{i+j}+1)。然后这些子段分治下去即可。
优化的话,令 bi=anib_i = a_{n-i},则贡献变为 ansi=j(aj+1)(bnij+1)ans_i = \sum_j (a_j+1) (b_{n-i-j}+1),可以卷积。直接套入 NTT 就可以做了。
实现的时候注意边界,式子有些地方可能需要 +1+1 或者 1-1
CPP
#include<bits/stdc++.h>
using namespace std;
const int M=998244353;
int T,n,a[400005],b[400005];
vector<int> G[400005];
int val[2000005],val2[2000005],val3[2000005],ans[2000005],len;
namespace NTT {
	int pow(int x, int y) {
		int res=1;
		while(y) {
			if(y&1) res=1ll*res*x%M;
			x=1ll*x*x%M;
			y>>=1;
		}
		return res;
	}
	int N,K,p[2000005];
	void ntt(int* x, int inv) {
		for(int i=0; i<N; i++) if(p[i]<i) swap(x[p[i]],x[i]);
		for(int h=2; h<=N; h<<=1) {
			int gn=pow(3,(M-1)/h);
			for(int i=0; i<N; i+=h) {
				int g=1;
				for(int j=i; j<i+h/2; j++, g=1ll*g*gn%M) {
					int u=x[j], v=1ll*x[j+h/2]*g%M;
					x[j]=(u+v)%M, x[j+h/2]=(u-v+M)%M;
				}
			}
		}
		if(inv) {
			reverse(x+1,x+N);
			int invn=pow(N,M-2);
			for(int i=0; i<N; i++) x[i]=1ll*x[i]*invn%M;
		}
	}
	void calc() {
		if(len<=10) {
			for(int i=0; i<len*2; i++) val3[i]=0;
			for(int i=0; i<len; i++)
				for(int j=0; j<len; j++)
					val3[i+j]=(val3[i+j]+1ll*val[i]*val2[j])%M;
			return;
		}
		N=1, K=0;
		while(N<len*2) N<<=1, K++;
		for(int i=1; i<=N; i++) p[i]=(p[i>>1]>>1)+((i&1)<<(K-1));
		for(int i=len; i<N; i++) val[i]=val2[i]=0;
		ntt(val,0), ntt(val2,0);
		for(int i=0; i<N; i++) val3[i]=1ll*val[i]*val2[i]%M;
		ntt(val3,1);
	}
}

int st[1000005][19];
void initst() {
	for(int i=1; i<=n; i++) st[i][0]=a[i];
	for(int i=1; i<=18; i++)
		for(int j=1; j<=n; j++)
			st[j][i]=max(st[j][i-1], st[j+(1<<i-1)][i-1]);
}
int getst(int x, int y) {
	int len=y-x+1, lg2=31-__builtin_clz(len);
	return max(st[x][lg2], st[y-(1<<lg2)+1][lg2]);
}
void solve(int l, int r) {
	if(l>r) return;
	int mx=getst(l,r);
	len=0;
	int posl=lower_bound(G[mx].begin(), G[mx].end(), l)-G[mx].begin();
	int posr=upper_bound(G[mx].begin(), G[mx].end(), r)-G[mx].begin()-1;
	val[len++]=G[mx][posl]-l+1;
	for(int i=posl; i<posr; i++) val[len++]=G[mx][i+1]-G[mx][i];
	val[len++]=r-G[mx][posr]+1;
//	printf("l=%d r=%d\n",l,r);
//	for(int i=0; i<len; i++) cout<<val[i]<<' ';
//	cout<<'\n';
	for(int i=0; i<len; i++) val2[i]=val[len-1-i];
	NTT::calc();
	for(int i=1; i<len; i++) ans[i]=(ans[i]+val3[len-1-i])%M;

	solve(l,G[mx][posl]-1);
	for(int i=posl; i<posr; i++) solve(G[mx][i]+1,G[mx][i+1]-1);
	solve(G[mx][posr]+1,r);
}
int main() {
	ios::sync_with_stdio(0), cin.tie(0);
//	freopen("in","r",stdin);
	cin>>T;
	while(T--) {
		cin>>n;
		for(int i=1; i<=n; i++) ans[i]=0;
		for(int i=1; i<=n; i++) cin>>a[i], b[i]=a[i];
		sort(b+1,b+1+n);
		int m=unique(b+1,b+1+n)-b-1;
		for(int i=1; i<=n; i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
		for(int i=1; i<=m; i++) G[i].clear();
		for(int i=1; i<=n; i++) G[a[i]].push_back(i);
		initst();
		solve(1,n);
		int anss=0;
		for(int i=1; i<=n; i++) anss+=1ll*i*ans[i]%M*ans[i]%M, anss%=M;
		cout<<anss<<'\n';
	}
	return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...