专栏文章

在线决策单调性地皮还能单老哥分治做?

算法·理论参与者 21已保存评论 24

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
24 条
当前快照
1 份
快照标识符
@miojunwj
此快照首次捕获于
2025/12/02 20:23
3 个月前
此快照最后确认于
2025/12/02 20:23
3 个月前
查看原文
原博客: 簡易版 LARSCH Algorithm noshi91。

众所周知决策单调性 dp 有很多做法,大家应该都会。但是主播主播,二分队列遇上要莫队算的贡献还是太菜了,整体二分还是太吃离线了,SMAWK 算法和 Wilber 算法又太吃常数和码量了。那么有没有什么常数小、能处理莫队、还好写的决策单调性优化 dp 呢?
有的兄弟有的,这就是我们要介绍的:

简易版 LARSCH 算法

设 dp 数组为 ff。考虑分治。尝试设计一个 solve(l,r)\operatorname{solve}(l,r) 函数。我们希望运行完这个函数后,fl+1rf_{l+1\sim r} 的值都已经算对了。但是考虑分治中心 midmid,我们如果要求 fmidf_{mid} 一定要算对,那么 0mid10\sim mid-1 的东西都要算对,这显然是不现实的。所以我们需要分成若干段考虑贡献。
具体地,我们在算 midmid 时,不要求所有可能成为决策点的地方都算对了,而是要求 flf_l 的决策点到 frf_r大致决策点内的东西都算对了。我们需要:
  • 0l0\sim lff 值和决策点都算对了。
  • 只考虑 0l0\sim l 的话rrff 值和决策点都算对了。
mid=(l+r)/2mid=(l+r)/2,那么我们需要做的事情是:
  1. ll 的决策点到 rr当前决策点之间的点转移到 midmid,更新 fmidf_{mid} 的值与决策点。
  2. 递归 solve(l,mid)\operatorname{solve}(l,mid)
  3. l+1l+1midmid 之间的点转移到 rr,更新 frf_r 的值与决策点。
  4. 递归 solve(mid,r)\operatorname{solve}(mid,r)
然后我们就把所有 ff 值算对了。
为什么?首先看第一步。由决策单调性,fmidf_{mid} 在只考虑 0l0\sim l 时的决策点,必然在 flf_l 的决策点与 frf_r 的决策点之间,所以第二步的递归前提是满足的。第三步中,frf_r 多考虑了 l+1midl+1\sim mid 的部分,因此第四步的递归前提仍是满足的。而对于最后一层的递归 solve(k1,k)\operatorname{solve}(k-1,k)fkf_k 已经考虑了 0k10\sim k-1 的所有位置,因此运行结束后所有位置的 ff 都算对了。
时间复杂度是多少呢?可以发现,对于分治的每一层,我们都相当于把所有点遍历了一遍,所以复杂度是 T(n)=2T(n/2)+O(n)=O(nlogn)T(n)=2T(n/2)+\mathcal{O}(n)=\mathcal{O}(n\log n) 的。
当然该算法也有局限性。具体的,由于本算法对决策点进行了部分估计,在限制不够强时可能无法包含最优决策。因此 dp 的转移应当满足四边形不等式

例题

wqs 二分使用例:[ABC355G] Baseball

dpi,jdp_{i,j} 表示前 ii 个位置放了 jj 个点的方案数,则有
dpi,j=minkdpk,j1+w(k,i)dp_{i,j}=\min_kdp_{k,j-1}+w(k,i)
ww 为中间点的距离乘以权值之和。显然 ww 满足四边形不等式,从而 dpn,jdp_{n,j} 关于 jj 是凸的,可以使用 wqs 二分转化为这个 dp:
fi=minjfj+w(j,i)f_i=\min_j f_j+w(j,i)
用上述算法解决,时间复杂度 O(nlognlogV)\mathcal{O}(n\log n\log V)
CodeCPP
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 5e4 + 10;

int n, m; ll s[MAXN], si[MAXN];

inline 
ll w(int l, int r) {
	if (l > r) return 0;
	if (l == 1 && r == n) return 1e18;
	if (l == 1) return (r + 1) * s[r] - si[r];
	if (r == n) return (si[n] - si[l - 1]) - (l - 1) * (s[n] - s[l - 1]);
	int mid = l + r >> 1;
	return (si[mid] - si[l - 1]) - (l - 1) * (s[mid] - s[l - 1])
		+ (r + 1) * (s[r] - s[mid]) - (si[r] - si[mid]);
}

ll dp[MAXN], X; int cnt[MAXN], p[MAXN];

inline 
void check(int i, int j) {
	ll x = dp[j] + w(j + 1, i - 1) + X;
	if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
	else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}

void solve(int l, int r) {
	if (r - l == 1) return ; int mid = l + r >> 1;
	for (int i = p[l]; i <= p[r]; i++) check(mid, i);
	solve(l, mid);
	for (int i = l + 1; i <= mid; i++) check(r, i);
	solve(mid, r);
}

inline 
bool check() {
	for (int i = 1; i <= n + 1; i++) dp[i] = 1e18, cnt[i] = p[i] = 0;
	check(n + 1, 0), solve(0, n + 1);
	return cnt[n + 1] <= m;
}

ll l, r, ans;

int main() {
	scanf("%d%d", &n, &m), m++;
	for (int i = 1; i <= n; i++) scanf("%lld", &s[i]);
	for (int i = 1; i <= n; i++) si[i] = si[i - 1] + i * s[i];
	for (int i = 1; i <= n; i++) s[i] += s[i - 1];
	for (l = 0, r = 1e10; l <= r; ) {
		X = l + r >> 1;
		if (check()) r = X - 1, ans = X;
		else l = X + 1;
	}
	X = ans, check(), printf("%lld", dp[n + 1] - X * m);
}

