欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

「模拟赛20180306」回忆树 memory LCA+KMP+AC自动机+树状数组

程序员文章站 2022-03-18 16:43:28
题目描述 回忆树是一棵树,树边上有小写字母。 一次回忆是这样的:~~你想起过往,触及心底……~~唔,不对,我们要说题目。 这题中我们认为回忆是这样的:给定 $2$ 个点 $u,v$ ($u$ 可能等于 $v$)和一个非空字符串 $s$ ,问从 $u$ 到 $v$ 的简单路径上的所有边按照到 $u$ ......

题目描述

回忆树是一棵树,树边上有小写字母。

一次回忆是这样的:你想起过往,触及心底……唔,不对,我们要说题目。

这题中我们认为回忆是这样的:给定 \(2\) 个点 \(u,v\) (\(u\) 可能等于 \(v\))和一个非空字符串 \(s\) ,问从 \(u\)\(v\) 的简单路径上的所有边按照到 \(u\) 的距离从小到大的顺序排列后,询问边上的字符依次拼接形成的字符串中给定的串 \(s\) 出现了多少次。

输入

第一行 \(2\) 个整数,依次为树中点的个数 \(n\) 和回忆的次数 \(m\)
接下来 \(n-1\) 行,每行 \(2\) 个整数 \(u,v\)\(1\) 个小写字母 \(c\) ,表示回忆树的点\(u,v\)之间有一条边,边上的字符为\(c\)
接下来 \(2m\) 行表示 \(m\) 次回忆,每次回忆 \(2\) 行:第 \(1\)\(2\) 个整数 \(u,v\),第 \(2\) 行给出回忆的字符串 \(s\)

输出

对于每次回忆,输出串 \(s\) 出现的次数。

样例

样例输入

12 3
1 2 w
2 3 w
3 4 x
4 5 w
5 6 w
6 7 x
7 8 w
8 9 w
9 10 x
10 11 w
11 12 w
1 7
wwx
1 12
www
1 12
w

样例输出

2
0
8

数据范围

\(1≤n,m≤10^5\)
询问字符串的总长度不超过\(3\times10^5\)

题解

这是一道神题,做法优美而且巧妙(同时也很恶心)。

既然是树链上的询问,就不能不让人想到利用\(LCA\)\(u\xrightarrow{}v\)的路径转化成\(u\xrightarrow{}lca\)\(lca\xrightarrow{}v\)的两条路径了。

那么我们就可以把询问分成三部分。

  1. \(lca\xrightarrow{}u\)\(s\)的反串出现了多少次
  2. \(lca\xrightarrow{}v\)\(s\)出现了多少次
  3. 跨越\(lca\)时,\(s\)出现了多少次

可以发现,第一部分和第二部分其实是类似的问题,我们先放一放。


那么我们考虑第三个问题,好像没有什么很简单的方法,于是我们考虑暴力。
很容易发现这一种情况下涉及的字符串不长,只有\(u\xrightarrow{}lca\)路径上的\(\left|s\right|\)个和\(v\xrightarrow{}lca\)路径上的\(\left|s\right|\)个。我们可以暴力取出这一段字符,然后做一次\(KMP\),这样一次的复杂度是\(O(\left|s\right|)\),总时间复杂度就是\(O(\sum\left|s\right|)\),完全可以过。


现在就剩前两个问题了。我们发现询问串太多,一个个做显然很吃力,这时,\(AC\)自动机的方法就呼之欲出了。我们把所有询问串做成一个\(AC\)自动机,把整棵树带进去匹配即可。

匹配的过程很简单,模拟字符串匹配的时候即可,从根开始,依次访问子树,进栈的时候答案加,出栈的时候答案减即可,然后把询问的区间标记一下,到达合适的区间就计算答案。

但是这样还有一个问题,\(AC\)自动机上的答案是要给\(fail\)链上的所有点增加的,暴力加显然会超时。于是我们修改一下做法,预处理出\(fail\)树的先序遍历序列,然后建立树状数组(一个比较显然的性质,同一颗子树的遍历序列是连续的)。于是修改的时候单点修改,查询的时候查询\(fail\)树上的子树和即可。


