专栏文章

题解:P14312 【模板】K-D Tree

P14312题解参与者 3已保存评论 2

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
2 条
当前快照
1 份
快照标识符
@min1p1xu
此快照首次捕获于
2025/12/01 19:07
3 个月前
此快照最后确认于
2025/12/01 19:07
3 个月前
查看原文
看到 K-D Tree 有模板题了,来水一发题解。
给了二进制分组的复杂度证明,OI-Wiki 上和洛谷上好像都没有给出过具体的证明(一个式子本蒟蒻看不懂,给一个详细一点的证明)。

K-D Tree

K-D Tree 是什么:
K-D Tree(KDT,K-Dimension Tree)是一种可以高效处理 kk 维空间信息的数据结构。
在结点数 nn 远大于 2k2^k 时,应用 K-D Tree 的时间效率很好。
简单来说,K-D Tree 就是一种可以高效处理高维空间中点的数据结构(例如可以解决强制在线的三维偏序),一般比较实用的是 2-D Tree 和 3-D Tree,也就是本题中要实现的。

节点信息

K-D Tree 是一颗二叉搜索树,每个节点是一个点,每棵子树是一个 kk 维空间。
每个节点需要存的信息如下:
CPP
struct node{
    int x[3];
    int val,sum;
    int ls,rs;
    int l[3],r[3];
    int siz,tag;
}t[N],L,R;
数组 tt 是节点,L,RL,R 是查询时矩形的两个顶点。
每个节点中要存这个节点的每维位置,节点的权值,子树的权值和,左右儿子,子树所表示空间的边界,子树大小,修改的懒标记。

建树

下面以二维平面为例,给出 K-D Tree 的建树方法。
首先为了保证平衡,我们应当对每个维度轮流处理,以下面这个图为例:
先对于第一维找到中间的点 DD,将平面分为两个部分。
换一个维度,找到两个部分的中点 C,EC,E,将平面分为四个部分,然后以此类推。
最后的树应该长这样:
主要操作就是找到一个维度的中点,将点分为左右两部分,可以直接用 nth_element 实现,复杂度 O(n)O(n),所以总的建树复杂度为 O(nlogn)O(n\log n),且最后的树高是严格 logn+O(1)\log n+O(1) 的,代码如下:
CPP
int build(int l,int r,int k=0){
    if(l>r) return 0;
    int mid=(l+r)>>1;
    nth_element(a+l,a+mid,a+r+1,[k](int x,int y){
        return t[x].x[k]<t[y].x[k];
    });
    int p=a[mid];
    ls=build(l,mid-1,(k+1)%K);
    rs=build(mid+1,r,(k+1)%K);
    up(p);
    return p;
}

查询

写法非常简单,与查询部分无交直接返回,有部分相交递归子树,全部相交返回处理好的子树权值和,代码如下:
CPP
 int query(int p){
    if(!p) return 0;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return 0;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f) return t[p].sum;
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    down(p);
    return f*t[p].val+query(ls)+query(rs);
}
实现是简单的,主要讲一下复杂度和证明,以二维为例:
在一个节点,如果递归了左右儿子,那么说明矩形与该节点的矩形部分相交,再考虑这个节点的四个孙子有哪些会再次与矩形部分相交,注意到与矩形部分相交的节点的矩形一定会被矩形的一条边穿过,所以我们将查询矩形的四条边分开来考虑,而一条边(与坐标轴平行)最多穿过这个节点四个孙子的其中两个,这是显然的,所以可以得到:
T(n)=2T(n4)+O(1)T(n)=2T(\frac{n}{4})+O(1) T(n)=O(n)T(n)=O(\sqrt n)
容易扩展到 kk 维形式:
T(n)=2k1T(n2k)+O(1)T(n)=2^{k-1}T(\frac{n}{2^k})+O(1) T(n)=O(n11k)T(n)=O(n^{1-\frac{1}{k}})
所以查询的复杂度为 O(n11k)O(n^{1-\frac{1}{k}})

修改

和查询几乎同理,代码也差不多,就不讲了,代码如下:
CPP
void update(int p){
    if(!p) return;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f){
        t[p].tag+=c;
        t[p].sum+=t[p].siz*c;
        t[p].val+=c;
        return;
    }
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    if(f) t[p].val+=c;
    down(p);
    update(ls);
    update(rs);
    up(p);
}
其中 upup 为和并节点信息的函数,downdown 为下传懒标记的函数,都是比较好理解的,就不多讲了,有一个实现的细节,在没有左右儿子的时候为了防止特判,可以对 00 号节点初始化一下:
CPP
void up(int p){
    t[p].sum=t[p].val+t[ls].sum+t[rs].sum;
    t[p].siz=1+t[ls].siz+t[rs].siz;
    for(int k=0;k<K;k++){
        t[p].l[k]=min(t[p].x[k],min(t[ls].l[k],t[rs].l[k]));
        t[p].r[k]=max(t[p].x[k],max(t[ls].r[k],t[rs].r[k]));
    }
}
void down(int p){
    if(!t[p].tag) return;
    int x=t[p].tag;
    if(ls) t[ls].tag+=x,t[ls].sum+=x*t[ls].siz,t[ls].val+=x;
    if(rs) t[rs].tag+=x,t[rs].sum+=x*t[rs].siz,t[rs].val+=x;
    t[p].tag=0;
}

插入/删除