莫队使用例:CF868F Yet Another Minimization Problem

非常经典的题了。由于转移是离线的,可以用常规的整体二分方法解决。这里尝试使用本文介绍的算法把它做掉。
容易发现,第一步和第三步的莫队都满足移动次数 O(nlogn)\mathcal{O}(n\log n) 的性质,分开跑即可。
注意不要用同一个莫队跑。否则指针会在决策点与区间内反复横跳,导致复杂度退化为平方级别。
CodeCPP
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 10;

int a[MAXN];

struct node {
	
	int cnt[MAXN], l, r; ll ans;
	
	node() : l(1), r(0), ans(0) { memset(cnt, 0, sizeof cnt); }
	
	inline 
	ll w(int ql, int qr) {
		if (ql > qr) return 0;
		for (; l > ql; ans += cnt[a[--l]]++);
		for (; r < qr; ans += cnt[a[++r]]++);
		for (; l < ql; ans -= --cnt[a[l++]]);
		for (; r > qr; ans -= --cnt[a[r--]]);
		return ans;
	}
	
} A, B;

ll dp[MAXN][30]; int p[MAXN][30];

inline 
void checkA(int i, int j, int k) {
	ll x = dp[j][k - 1] + A.w(j + 1, i);
	if (x < dp[i][k]) dp[i][k] = x, p[i][k] = j;
}

inline 
void checkB(int i, int j, int k) {
	ll x = dp[j][k - 1] + B.w(j + 1, i);
	if (x < dp[i][k]) dp[i][k] = x, p[i][k] = j;
}

void solve(int l, int r, int k) {
	if (r - l == 1) return ; int mid = l + r >> 1;
	for (int i = p[l][k]; i <= p[r][k]; i++) checkA(mid, i, k);
	solve(l, mid, k);
	for (int i = l + 1; i <= mid; i++) checkB(r, i, k);
	solve(mid, r, k);
}

int n, k;

int main() {
	scanf("%d%d", &n, &k);
	for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
	memset(dp, 0x3f, sizeof dp), **dp = 0;
	for (int i = 1; i <= k; i++) checkA(n, 0, i), solve(0, n, i);
	printf("%lld", dp[n][k]);
}

莫队使用例:P9266 [PA 2022] Nawiasowe podziały

设一个区间的代价为 w(l,r)w(l,r),即区间 [l,r][l,r] 内的合法括号子串个数。我们有 w(l,r)+w(l+1,r1)w(l+1,r)+w(l,r1)w(l,r)+w(l+1,r-1)\ge w(l+1,r)+w(l,r-1)。因此转移满足四边形不等式,同时可以用 wqs 二分去掉段数的限制。
但这时你发现个问题,这个代价很难用莫队以外的办法算出来。这导致二分队列直接倒闭了。同时,转移是在线的,整体二分还要套 cdq 才能处理。
本讲方法的优越性就体现出来了,可以直接做到 O(nlognlogV)\mathcal{O}(n\log n\log V),常数较小。
CodeCPP
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 10;

int n, m, pos[MAXN], col[MAXN]; char s[MAXN];

struct node {
	
	int cnt[MAXN], l, r; ll ans;
	
	node() : l(1), r(0), ans(0) { memset(cnt, 0, sizeof cnt); }
	
	inline 
	ll w(int ql, int qr) {
		if (ql > qr) return 0;
		for (; l > ql; ) {
			l--;
			if (s[l] == '(' && pos[l] <= r) ans += ++cnt[col[l]];
		}
		for (; r < qr; ) {
			r++;
			if (s[r] == ')' && pos[r] >= l) ans += ++cnt[col[r]];
		}
		for (; l < ql; ) {
			if (s[l] == '(' && pos[l] <= r) ans -= cnt[col[l]]--;
			l++;
		}
		for (; r > qr; ) {
			if (s[r] == ')' && pos[r] >= l) ans -= cnt[col[r]]--;
			r--;
		}
		return ans;
	}
	
} A, B;

ll dp[MAXN], X; int cnt[MAXN], p[MAXN];

inline 
void checkA(int i, int j) {
	ll x = dp[j] + A.w(j + 1, i) + X;
	if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
	else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}

inline 
void checkB(int i, int j) {
	ll x = dp[j] + B.w(j + 1, i) + X;
	if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
	else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}

void solve(int l, int r) {
	if (r - l == 1) return ; int mid = l + r >> 1;
	for (int i = p[l]; i <= p[r]; i++) checkA(mid, i);
	solve(l, mid);
	for (int i = l + 1; i <= mid; i++) checkB(r, i);
	solve(mid, r);
}

inline 
bool check() {
	for (int i = 1; i <= n; i++) dp[i] = 1e18, cnt[i] = p[i] = 0;
	checkA(n, 0), solve(0, n);
	return cnt[n] <= m;
}

int st[MAXN], tp, id; 

ll l, r, ans; 

int main() {
	scanf("%d%d%s", &n, &m, s + 1);
	for (int i = 1; i <= n; i++) {
		if (s[i] == '(') { st[++tp] = i, pos[i] = n + 1; continue; }
		if (!tp) continue; int j = st[tp--]; pos[i] = j, pos[j] = i;
		col[i] = col[j] = (col[j - 1] ? col[j - 1] : ++id);
	}
	for (l = 0, r = 1e10; l <= r; ) {
		X = l + r >> 1;
		if (check()) r = X - 1, ans = X;
		else l = X + 1;
	}
	X = ans, check();
	printf("%lld", dp[n] - X * m);
}

评论

24 条评论,欢迎与作者交流。

正在加载评论...