专栏文章

题解:P4493 [HAOI2018] 字串覆盖

P4493题解参与者 1已保存评论 0

文章操作

快速查看文章及其快照的属性,并进行相关操作。

当前评论
0 条
当前快照
1 份
快照标识符
@mio566ly
此快照首次捕获于
2025/12/02 13:32
3 个月前
此快照最后确认于
2025/12/02 13:32
3 个月前
查看原文
rl>2000r-l\gt 2000,暴力最多跳 n2000\frac{n}{2000} 次,每次找给定位置右侧 YY 第一次出现的最小位置。先把 A,BA, B 拼起来跑 SA。每次在 height 二分出和 YYlcp\mathrm{lcp} 长度 Y\ge |Y| 的区间 [l,r][l, r]。考虑在线段树上维护区间 saisa_i 出现的次数。可持久化就可以 O(logn)O(\log{n}) 查询。
51rl200051\le r-l\le 2000,因为很少,也用暴力方法。
rl50r-l\le 50,考虑倍增,对每个长度 len[1,51]len\in[1, 51] 预处理跳 2k2^k 步到的位置、答案。
考虑如何把 50nlogn50n\log{n}log\log 去掉。
查询一个位置跳 11 次的位置,原本是用主席树查的,方法是二分找到 height 数组一段极长区间 [l,r][l, r],使得这段的 lcp\mathrm{lcp} 都大于当前处理的长度 lenlen,然后在 [l,r][l, r] 的线段树上查。
对于一个 lenlen[l,r][l, r] 的集合是 [1,n][1, n] 截断若干 height 小于 lenlen 的位置的段的集合。若我们从前往后初始化倍增数组,对于一个 [l,r][l, r],每次查的位置是单增的。于是我们对每个 [l,r][l, r] 维护一个单调的指针。对于维护 [l,r][l, r],考虑倒着处理每个 lenlen,每次相当于要合并若干有序序列,归并排序可以做到线性。
有一个 O(n)O(n) 预处理倍增的黑科技,详见这篇文章。这样就消去了预处理倍增的 log\log
CPP
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rep(i, l, r) for(int i = (l); i <= (r); ++i)
#define per(i, r, l) for(int i = (r); i >= (l); --i)
const int N = 2e5 + 5;
int n, sz1, K, q;
char s[N];
// sa
int sa[N], rk[N], ork[N], buc[N], id[N];
void SA() {
    int m = 1 << 7, p = 0;
    rep(i, 1, n) buc[rk[i] = s[i]]++;
    rep(i, 1, m) buc[i] += buc[i - 1];
    per(i, n, 1) sa[buc[rk[i]]--] = i;
    for(int w = 1; ; m = p, p = 0, w <<= 1) {
        rep(i, n - w + 1, n) id[++p] = i;
        rep(i, 1, n) if(sa[i] > w) id[++p] = sa[i] - w;
        memset(buc, 0, sizeof(int) * (m + 1));
        memcpy(ork, rk, sizeof(int) * (n + 1));
        p = 0;
        rep(i, 1, n) buc[rk[i]]++;
        rep(i, 1, m) buc[i] += buc[i - 1];
        per(i, n, 1) sa[buc[rk[id[i]]]--] = id[i];
        rep(i, 1, n) {
            if(ork[sa[i - 1]] == ork[sa[i]] && ork[sa[i - 1] + w] == ork[sa[i] + w]) rk[sa[i]] = p;
            else rk[sa[i]] = ++p;
        }
        if(p == n) break;
    }
}
// height
int ht[N], mi[20][N];
void Get_height() {
    for(int i = 1, h = 0; i <= n; ++i) {
        if(h)h--;
        while(s[sa[rk[i] - 1] + h] == s[i + h])h++;
        ht[rk[i]] = h;
    }
    rep(i, 1, n)mi[0][i] = ht[i];
    rep(i, 1, 17)rep(s, 1, n - (1 << i) + 1)
    mi[i][s] = min(mi[i - 1][s], mi[i - 1][s + (1 << i - 1)]);
}
int lcp(int l, int r) {
    l++;
    int k = __lg(r - l + 1);
    return min(mi[k][l], mi[k][r - (1 << k) + 1]);
}

// sgt
const int SGTSZ = N * 20;
int rt[N], sgtcnt;
int lc[SGTSZ], rc[SGTSZ], sum[SGTSZ];
int ask(int u, int v, int l, int r, int x) {
    if(!u) return n + 1;
    if(sum[u] - sum[v] == 0) return n + 1;
    if(r < x) return n + 1;
    if(l == r)return l;
    if(l >= x) {
        int mid = l + r >> 1;
        if(sum[lc[u]] - sum[lc[v]] > 0) return ask(lc[u], lc[v], l, mid, x);
        return ask(rc[u], rc[v], mid + 1, r, x);
    }
    int mid = l + r >> 1;
    int t = ask(lc[u], lc[v], l, mid, x);
    if(t <= n) return t;
    return ask(rc[u], rc[v], mid + 1, r, x);
}
void add(int &u, int p, int l, int r, int x) {
    u = ++sgtcnt;
    lc[u] = lc[p], rc[u] = rc[p], sum[u] = sum[p];
    sum[u]++;
    if(l == r)return;
    int mid = l + r >> 1;
    if(x <= mid)add(lc[u], lc[p], l, mid, x);
    else add(rc[u], rc[p], mid + 1, r, x);
}


