社区讨论

求解KDT为什么我的不平衡重构跑不过暴力

P4148简单题参与者 2已保存回复 3

讨论操作

快速查看讨论及其快照的属性,并进行相关操作。

当前回复
3 条
当前快照
1 份
快照标识符
@m5p2lanj
此快照首次捕获于
2025/01/09 16:32
去年
此快照最后确认于
2025/11/04 11:50
4 个月前
查看原帖
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 10, INF = 2e9;
const double bs = 0.9;
int n;
struct node {
	int nums[2], w;
};
struct trnode {
	int X[2], Y[2];
	int d, ls, rs, sum, val;
	void pr () {
		printf ("X[0]:%d Y[0]:%d X[1]:%d Y[1]:%d\n", X[0], Y[0], X[1], Y[1]);
	}
	node p;
	inline void clear (node x, int pd) {
		ls = rs = 0; d = pd;
		sum = 1, val = x.w;
		X[0] = X[1] = x.nums[0];
		Y[0] = Y[1] = x.nums[1];
		p = x;
	}
}tr[N];
inline int max (int a, int b) {return a < b ? b : a;}
inline int max (int a, int b, int c) {return max(a, max(b, c));}
inline int min (int a, int b) {return a > b ? b : a;}
inline int min (int a, int b, int c) {return min(a, min(b, c));}
inline bool cmp0 (node a, node b) {return a.nums[0] < b.nums[0];}
inline bool cmp1 (node a, node b) {return a.nums[1] < b.nums[1];}
int fw[N], cnt, id, top;
node stc[N];
int newnode () {
	if (cnt) return fw[cnt--];
	return ++id;
}
inline void update (int k) {
	tr[k].X[0] = min (tr[tr[k].ls].X[0], tr[tr[k].rs].X[0], tr[k].p.nums[0]);		
	tr[k].X[1] = max (tr[tr[k].ls].X[1], tr[tr[k].rs].X[1], tr[k].p.nums[0]);		
	tr[k].Y[0] = min (tr[tr[k].ls].Y[0], tr[tr[k].rs].Y[0], tr[k].p.nums[1]);		
	tr[k].Y[1] = max (tr[tr[k].ls].Y[1], tr[tr[k].rs].Y[1], tr[k].p.nums[1]);		
	tr[k].sum = tr[tr[k].ls].sum + tr[tr[k].rs].sum + 1;
	tr[k].val = tr[tr[k].ls].val + tr[tr[k].rs].val + tr[k].p.w;
}
void dfs (int k) {
	if (!k) return;
	stc[++top] = tr[k].p;
	fw[++cnt] = k;
	dfs (tr[k].ls); dfs (tr[k].rs);
}
int build (int l, int r, int d) {
	if (l > r) return 0; 
	int k = newnode(), mid = (l + r) >> 1;
	nth_element(stc + l, stc + mid, stc + r + 1, d ? cmp1 : cmp0);	
	tr[k].clear(stc[mid], d);
	tr[k].ls = build (l, mid - 1, d ^ 1);
	tr[k].rs = build (mid + 1, r, d ^ 1);
	update(k);
	return k;
}
inline void check (int &k, int d) {
	if (bs * tr[k].sum < max (tr[tr[k].ls].sum, tr[tr[k].rs].sum)) {
		top = 0; dfs (k);
		k = build(1, top, d);
	}
}
void insert (int &k, node p, int d) {
	if (!k) {
		k = newnode();
		tr[k].clear(p, d);
		return;
	}
	if (tr[k].p.nums[d] <= p.nums[d]) insert(tr[k].ls, p, d ^ 1);
	else insert(tr[k].rs, p, d ^ 1);
	update(k);
	check(k, d);
}
inline bool in (int a, int b, int c, int d, int A, int B, int C, int D) {
	return (A <= a && B <= b && C >= c && D >= d);
}
inline bool df (int a, int b, int c, int d, int A, int B, int C, int D) {
	return (c < A || d < B || C < a || D < b);
}
int query (int k, int a, int b, int c, int d) {
	if (!k) return 0;
	int res = 0;
	if (in(tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1], a, b, c, d)) return tr[k].val;
	if (df(a, b, c, d, tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1])) return 0;
	if (in(tr[k].p.nums[0], tr[k].p.nums[1], tr[k].p.nums[0], tr[k].p.nums[1], a, b, c, d)) res = tr[k].p.w;
	return res + query(tr[k].ls, a, b, c, d) + query(tr[k].rs, a, b, c, d);
}
signed main () {
	tr[0] = {INF, -INF, INF, -INF, 0, 0, 0, 0, 0};
	scanf("%d", &n);
	int rot = 0, lsan = 0;
	while (1) {
		int opt; scanf("%d", &opt);
		if (opt == 3) break;
		if (opt == 1) {
			int x, y, w;
			scanf("%d%d%d", &x, &y, &w);
			x ^= lsan; y ^= lsan; w ^= lsan; 
			insert(rot, {x, y, w}, 1);
		} else {
			int a, b, c, d;
			scanf("%d%d%d%d", &a, &b, &c, &d);
			a ^= lsan; b ^= lsan; c ^= lsan; d ^= lsan;
			cout << (lsan = query(rot, a, b, c, d)) << '\n';
		}
	}
	return 0;
}
CPP
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, INF = 2e9;
const double bs = 2;
int n;
struct node {
	int nums[2], w;
};
struct trnode {
	int X[2], Y[2];
	int d, ls, rs, sum, val;
	void pr () {
		printf ("X[0]:%d Y[0]:%d X[1]:%d Y[1]:%d\n", X[0], Y[0], X[1], Y[1]);
	}
	node p;
	inline void clear (node x, int pd) {
		ls = rs = 0; d = pd;
		sum = 1, val = x.w;
		X[0] = X[1] = x.nums[0];
		Y[0] = Y[1] = x.nums[1];
		p = x;
	}
}tr[N];
inline int max (int a, int b) {return a < b ? b : a;}
inline int max (int a, int b, int c) {return max(a, max(b, c));}
inline int min (int a, int b) {return a > b ? b : a;}
inline int min (int a, int b, int c) {return min(a, min(b, c));}
inline bool cmp0 (node a, node b) {return a.nums[0] < b.nums[0];}
inline bool cmp1 (node a, node b) {return a.nums[1] < b.nums[1];}
int fw[N], cnt, id, top;
node stc[N];
int newnode () {
	if (cnt) return fw[cnt--];
	return ++id;
}
inline void update (int k) {
	tr[k].X[0] = min (tr[tr[k].ls].X[0], tr[tr[k].rs].X[0], tr[k].p.nums[0]);		
	tr[k].X[1] = max (tr[tr[k].ls].X[1], tr[tr[k].rs].X[1], tr[k].p.nums[0]);		
	tr[k].Y[0] = min (tr[tr[k].ls].Y[0], tr[tr[k].rs].Y[0], tr[k].p.nums[1]);		
	tr[k].Y[1] = max (tr[tr[k].ls].Y[1], tr[tr[k].rs].Y[1], tr[k].p.nums[1]);		
	tr[k].sum = tr[tr[k].ls].sum + tr[tr[k].rs].sum + 1;
	tr[k].val = tr[tr[k].ls].val + tr[tr[k].rs].val + tr[k].p.w;
}
void dfs (int k) {
	if (!k) return;
	stc[++top] = tr[k].p;
	fw[++cnt] = k;
	dfs (tr[k].ls); dfs (tr[k].rs);
}
int build (int l, int r, int d) {
	if (l > r) return 0; 
	int k = newnode(), mid = (l + r) >> 1;
	nth_element(stc + l, stc + mid, stc + r + 1, d ? cmp1 : cmp0);	
	tr[k].clear(stc[mid], d);
	tr[k].ls = build (l, mid - 1, d ^ 1);
	tr[k].rs = build (mid + 1, r, d ^ 1);
	update(k);
	return k;
}
void check (int &k, int d) {
	if (bs * tr[k].sum < max (tr[tr[k].ls].sum, tr[tr[k].rs].sum)) {
		top = 0; dfs (k);
		k = build(1, top, d);
	}
}
void insert (int &k, node p, int d) {
	if (!k) {
		k = newnode();
		tr[k].clear(p, d);
		return;
	}
	if (tr[k].p.nums[d] <= p.nums[d]) insert(tr[k].ls, p, d ^ 1);
	else insert(tr[k].rs, p, d ^ 1);
	update(k);
	check(k, d);
}
inline bool in (int a, int b, int c, int d, int A, int B, int C, int D) {
	return (A <= a && B <= b && C >= c && D >= d);
}
inline bool df (int a, int b, int c, int d, int A, int B, int C, int D) {
	return (c < A || d < B || C < a || D < b);
}
int query (int k, int a, int b, int c, int d) {
	if (!k) return 0;
	int res = 0;
	if (in(tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1], a, b, c, d)) return tr[k].val;
	if (df(a, b, c, d, tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1])) return 0;
	if (in(tr[k].p.nums[0], tr[k].p.nums[1], tr[k].p.nums[0], tr[k].p.nums[1], a, b, c, d)) res = tr[k].p.w;
	return res + query(tr[k].ls, a, b, c, d) + query(tr[k].rs, a, b, c, d);
}
signed main () {
	tr[0] = {INF, -INF, INF, -INF, 0, 0, 0, 0, 0};
	scanf("%d", &n);
	int rot = 0, lsan = 0;
	while (1) {
		int opt; scanf("%d", &opt);
		if (opt == 3) break;
		if (opt == 1) {
			int x, y, w;
			scanf("%d%d%d", &x, &y, &w);
			x ^= lsan; y ^= lsan; w ^= lsan; 
			insert(rot, {x, y, w}, 1);
		} else {
			int a, b, c, d;
			scanf("%d%d%d%d", &a, &b, &c, &d);
			a ^= lsan; b ^= lsan; c ^= lsan; d ^= lsan;
			cout << (lsan = query(rot, a, b, c, d)) << '\n';
		}
	}
	return 0;
}

回复

3 条回复,欢迎继续交流。

正在加载回复...