bzoj3473 字符串

题目描述

给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?

输入格式

第一行两个整数n,k。

接下来n行每行一个字符串。

输出格式

一行n个整数,第i个整数表示第i个字符串的答案。

样例输入

3 1
abc
a
ab

样例输出

6 1 3

说明

对于 100% 的数据,1<=n,k<=10^5,所有字符串总长不超过10^5,字符串只包含小写字母。


分析

首先,若一个串满足题意,那么这个串的所有后缀都满足题意。

那么,我们只要对于每个右端点,求出符合题意的最小左端点,并把答案加上$r-l+1$即可。这类问题可以使用SAM解决。

先建出广义后缀自动机,统计出每个状态$u$在几个字符串里出现过,记为$cnt(u)$。然后将当前询问的串在SAM上跑,每次只要$cnt(u)<k$就把$u$往$fa(u)$上跳,因为$fa(u)$是$u$的后缀,$cnt$会变大。直到$cnt(u)\ge k$,则$mx(u)$即为当前右端点的最长匹配长度。

如何求$cnt(u)$?可以对每个状态维护它所在的字符串编号集合,用set存储。一开始建完SAM时只有叶子节点的set是有元素的,然后向上启发式合并求出每个点的集合大小。

代码

#include <bits/stdc++.h>
using std::string;
typedef long long ll;

const int N = 1e5+50;
const int L = 1e5+50;
const int SIZ = L*2;

struct Graph {
    struct Edge {
        int to, nxt;
    } e[SIZ];
    int top, head[SIZ];

    Graph(): top(0) {
        memset(head, -1, sizeof head);
    }

    void add(int u, int v) {
        e[top] = (Edge){v, head[u]};
        head[u] = top++;
    }

    void add2(int u, int v) {
        add(u, v);
        add(v, u);
    }
};

struct Sam {
    int top, last;
    int ch[SIZ][26], fa[SIZ], mx[SIZ];
    std::set<int> ids[SIZ];
    Graph G;
    int cnt[SIZ];  // 表示这个状态在多少个串中出现过 

    Sam(): top(1) {}

    void extend(int c, int id) {
        int x = c - 'a', p = last, tmp = ch[p][x];
        if (tmp && mx[tmp] == mx[p] + 1) {
            ids[last=tmp].insert(id);
            return;
        }
        int np = last = ++top;
        mx[np] = mx[p] + 1;
        ids[np].insert(id);
        for (; ch[p][x] == 0; p = fa[p])
            ch[p][x] = np;
        if (p == 0) {
            fa[np] = 1;
            return;
        }
        int q = ch[p][x];
        if (mx[q] == mx[p] + 1) {
            fa[np] = q;
            return;
        }
        int nq = (mx[p] + 1 == mx[np] ? np : ++top);
        mx[nq] = mx[p] + 1;
        fa[nq] = fa[q];
        fa[q] = nq;
        if (nq != np) fa[np] = nq;
        memcpy(ch[nq], ch[q], sizeof ch[nq]);
        for (; ch[p][x] == q; p = fa[p])
            ch[p][x] = nq;
    }

    void insert(string s, int id) {
        last = 1;
        for (int len = s.size(), i = 0; i < len; i++)
            extend(s[i], id);
    }

    void dfs(int u) {
        for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
            int v = G.e[i].to;
            dfs(v);
            if (ids[v].size() > ids[u].size())
                std::swap(ids[u], ids[v]);
            for (std::set<int>::iterator it = ids[v].begin(); it != ids[v].end(); it++)
                ids[u].insert(*it);
            ids[v].clear();
        }
        cnt[u] = ids[u].size();
    }

    void pret() {
        for (int i = 2; i <= top; i++)
            G.add(fa[i], i);
        dfs(1);
    }
} sam;

int n, k;
string a[N];

ll solve(string s) {
    int u = 1;
    ll ans = 0;
    for (int len = s.size(), i = 0; i < len; i++) {
        u = sam.ch[u][s[i]-'a'];
        while (u != 0 && sam.cnt[u] < k)
            u = sam.fa[u];
        ans += sam.mx[u];
    }
    return ans;
}

int main() {
    scanf("%d%d", &n, &k);
    for (int i = 1; i <= n; i++) {
        std::cin >> a[i];
        sam.insert(a[i], i);
    }
    sam.pret();
    for (int i = 1; i <= n; i++)
        printf("%lld ", solve(a[i]));
    puts("");
    return 0;
}


转载请注明出处。


评论列表,共 0 条评论

    暂无评论

发表评论