bzoj2555 SubString

题目描述

懒得写背景了,给你一个字符串init,要求你支持两个操作

(1):在当前字符串的后面插入一个字符串

(2):询问字符串s在当前字符串中出现了几次?(作为连续子串)

你必须在线支持这些操作。

输入格式

第一行一个数Q表示操作个数

第二行一个字符串表示初始字符串init

接下来Q行,每行2个字符串Type,Str

Type是ADD的话表示在后面插入字符串。

Type是QUERY的话表示询问某字符串在当前字符串中出现了几次。

为了体现在线操作,你需要维护一个变量mask,初始值为0

void decode(char s[], int mask) {
    int len = strlen(s);
    for (int i = 0; i < len; i++) {
        mask = (mask * 131 + i) % len;
        std::swap(s[i], s[mask]);
    }
}

读入串Str之后,使用这个过程将之解码成真正询问的串TrueStr。

输出格式

询问的时候,对TrueStr询问后输出一行答案Result

然后mask=maskxorResult

插入的时候,将TrueStr插到当前字符串后面即可。

样例输入

2
A
QUERY B
ADD BBABBBBAAB

样例输出

0

说明

HINT:ADD和QUERY操作的字符串都需要解压

长度 <= 600000,询问次数<= 10000,询问总长度<= 3000000

新加数据一组--2015.05.20


分析

看似很简单实际上没那么简单的题。

题目大意:多组询问求输入的串在s中出现几次,并且需要支持在s末尾插入字符。

如果没有插入操作,可以建出SAM,找到输入的串在哪个状态,输出这个状态的right集合大小即可。

加上了插入操作,相当于动态维护right集合大小。建SAM时需要加边断边,增删边时只会影响parent树到根路径上的结点,也就是说要做树上的加边删边和路径修改,可以用LCT解决。

看起来很麻烦,但因为parent树是有向的,所以不用写LCT的换根操作,代码会短一些。

复杂度$O(n\log^2 n)$。

代码

#include <bits/stdc++.h>

const int L = 3e6+50;
const int SIZ = 2e6+50;

namespace Lct {
    struct Node {
        int ch[2], fa;
        int w;
        int add;
    } tr[SIZ];

    int getlr(int x) {
        return tr[tr[x].fa].ch[1] == x;
    }

    bool isRoot(int x) {
        return tr[tr[x].fa].ch[getlr(x)] != x;
    }

    void connect(int f, int lr, int x) {
        if (f) tr[f].ch[lr] = x;
        if (x) tr[x].fa = f;
    }

    void putAdd(int x, int val) {
        tr[x].w += val;
        tr[x].add += val;
    }

    void pushDown(int x) {
        for (int i = 0; i < 2; i++)
            putAdd(tr[x].ch[i], tr[x].add);
        tr[x].add = 0;
    }

    void pushAll(int x) {
        if (!isRoot(x))
            pushAll(tr[x].fa);
        pushDown(x);
    }

    void rotate(int x) {
        int y = tr[x].fa, z = tr[y].fa;
        bool lr = getlr(x);
        if (!isRoot(y))
            tr[z].ch[getlr(y)] = x;
        tr[x].fa = z;
        connect(y, lr, tr[x].ch[!lr]);
        connect(x, !lr, y);
    }

    void splay(int x) {
        for (pushAll(x); !isRoot(x); rotate(x)) {
            int y = tr[x].fa;
            if (!isRoot(y))
                rotate(getlr(x) == getlr(y) ? y : x);
        }
    }

    void access(int x) {
        int t = x;
        for (int y = 0; x; y = x, x = tr[x].fa) {
            splay(x);
            tr[x].ch[1] = y;
        }
        splay(t);
    }

    void cut(int x) {
        access(x);
        int &ls = tr[x].ch[0];
        putAdd(ls, -tr[x].w);
        tr[ls].fa = 0;
        ls = 0;
    }

    void link(int x, int y) {
        tr[x].fa = y;
        access(y);
        putAdd(y, tr[x].w);
    }

    void setw(int x, int val) {
        tr[x].w = val;
    }

    int getw(int x) {
        pushAll(x);
        return tr[x].w;
    }
}

struct Sam {
    int top, last;
    int ch[SIZ][26], fa[SIZ], mx[SIZ];

    Sam(): top(1), last(1) {}

    void extend(char c) {
        int x = c - 'A', p = last;
        int np = last = ++top;
        Lct::setw(np, 1);
        mx[np] = mx[p] + 1;
        for (; p && !ch[p][x]; p = fa[p])
            ch[p][x] = np;
        if (!p) {
            fa[np] = 1;
            Lct::link(np, 1);
            return;
        }
        int q = ch[p][x];
        if (mx[q] == mx[p] + 1) {
            fa[np] = q;
            Lct::link(np, q);
            return;
        }
        int nq = ++top;
        mx[nq] = mx[p] + 1;
        Lct::cut(q);
        Lct::link(nq, fa[q]);
        Lct::link(q, nq);
        Lct::link(np, nq);
        fa[nq] = fa[q];
        fa[q] = fa[np] = nq;
        memcpy(ch[nq], ch[q], sizeof ch[nq]);
        for (; ch[p][x] == q; p = fa[p])
            ch[p][x] = nq;
    }

    void insert(char s[]) {
        for (int len = strlen(s), i = 0; i < len; i++)
            extend(s[i]);
    }

    int query(char s[]) {
        int u = 1;
        for (int len = strlen(s), i = 0; i < len; i++) {
            u = ch[u][s[i]-'A'];
            if (!u) return 0;
        }
        return Lct::getw(u);
    }
} sam;

char s[L];

void decode(int mask) {
    int len = strlen(s);
    for (int i = 0; i < len; i++) {
        mask = (mask * 131 + i) % len;
        std::swap(s[i], s[mask]);
    }
}

int main() {
    int m; scanf("%d", &m);
    scanf("%s", s);
    sam.insert(s);
    for (int mask = 0; m--; ) {
        char opt[10];
        scanf("%s%s", opt, s);
        decode(mask);
        if (opt[0] == 'A') sam.insert(s);
        else {
            int ans = sam.query(s);
            printf("%d\n", sam.query(s));
            mask ^= ans;
        }
    }
    return 0;
}


转载请注明出处。


评论列表,共 0 条评论

    暂无评论

发表评论