欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【模板】回文自动机(PAM)

程序员文章站 2024-03-09 08:15:41
...

题目

题目背景
模板题,无背景(其实是我想不出背景)。

题目描述
给定一个字符串 ss。保证每个字符为小写字母。对于 ss 的每个位置,请求出以该位置结尾的回文子串个数。

这个字符串被进行了加密,除了第一个字符,其他字符都需要通过上一个位置的答案来解密。

具体地,若第 i(i\geq 1)i(i≥1) 个位置的答案是 kk,第 i+1i+1 个字符读入时的 \rm ASCIIASCII 码为 cc,则第 i+1i+1 个字符实际的 \rm ASCIIASCII 码为 (c-97+k)\bmod 26+97(c−97+k)mod26+97。所有字符在加密前后都为小写字母。

输入格式
一行一个字符串 ss 表示被加密后的串。

输出格式
一行, |s|∣s∣ 个整数。第 ii 个整数表示原串以第 ii 个字符结尾的回文子串个数。

输入输出样例
输入 #1 复制
debber
输出 #1 复制
1 1 1 2 1 1
输入 #2 复制
lwk
输出 #2 复制
1 1 2
输入 #3 复制
lxl
输出 #3 复制
1 1 1
说明/提示
对于 100%100% 的数据, 1\leq |s|\leq 5\times 10^51≤∣s∣≤5×10
5

思路

回文自动机

1.回文树
回文树有两个根, 分别表示偶回文和奇回文的根, 和Trie 很
类似.
假设某个节点代表的串为S, 那么其字符为c 的儿子节点代
表的串为cSc.
偶根代表的时候长度为0 的空串, 而奇根代表的是长度为
-1 的串.

2.回文自动机
字符串s 的回文自动机是这样的一个东西
把s 所有本质不同的回文子串建成一棵回文树.
然后对每个节点求fail 函数, 和AC 自动机的fail 函数含义
相同.

3.构造
之前已经介绍过,初始状态是有两个节点(0号和1号),它们的状态之前也已经说明。

我们从左往右加入字符串的每个字符

设lastlast为上一次插入字符的节点编号,初始时last = 0last=0
对于每个字符我们需要在回文自动机上找到以它结尾的最长回文子串。

对于字符ii,

while(s[i - b[last].len - 1] != s[i])last = b[last].fail
这样求出的lastlast即为新节点的父亲。

为什么呢?网上大多数dalao都说的很详细,还有高清大图。我比较懒我就大概讲讲自己的理解。

首先我们肯定要先考虑上一个位置在两边直接各加一个字符是否是回文串。

如果不是怎么办?那么我们就不停跳到它的最长回文后缀直到是回文为止。每次跳都能保证它是一开始lastlast的一个回文后缀也就是指以i - 1i−1结尾的回文子串,这样在它两边各加一个字符后,对应的一定是以当前字符为结尾的一个回文串。又由于每次都是跳最长的,所以第一次合法时取到的以ii结尾的回文子串也一定是最长的。

这个过程什么时候结束呢?由于len_1 = -1len
1
​ =−1,所以在跳到1的时候必然是自己等于自己,必然可以回文。

然后我们就像trietrie一样,在父亲下面生成孩子。

新孩子的failfail怎么计算呢?它就是从它父亲的failfail开始跳,跳到的第一个回文的位置。原因和上面类似,大家可以自己思考。

于是我们就愉快的建完了。

代码

#include<bits/stdc++.h>
using namespace std; 

const int N = 2e6 + 77; 
struct PAM_Trie
{
    int ch[26]; 
    int fail, len, num; 
}; 
struct PAM
{
    PAM_Trie b[N]; 
    int n, length, last, cnt, s[N]; 
    char c[N]; 

    PAM()
    {
        b[0].len = 0; b[1].len = -1; 
        b[0].fail = 1; b[1].fail = 0; 
        last = 0;
        cnt = 1; 
    }
    void read()
    {
        scanf("%s", c + 1); 
        length = strlen(c + 1); 
    }
    int get_fail(int x)
    {
        while(s[n - b[x].len - 1] != s[n])
        {
            x = b[x].fail; 
        }
        return x; 
    }
    void insert()
    {
        int p = get_fail(last); 
        if(!b[p].ch[s[n]])
        {
            b[++cnt].len = b[p].len + 2; 
            int tmp = get_fail(b[p].fail); 
            b[cnt].fail = b[tmp].ch[s[n]]; 
            b[cnt].num = b[b[cnt].fail].num + 1; 
            b[p].ch[s[n]] = cnt; 
        }
        last = b[p].ch[s[n]]; 
    }
    void solve()
    {
        int k = 0; 
        s[0] = 26; 
        for(n = 1; n <= length; n++)
        {
            c[n] = (c[n] - 97 + k) % 26 + 97; 
            s[n] = c[n] - 'a'; 
            insert(); 
            printf("%d ", b[last].num); 
            k = b[last].num; 
        }
    }
}P; 
int main()
{
    P.read();
    P.solve();
}