欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Mediocre String Problem Gym - 101981M (拓展KMP + PAM回文自动机)

程序员文章站 2022-06-08 08:24:28
...

Mediocre String Problem Gym - 101981M (拓展KMP + PAM回文自动机)

题目大意

给定两个字符串,s,t然后询问有多少个三元组满足 s[i~j] + t[1, k]并且i到j的长度大于k使得拼接的字符串是个回文字符串。

思路 & 代码

将s逆序,得到ss,对其和t求扩展KMP
得到ss[i…n-1] 和 t[0…m-1]的最长公共前缀。
然后其每个前缀的长度 * 以i结尾后缀回文的数量再求个和就ok

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int MAXN = 1e6 + 100;
const int N = 26;
char s[MAXN], t[MAXN];
int extend[MAXN], nxt[MAXN];
int n,m;
void pre_EKMP(char x[], int m, int nxt[]) {
    nxt[0] = m;
    int j = 0;
    while(j + 1 < m && x[j] == x[j+1]) j++;
    nxt[1] = j;
    int k = 1;
    for(int i = 2; i < m; i++) {
        int p = nxt[k] + k - 1;
        int L = nxt[i - k];
        if(i + L < p + 1) nxt[i] = L;
        else {
            j = max(0, p - i + 1);
            while(i + j < m && x[i + j] == x[j]) j++;
            nxt[i] = j;
            k = i;
        }
    }
}
void EKMP(char x[], int m, char y[], int n, int nxt[], int extend[]) {
    pre_EKMP(x, m, nxt);
    int j = 0;
    while(j < n && j < m && x[j] == y[j]) j++;
    extend[0] = j;
    int k = 0;
    for(int i = 1; i < n; i++) {
        int p = extend[k] + k - 1;
        int L = nxt[i - k];
        if(i + L < p + 1) extend[i] = L;
        else {
            j = max(0, p - i + 1);
            while(i + j < n && j < m && y[i + j] == x[j]) j++;
            extend[i] = j;
            k = i;
        }
    }
}
struct PAM {
    int ch[MAXN][N];// 往一个字符串左右添加一个字符对应的结点
    int fail[MAXN]; // 对于后缀回文串来说,失配后能到达的下一个后缀回文
    int cnt[MAXN];  // 本质不同的结点数量,最后需要count函数得到正确结果
    int num[MAXN];  // num[i]表示以i结尾的后缀回文串的数量
    int len[MAXN];  // 每个节点代表的回文串长度
    int S[MAXN];    // 存字符串
    int last;       // 对应最长后缀回文节点
    int n;          // 字符集数量
    int p;          // 回文树结点数量
    int newnode(int l) {
        for(int i = 0; i < N; i++) ch[p][i] = 0;
        cnt[p] = num[p] = 0;
        len[p] = l;
        return p++;
    }
    void init() {
        p = 0;
        newnode(0);  // 偶结点
        newnode(-1); // 奇结点
        last = 0;
        n = 0;
        S[n] = -1;
        fail[0] = 1;
    }
    int get_fail(int x) {
        while(S[n - len[x] - 1] != S[n]) x = fail[x];
        return x;
    }
    int add(int c) {
        c-='a';
        S[++n] = c;
        int cur = get_fail(last);
        if(!ch[cur][c]) {
            int now = newnode(len[cur] + 2);// 新添加一个结点
            fail[now] = ch[get_fail(fail[cur])][c];
            ch[cur][c] = now;
            num[now] = num[fail[now]] + 1;
        }
        last = ch[cur][c];
        cnt[last]++;
        return num[last];
    }
    void count() {
        for(int i = p - 1; i >= 0; --i) {
            cnt[fail[i]] += cnt[i];
        }
    }
} pam;
int num[MAXN];
int main() {
    //freopen("/Users/maoxiangsun/MyRepertory/input.txt", "r", stdin);
    scanf("%s%s",s,t);
    n = (int)strlen(s);
    m = (int)strlen(t);
    reverse(s, s + n);
    EKMP(t, m, s, n, nxt, extend);
    // OK
    pam.init();
    for(int i = 0; i < n; i++) {
        num[i] = pam.add(s[i]);
    }
    ll res = 0;
    for(int i = 1; i < n ; i++) {
        res += 1LL * num[i-1] * extend[i];
    }
    cout << res << endl;
    return 0;
}