然而,这道题说起来很轻巧,却是一道码农题……并且还卡常数……卡常数!!!
所以,我还是把我\(250\)行的代码拿出来吧……
\(Code:\)

#include <queue> 
#include <vector> 
#include <cstdio> 
#include <cstring> 
#include <algorithm> 
using namespace std; 
#define M 600005 
queue<int>q; 
int n, m; 
int f[25][M], dep[M], fa[M]; 
int L[M], R[M], ans[M], ens[M], plc[M]; 
vector<int>B[M], E[M]; 
char len[M], top[M], S[M]; 
struct node 
{ 
    int fir[M], tar[M], nex[M], cnt; 
}T1, T2; 
void add(int a, int b, char c) 
{ 
    ++T1.cnt; 
    T1.tar[T1.cnt] = b; 
    len[T1.cnt] = c; 
    T1.nex[T1.cnt] = T1.fir[a]; 
    T1.fir[a] = T1.cnt; 
} 
void add(int a, int b) 
{ 
    ++T2.cnt; 
    T2.tar[T2.cnt] = b; 
    T2.nex[T2.cnt] = T2.fir[a]; 
    T2.fir[a] = T2.cnt; 
} 
//dfs-begin 
void dfs(int r) 
{ 
    for (int i = T1.fir[r]; i; i = T1.nex[i]) 
    { 
        int v = T1.tar[i]; 
        if (v != fa[r]) 
        { 
            fa[v] = r; 
            dep[v] = dep[r] + 1; 
            top[v] = len[i]; 
            dfs(v); 
        } 
    } 
} 
//dfs-end 
//LCA-begin 
int LCA(int u, int v) 
{ 
    if (dep[u] < dep[v]) 
        swap(u, v); 
    int k = dep[u] - dep[v]; 
    for (int i = 20; i >= 0; i--) 
        if (k & 1 << i) 
            u = f[i][u]; 
    if (u == v) 
        return u; 
    for (int i = 20; i >= 0; i--) 
        if (f[i][u] != f[i][v]) 
            u = f[i][u], v = f[i][v]; 
    return f[0][u]; 
} 
int getk(int u, int k) 
{ 
    for (int i = 0; i <= 20; i++) 
        if (k & 1 << i) 
            u = f[i][u]; 
    return u; 
} 
//LCA-end 
//KMP-begin 
char K[M]; 
int nex[M]; 
void KMP(int a, int b, int c, int ls, int w) 
{ 
    int len = 0; 
    while (a != c) 
        K[len++] = top[a], a = fa[a]; 
    int z = dep[b] - dep[c]; 
    len += dep[b] - dep[c]; 
    while (b != c) 
        K[--len] = top[b], b = fa[b]; 
    len += z; 
    K[len] = 0; 
    nex[0] = -1; 
    int i = 0, j = -1, ans = 0; 
    while(i < ls) 
    { 
        if (j == -1 || S[i] == S[j]) 
            nex[++i] = ++j; 
        else
            j = nex[j]; 
    } 
    i = 0, j = 0; 
    while(i < len) 
    { 
        if (j == ls) 
        { 
            ans++; 
            j = nex[j]; 
            continue; 
        } 
        if(j == -1 || K[i] == S[j]) 
            i++, j++; 
        else
            j = nex[j]; 
    } 
    if (j == ls) 
        ans++; 
    ens[w] += ans; 
} 
//KMP-end 
//ACTrie-begin 
struct ACTrie 
{ 
    int nex[M][30], fail[M], in[M], out[M]; 
    int root, cnt, tim, dfn[M], id[M]; 
    int tree[M]; 
    ACTrie(){root = cnt = 1;} 
    void Insert(char *S, int w) 
    { 
        int r = root, len = strlen(S); 
        for (int i = 0; i < len; i++) 
        { 
            int val = S[i] - 'a'; 
            if (!nex[r][val]) 
                nex[r][val] = ++cnt; 
            r = nex[r][val]; 
        } 
        plc[w] = r; 
    } 
    void Build() 
    { 
        int r = root; 
        fail[r] = r; 
        q.push(root); 
        while (!q.empty()) 
        { 
            r = q.front(); 
            q.pop(); 
            for (int i = 0; i < 26; i++) 
            { 
                if (nex[r][i]) 
                { 
                    int tmp = nex[fail[r]][i]; 
                    if (tmp && tmp != nex[r][i]) 
                        fail[nex[r][i]] = tmp; 
                    else
                        fail[nex[r][i]] = root; 
                    q.push(nex[r][i]); 
                } 
                else
                { 
                    int tmp = nex[fail[r]][i]; 
                    if (tmp) 
                        nex[r][i] = tmp; 
                    else
                        nex[r][i] = root; 
                } 
            } 
            if (r != root) 
                add(fail[r], r); 
        } 
    } 
    void DFS(int r) 
    { 
        dfn[r] = ++tim; 
        in[r] = tim; 
        id[tim] = r; 
        for (int i = T2.fir[r]; i; i = T2.nex[i]) 
        { 
            int v = T2.tar[i]; 
            DFS(v); 
        } 
        out[r] = tim; 
    } 
    void Update(int x, int v) 
    { 
        for (int i = x; i <= cnt; i += i & -i) 
            tree[i] += v; 
    } 
    int Getsum(int x) 
    { 
        int ans = 0; 
        for (int i = x; i; i -= i & -i) 
            ans += tree[i]; 
        return ans; 
    } 
}AC; 
//ACTrie-end 
void dfs2(int r, int now) 
{ 
    AC.Update(AC.dfn[now], 1); 
    int s = B[r].size(); 
    for (int i = 0; i < s; i++) 
        ens[(B[r][i] + 1)/ 2] -= AC.Getsum(AC.out[plc[B[r][i]]]) - AC.Getsum(AC.in[plc[B[r][i]]] - 1); 
    s = E[r].size(); 
    for (int i = 0; i < s; i++) 
        ens[(E[r][i] + 1)/ 2] += AC.Getsum(AC.out[plc[E[r][i]]]) - AC.Getsum(AC.in[plc[E[r][i]]] - 1); 
    for (int i = T1.fir[r]; i; i = T1.nex[i]) 
    { 
        int v = T1.tar[i]; 
        if (v != fa[r]) 
            dfs2(v, AC.nex[now][len[i] - 'a']); 
    } 
    AC.Update(AC.dfn[now], -1); 
} 
int main() 
{ 
    //freopen("memory.in", "r", stdin); 
    //freopen("memory.out", "w", stdout); 
    scanf("%d%d", &n, &m); 
    for (int i = 1; i < n; i++) 
    { 
        int a, b; 
        char c[5]; 
        scanf("%d%d%s", &a, &b, c); 
        add(a, b, c[0]); 
        add(b, a, c[0]); 
    } 
    dfs(1); 
    for (int i = 1; i <= n; i++) 
        f[0][i] = fa[i]; 
    for (int i = 1; i <= 20; i++) 
        for (int j = 1; j <= n; j++) 
            f[i][j] = f[i - 1][f[i - 1][j]]; 
    int w = 0; 
    for (int i = 1; i <= m; i++) 
    { 
        int u, v, c; 
        scanf("%d%d%s", &u, &v, S); 
        c = LCA(u, v); 
        int l1 = dep[u] - dep[c], l2 = dep[v] - dep[c], ls = strlen(S); 
        int a = getk(u, max(0, l1 - ls + 1)); 
        int b = getk(v, max(0, l2 - ls + 1)); 
        KMP(a, b, c, ls, i); 
        w++; 
        AC.Insert(S, w); 
        B[b].push_back(w); 
        E[v].push_back(w); 
        w++; 
        for (int i = 0; i < ls / 2; i++) 
            swap(S[i], S[ls - i - 1]); 
        AC.Insert(S, w); 
        B[a].push_back(w); 
        E[u].push_back(w); 
    } 
    AC.Build(); 
    AC.DFS(AC.root); 
    dfs2(1, AC.root); 
    for (int i = 1; i <= m; i++) 
        printf("%d\n", ens[i]); 
}