欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

AC自动机

程序员文章站 2022-06-06 17:41:57
...

概述

AC 自动机是 以 TRIE 的结构为基础 ,结合 KMP 的思想 建立的。

简单来说,建立一个 AC 自动机有两个步骤:

  1. 基础的 TRIE 结构:将所有的模式串构成一棵 TrieTrie
  2. KMP 的思想:对 TrieTrie 树上所有的结点构造失配指针。

然后就可以利用它进行多模式匹配了。

字典树构建

AC 自动机在初始时会将若干个模式串丢到一个 TRIE 里,然后在 TRIE 上建立 AC 自动机。这个 TRIE 就是普通的 TRIE,该怎么建怎么建。

这里需要仔细解释一下 TRIE 的结点的含义,尽管这很小儿科,但在之后的理解中极其重要。TRIE 中的结点表示的是某个模式串的前缀。我们在后文也将其称作状态。一个结点表示一个状态,TRIE 的边就是状态的转移。

形式化地说,对于若干个模式串 s1,s2sns_1,s_2\dots s_n ,将它们构建一棵字典树后的所有状态的集合记作 QQ

失配指针

AC 自动机利用一个 fail 指针来辅助多模式串的匹配。

状态 uu 的 fail 指针指向另一个状态 vv ,其中 vQv\in Q ,且 vvuu 的最长后缀(即在若干个后缀状态中取最长的一个作为 fail 指针)。对于学过 KMP 的朋友,我在这里简单对比一下这里的 fail 指针与 KMP 中的 next 指针:

  1. 共同点:两者同样是在失配的时候用于跳转的指针。
  2. 不同点:next 指针求的是最长 Border(即最长的相同前后缀),而 fail 指针指向所有模式串的前缀中匹配当前状态的最长后缀。

因为 KMP 只对一个模式串做匹配,而 AC 自动机要对多个模式串做匹配。有可能 fail 指针指向的结点对应着另一个模式串,两者前缀不同。

没看懂上面的对比不要急(也许我的脑回路和泥萌不一样是吧),你只需要知道,AC 自动机的失配指针指向当前状态的最长后缀状态即可。

AC 自动机在做匹配时,同一位上可匹配多个模式串。

构建指针

下面介绍构建 fail 指针的 基础思想 :(强调!基础思想!基础!)

构建 fail 指针,可以参考 KMP 中构造 Next 指针的思想。

考虑字典树中当前的结点 uuuu 的父结点是 pppp 通过字符 c 的边指向 uu ,即 trie[p,c]=utrie[p,c]=u 。假设深度小于 uu 的所有结点的 fail 指针都已求得。

  1. 如果 trie[fail[p],c]trie[fail[p],c] 存在:则让 u 的 fail 指针指向 trie[fail[p],c]trie[fail[p],c] 。相当于在 ppfail[p]fail[p] 后面加一个字符 c ,分别对应 uufail[u]fail[u]
  2. 如果 trie[fail[p],c]trie[fail[p],c] 不存在:那么我们继续找到 trie[fail[fail[p]],c]trie[fail[fail[p]],c] 。重复 1 的判断过程,一直跳 fail 指针直到根结点。
  3. 如果真的没有,就让 fail 指针指向根结点。

如此即完成了 fail[u]fail[u] 的构建。

例子

下面放一张 GIF 帮助大家理解。对字符串 i he his she hers 组成的字典树构建 fail 指针:

  1. 黄色结点:当前的结点 uu
  2. 绿色结点:表示已经 BFS 遍历完毕的结点,
  3. 橙色的边:fail 指针。
  4. 红色的边:当前求出的 fail 指针。

AC自动机

我们重点分析结点 6 的 fail 指针构建:

AC自动机

找到 6 的父结点 5, fail[5]=10fail[5]=10 。然而 10 结点没有字母 s 连出的边;继续跳到 10 的 fail 指针, fail[10]=0fail[10]=0 。发现 0 结点有字母 s 连出的边,指向 7 结点;所以 fail[6]=7fail[6]=7 。最后放一张建出来的图

AC自动机


import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;

public class AhoCorasickAutomation {
    /*本示例中的AC自动机只处理英文类型的字符串,所以数组的长度是128*/
    private static final int ASCII = 128;

    /*AC自动机的根结点,根结点不存储任何字符信息*/
    private Node root;

    /*待查找的目标字符串集合*/
    private List<String> target;

    /*内部静态类,用于表示AC自动机的每个结点,在每个结点中我们并没有存储该结点对应的字符*/
    private static class Node {

        /*如果该结点是一个终点,即,从根结点到此结点表示了一个目标字符串,则str != null, 且str就表示该字符串*/
        String str;