int GetL(int p, int len) {
    int al = 1, ar = p;
    while(al < ar) {
        int am = al + ar >> 1;
        if(lcp(am, p) >= len) ar = am;
        else al = am + 1;
    }
    return al;
}
int GetR(int p, int len) {
    int al = p, ar = n;
    while(al < ar) {
        int am = al + ar + 1 >> 1;
        if(lcp(p, am) >= len) al = am;
        else ar = am - 1;
    }
    return ar;
}

int bl[N];
struct BZ {
    int len[N], anc[N], fa[N];
    ll sum[N];
} jp[51];

int fa[N];
int find_set(int v) {
    return fa[v] == v ? v : fa[v] = find_set(fa[v]);
}

bool cmpSA(int x, int y) {
    return sa[x] < sa[y];
}

int ed[N], eds, ks[N], js[N], pt[N];
vector<int> advec[51];
int main() {
    scanf("%d%d", &sz1, &K);
    scanf("%s", s + n + 1);
    n += sz1;
    s[++n] = '&';
    scanf("%s", s + n + 1);
    n += sz1;
    SA();
    Get_height();

    rep(i, 1, n) {
        rt[i] = rt[i - 1];
        add(rt[i], rt[i], 1, n, sa[i]);
    }


    rep(i, 1, n) id[i] = i;
    ed[eds = 1] = 1;
    rep(i, 2, n + 1) if(i == n + 1 || ht[i] < 51) {
        if(i <= n) {
            advec[ht[i]].push_back(i);
        }
        sort(id + ed[eds], id + i, cmpSA);
        ed[++eds] = i;
    }
    rep(i, 1, eds) fa[i] = i;
    rep(i, 1, eds - 1) {
        ks[i] = ed[i], js[i] = ed[i + 1] - 1;
        rep(j, ed[i], ed[i + 1] - 1) bl[j] = i;
    }
    per(l, 51, 1) {
        BZ &now = jp[l - 1];

        rep(i, 1, eds) if(fa[i] == i) pt[i] = ks[i];

        rep(i, 1, sz1) {
            int bld = find_set(bl[rk[i]]);
            int &ptd = pt[bld];
            while(ptd <= js[bld] && sa[id[ptd]] < i + l) ptd++;
            if(ptd <= js[bld] && sa[id[ptd]] + l - 1 <= sz1) now.fa[i] = sa[id[ptd]];
            else now.fa[i] = 0;
        }
//        rep(i, 1, sz1){
//            int _l = ks[find_set(bl[rk[i]])];
//            int _r = js[find_set(bl[rk[i]])];
//            int np = ask(rt[_r], rt[_l-1], 1, n, i + l);
//            if(np+l-1 > sz1) continue;
//            now.fa[i] = np;
//        }
        per(i, sz1, 1) {
            int f = now.fa[i], ff = now.anc[f], fff = now.anc[ff];
            if(f && ff && fff &&
                    now.len[f] == now.len[ff]) {
                now.len[i] = now.len[f] * 2 + 1;
                now.anc[i] = fff;
                now.sum[i] = i + now.sum[f] + now.sum[ff];
            } else {
                now.len[i] = 1;
                now.anc[i] = f;
                now.sum[i] = i;
            }
        }

        for(int i : advec[l - 1]) {
            int x = bl[i - 1], y = bl[i];
            x = find_set(x), y = find_set(y);
            inplace_merge(id + ks[x], id + ks[y], id + js[y] + 1, cmpSA);
            fa[x] = y;
            ks[y] = ks[x];
        }
    }

    scanf("%d", &q);
    while(q--) {
        int s, t, l, r;
        scanf("%d%d%d%d", &s, &t, &l, &r);
        int _l, _r;
        _l = GetL(rk[sz1 + 1 + l], r - l + 1);
        _r = GetR(rk[sz1 + 1 + l], r - l + 1);

        ll ans = 0, tot = 0;
        if(r - l >= 51) {
            int p = s, np;
            while(p <= t) {
                np = ask(rt[_r], rt[_l - 1], 1, n, p);
                if(np + r - l > t) break;
                ans += np, tot++;
                p = np + r - l + 1;
            }
            printf("%lld\n", tot * K - ans);
            continue;
        }
        int p = ask(rt[_r], rt[_l - 1], 1, n, s);
        BZ &now = jp[r - l];
        while(p + r - l <= t) {
            if(now.anc[p] && now.anc[p] + r - l <= t) {
                ans += now.sum[p];
                tot += now.len[p];
                p = now.anc[p];
            } else if(now.fa[p] && now.fa[p] + r - l <= t) {
                ans += p;
                tot++;
                p = now.fa[p];
            } else break;
        }
        if(p + r - l <= t) {
            ans += p;
            tot++;
        }
        printf("%lld\n", tot * K - ans);
    }
    return 0;
}

评论

0 条评论,欢迎与作者交流。

正在加载评论...