删除比较显然,直接标记一下表示删了即可,这题用不到,主要讲一下怎么插入一个点。
首先直接插入显然不对,因为查询操作要保证子树大小严格减半。
比较常用维护方式是替罪羊树维护,根号重构以及二进制分组。
前两者的复杂度可以参考文末的文章,这里不多讲了。
比较推荐的方式是写二进制分组,容易实现复杂度也较优,为 O(nn+nlog2n)O(n\sqrt n+n\log^2 n)
具体方式是这样的,开 logn\log n 棵树,大小分别为 20,21,222^0,2^1,2^2\cdots,当出现两颗大小相同的树时,合并为一棵新的大小为原来两倍的树,就是二进制的原理,复杂度也比较显然,建树次数是 O(logn)O(\log n) 的,所以总复杂度为 O(nlog2n)O(n\log^2 n)
查询和修改时对于每棵树分别查询和修改即可,但这样复杂度为啥是对的?
写出复杂度:
i=1logn2i=i=1logn2i2=i=1logn2i\sum_{i=1}^{\log n}\sqrt{2^i}=\sum_{i=1}^{\log n}2^{\frac{i}{2}}=\sum_{i=1}^{\log n}\sqrt2^i
发现是个等比数列,直接求和:
2×2logn121=O(n)\sqrt2\times\frac{\sqrt2^{\log n}-1}{\sqrt2-1}=O(\sqrt n)
所以对于每棵子树分别修改查询复杂度不会多 O(logn)O(\log n),还是 O(n)O(\sqrt n) 的,写法还是比较简单的,插入节点代码如下:
CPP
a[n=1]=cnt;
for(int i=0;i<LG;i++)
    if(rt[i]) release(rt[i]);
    else{
        rt[i]=build(1,n);
        break;
    }
其中 releaserelease 函数为回收节点,这个是简单的:
CPP
void release(int &p){
    if(!p) return;
    a[++n]=p;
    down(p);
    release(ls);
    release(rs);
    p=0;
}

代码

给出我丑陋的实现,因为不想写两颗 K-D Tree,用了循环枚举维度,跑的挺慢,实际上可以展开写。
除此之外,K-D Tree 还可以写成线段树的形式,常数会大一点,但是好写不少,可以参考 ERoRaIn大佬的实现
CPP
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=1.5e5+10,LG=__lg(N)+3,inf=1e18;
int m,K;
struct node{
    int x[3];
    int val,sum;
    int ls,rs;
    int l[3],r[3];
    int siz,tag;
}t[N],L,R;
int rt[LG],cnt=0,n=0,a[N],c;
#define ls t[p].ls
#define rs t[p].rs
void up(int p){
    t[p].sum=t[p].val+t[ls].sum+t[rs].sum;
    t[p].siz=1+t[ls].siz+t[rs].siz;
    for(int k=0;k<K;k++){
        t[p].l[k]=min(t[p].x[k],min(t[ls].l[k],t[rs].l[k]));
        t[p].r[k]=max(t[p].x[k],max(t[ls].r[k],t[rs].r[k]));
    }
}
void down(int p){
    if(!t[p].tag) return;
    int x=t[p].tag;
    if(ls) t[ls].tag+=x,t[ls].sum+=x*t[ls].siz,t[ls].val+=x;
    if(rs) t[rs].tag+=x,t[rs].sum+=x*t[rs].siz,t[rs].val+=x;
    t[p].tag=0;
}
void release(int &p){
    if(!p) return;
    a[++n]=p;
    down(p);
    release(ls);
    release(rs);
    p=0;
}
int build(int l,int r,int k=0){
    if(l>r) return 0;
    int mid=(l+r)>>1;
    nth_element(a+l,a+mid,a+r+1,[k](int x,int y){
        return t[x].x[k]<t[y].x[k];
    });
    int p=a[mid];
    ls=build(l,mid-1,(k+1)%K);
    rs=build(mid+1,r,(k+1)%K);
    up(p);
    return p;
}
int query(int p){
    if(!p) return 0;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return 0;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f) return t[p].sum;
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    down(p);
    return f*t[p].val+query(ls)+query(rs);
}
void update(int p){
    if(!p) return;
    for(int k=0;k<K;k++) if(L.x[k]>t[p].r[k] || t[p].l[k]>R.x[k]) return;
    bool f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].l[k] && t[p].r[k]<=R.x[k]);
    if(f){
        t[p].tag+=c;
        t[p].sum+=t[p].siz*c;
        t[p].val+=c;
        return;
    }
    f=1;
    for(int k=0;k<K;k++) f&=(L.x[k]<=t[p].x[k] && t[p].x[k]<=R.x[k]);
    if(f) t[p].val+=c;
    down(p);
    update(ls);
    update(rs);
    up(p);
}
#undef ls
#undef rs
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin>>K>>m;
    t[0]={0,0,0,0,0,0,0,inf,inf,inf,0,0,0,0,0};
    for(int op,ans=0;m--;){
        cin>>op;
        if(op==1){
            cnt++;
            for(int i=0;i<K;i++) cin>>t[cnt].x[i],t[cnt].x[i]^=ans;
            cin>>t[cnt].val;t[cnt].val^=ans;
            a[n=1]=cnt;
            for(int i=0;i<LG;i++)
                if(rt[i]) release(rt[i]);
                else{
                    rt[i]=build(1,n);
                    break;
                }
        }
        if(op==2){
            for(int i=0;i<K;i++) cin>>L.x[i],L.x[i]^=ans;
            for(int i=0;i<K;i++) cin>>R.x[i],R.x[i]^=ans;
            cin>>c;c^=ans;
            for(int i=0;i<LG;i++) update(rt[i]);
        }
        if(op==3){
            for(int i=0;i<K;i++) cin>>L.x[i],L.x[i]^=ans;
            for(int i=0;i<K;i++) cin>>R.x[i],R.x[i]^=ans;
            ans=0;
            for(int i=0;i<LG;i++) ans+=query(rt[i]);
            cout<<ans<<"\n";
        }
    }
    return 0;
}

参考资料

线段树式写法(From ERoRaIn),

评论

2 条评论,欢迎与作者交流。

正在加载评论...