        /*ASCII == 128, 所以这里相当于128叉树*/
        Node[] table = new Node[ASCII];

        /*当前结点的孩子结点不能匹配文本串中的某个字符时,下一个应该查找的结点*/
        Node fail;

        boolean isWord() {
            return str != null;
        }

    }

    /*target表示待查找的目标字符串集合*/
    public AhoCorasickAutomation(List<String> target) {
        root = new Node();
        this.target = target;
        buildTrieTree();
        build_AC_FromTrie();
    }

    /*由目标字符串构建Trie树*/
    private void buildTrieTree() {
        for (String targetStr : target) {
            Node curr = root;
            for (int i = 0; i < targetStr.length(); i++) {
                char ch = targetStr.charAt(i);
                if (curr.table[ch] == null) {
                    curr.table[ch] = new Node();
                }
                curr = curr.table[ch];
            }
            /*将每个目标字符串的最后一个字符对应的结点变成终点*/
            curr.str = targetStr;
        }
    }

    /*由Trie树构建AC自动机,本质是一个自动机,相当于构建KMP算法的next数组*/
    private void build_AC_FromTrie() {
        /*广度优先遍历所使用的队列*/
        LinkedList<Node> queue = new LinkedList<>();

        /*单独处理根结点的所有孩子结点*/
        for (Node x : root.table) {
            if (x != null) {
                /*根结点的所有孩子结点的fail都指向根结点*/
                x.fail = root;
                queue.addLast(x);/*所有根结点的孩子结点入列*/
            }
        }

        while (!queue.isEmpty()) {
            /*确定出列结点的所有孩子结点的fail的指向*/
            Node p = queue.removeFirst();
            for (int i = 0; i < p.table.length; i++) {
                if (p.table[i] != null) {
                    /*孩子结点入列*/
                    queue.addLast(p.table[i]);
                    /*从p.fail开始找起*/
                    Node failTo = p.fail;
                    while (true) {
                        /*说明找到了根结点还没有找到*/
                        if (failTo == null) {
                            p.table[i].fail = root;
                            break;
                        }

                        /*说明有公共前缀*/
                        if (failTo.table[i] != null) {
                            p.table[i].fail = failTo.table[i];
                            break;
                        } else {/*继续向上寻找*/
                            failTo = failTo.fail;
                        }
                    }
                }
            }
        }
    }

    /*在文本串中查找所有的目标字符串*/
    public HashMap<String, List<Integer>> find(String text) {
        /*创建一个表示存储结果的对象*/
        /*表示在文本字符串中查找的结果,key表示目标字符串, value表示目标字符串在文本串出现的位置*/
        HashMap<String, List<Integer>> result = new HashMap<>();
        for (String s : target) {
            result.put(s, new LinkedList<>());
        }

        Node curr = root;
        int i = 0;
        while (i < text.length()) {
            /*文本串中的字符*/
            char ch = text.charAt(i);

            /*文本串中的字符和AC自动机中的字符进行比较*/
            if (curr.table[ch] != null) {
                /*若相等,自动机进入下一状态*/
                curr = curr.table[ch];

                if (curr.isWord()) {
                    result.get(curr.str).add(i - curr.str.length() + 1);
                }

                /*这里很容易被忽视,因为一个目标串的中间某部分字符串可能正好包含另一个目标字符串,
                 * 即使当前结点不表示一个目标字符串的终点,但到当前结点为止可能恰好包含了一个字符串*/
                if (curr.fail != null && curr.fail.isWord()) {
                    result.get(curr.fail.str).add(i - curr.fail.str.length() + 1);
                }

                /*索引自增,指向下一个文本串中的字符*/
                i++;
            } else {
                /*若不等,找到下一个应该比较的状态*/
                curr = curr.fail;

                /*到根结点还未找到,说明文本串中以ch作为结束的字符片段不是任何目标字符串的前缀,
                 * 状态机重置,比较下一个字符*/
                if (curr == null) {
                    curr = root;
                    i++;
                }
            }
        }
        return result;
    }


    public static void main(String[] args) {
        List<String> target = new ArrayList<>();
        target.add("abcdef");
        target.add("abhab");
        target.add("bcd");
        target.add("cde");
        target.add("cdfkcdf");

        String text = "bcabcdebcedfabcdefababkabhabk";

        AhoCorasickAutomation aca = new AhoCorasickAutomation(target);
        HashMap<String, List<Integer>> result = aca.find(text);

        System.out.println(text);
        for (Entry<String, List<Integer>> entry : result.entrySet()) {
            System.out.println(entry.getKey() + " : " + entry.getValue());
        }

    }
}