欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Java中实现双数组Trie树实例

程序员文章站 2024-03-02 19:48:28
传统的trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。 双数组trie就是优化了空间的trie树,...

传统的trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。

双数组trie就是优化了空间的trie树,原理本文就不讲了,请参考an efficient implementation of trie structures,本程序的编写也是参考这篇论文的。

关于几点论文没有提及的细节和与论文不一一致的实现:

1.对于插入字符串,如果有一个字符串是另一个字符串的子串的话,我是将结束符也作为一条边,产生一个新的结点,这个结点新节点的base我置为0

所以一个字符串结束也有2中情况:一个是base值为负,存储剩余字符(可能只有一个结束符)到tail数组;另一个是base为0。

所以在查询的时候要考虑一下这两种情况

2.对于第一种冲突(论文中的case 3),可能要将tail中的字符串取出一部分,作为边放到索引中。论文是使用将尾串左移的方式,我的方式直接修改base值,而不是移动尾串。


下面是java实现的代码,可以处理相同字符串插入,子串的插入等情况

复制代码 代码如下:

/*
 * name:   double array trie
 * author: yaguang ding
 * mail: dingyaguang117@gmail.com
 * blog: blog.csdn.net/dingyaguang117
 * date:   2012/5/21
 * note: a word ends may be either of these two case:
 * 1. base[cur_p] == pos  ( pos<0 and tail[-pos] == 'end_char' )
 * 2. check[base[cur_p] + code('end_char')] ==  cur_p
 */


import java.util.arraylist;
import java.util.hashmap;
import java.util.map;
import java.util.arrays;


public class doublearraytrie {
 final char end_char = '\0';
 final int default_len = 1024;
 int base[]  = new int [default_len];
 int check[] = new int [default_len];
 char tail[] = new char [default_len];
 int pos = 1;
 map<character ,integer> charmap = new hashmap<character,integer>();
 arraylist<character> charlist = new arraylist<character>();
 
 public doublearraytrie()
 {
  base[1] = 1;
  
  charmap.put(end_char,1);
  charlist.add(end_char);
  charlist.add(end_char);
  for(int i=0;i<26;++i)
  {
   charmap.put((char)('a'+i),charmap.size()+1);
   charlist.add((char)('a'+i));
  }
  
 }
 private void extend_array()
 {
  base = arrays.copyof(base, base.length*2);
  check = arrays.copyof(check, check.length*2);
 }
 
 private void extend_tail()
 {
  tail = arrays.copyof(tail, tail.length*2);
 }
 
 private int getcharcode(char c)
 {
  if (!charmap.containskey(c))
  {
   charmap.put(c,charmap.size()+1);
   charlist.add(c);
  }
  return charmap.get(c);
 }
 private int copytotailarray(string s,int p)
 {
  int _pos = pos;
  while(s.length()-p+1 > tail.length-pos)
  {
   extend_tail();
  }
  for(int i=p; i<s.length();++i)
  {
   tail[_pos] = s.charat(i);
   _pos++;
  }
  return _pos;
 }
 
 private int x_check(integer []set)
 {
  for(int i=1; ; ++i)
  {
   boolean flag = true;
   for(int j=0;j<set.length;++j)
   {
    int cur_p = i+set[j];
    if(cur_p>= base.length) extend_array();
    if(base[cur_p]!= 0 || check[cur_p]!= 0)
    {
     flag = false;
     break;
    }
   }
   if (flag) return i;
  }
 }
 
 private arraylist<integer> getchildlist(int p)
 {
  arraylist<integer> ret = new arraylist<integer>();
  for(int i=1; i<=charmap.size();++i)
  {
   if(base[p]+i >= check.length) break;
   if(check[base[p]+i] == p)
   {
    ret.add(i);
   }
  }
  return ret;
 }
 
 private boolean tailcontainstring(int start,string s2)
 {
  for(int i=0;i<s2.length();++i)
  {
   if(s2.charat(i) != tail[i+start]) return false;
  }
  
  return true;
 }
 private boolean tailmatchstring(int start,string s2)
 {
  s2 += end_char;
  for(int i=0;i<s2.length();++i)
  {
   if(s2.charat(i) != tail[i+start]) return false;
  }
  return true;
 }
 
 
 public void insert(string s) throws exception
 {
  s += end_char;
  
  int pre_p = 1;
  int cur_p;
  for(int i=0; i<s.length(); ++i)
  {
   //获取状态位置
   cur_p = base[pre_p]+getcharcode(s.charat(i));
   //如果长度超过现有,拓展数组
   if (cur_p >= base.length) extend_array();
   
   //空闲状态
   if(base[cur_p] == 0 && check[cur_p] == 0)
   {
    base[cur_p] = -pos;
    check[cur_p] = pre_p;
    pos = copytotailarray(s,i+1);
    break;
   }else
   //已存在状态
   if(base[cur_p] > 0 && check[cur_p] == pre_p)
   {
    pre_p = cur_p;
    continue;
   }else
   //冲突 1:遇到 base[cur_p]小于0的,即遇到一个被压缩存到tail中的字符串
   if(base[cur_p] < 0 && check[cur_p] == pre_p)
   {
    int head = -base[cur_p];
    
    if(s.charat(i+1)== end_char && tail[head]==end_char) //插入重复字符串
    {
     break;
    }
    
    //公共字母的情况,因为上一个判断已经排除了结束符,所以一定是2个都不是结束符
    if (tail[head] == s.charat(i+1))
    {
     int avail_base = x_check(new integer[]{getcharcode(s.charat(i+1))});
     base[cur_p] = avail_base;
     
     check[avail_base+getcharcode(s.charat(i+1))] = cur_p;
     base[avail_base+getcharcode(s.charat(i+1))] = -(head+1);
     pre_p = cur_p;
     continue;
    }
    else
    {
     //2个字母不相同的情况,可能有一个为结束符
     int avail_base ;
     avail_base = x_check(new integer[]{getcharcode(s.charat(i+1)),getcharcode(tail[head])});
     
     base[cur_p] = avail_base;
     
     check[avail_base+getcharcode(tail[head])] = cur_p;
     check[avail_base+getcharcode(s.charat(i+1))] = cur_p;
     
     //tail 为end_flag 的情况
     if(tail[head] == end_char)
      base[avail_base+getcharcode(tail[head])] = 0;
     else
      base[avail_base+getcharcode(tail[head])] = -(head+1);
     if(s.charat(i+1) == end_char)
      base[avail_base+getcharcode(s.charat(i+1))] = 0;
     else
      base[avail_base+getcharcode(s.charat(i+1))] = -pos;
     
     pos = copytotailarray(s,i+2);
     break;
    }
   }else
   //冲突2:当前结点已经被占用,需要调整pre的base
   if(check[cur_p] != pre_p)
   {
    arraylist<integer> list1 = getchildlist(pre_p);
    int tobeadjust;
    arraylist<integer> list = null;
    if(true)
    {
     tobeadjust = pre_p;
     list = list1;
    }
    
    int origin_base = base[tobeadjust];
    list.add(getcharcode(s.charat(i)));
    int avail_base = x_check((integer[])list.toarray(new integer[list.size()]));
    list.remove(list.size()-1);
    
    base[tobeadjust] = avail_base;
    for(int j=0; j<list.size(); ++j)
    {
     //bug
     int tmp1 = origin_base + list.get(j);
     int tmp2 = avail_base + list.get(j);
     
     base[tmp2] = base[tmp1];
     check[tmp2] = check[tmp1];
     
     //有后续
     if(base[tmp1] > 0)
     {
      arraylist<integer> subsequence = getchildlist(tmp1);
      for(int k=0; k<subsequence.size(); ++k)
      {
       check[base[tmp1]+subsequence.get(k)] = tmp2;
      }
     }
     
     base[tmp1] = 0;
     check[tmp1] = 0;
    }
    
    //更新新的cur_p
    cur_p = base[pre_p]+getcharcode(s.charat(i));
    
    if(s.charat(i) == end_char)
     base[cur_p] = 0;
    else
     base[cur_p] = -pos;
    check[cur_p] = pre_p;
    pos = copytotailarray(s,i+1);
    break;
   }
  }
 }
 
 public boolean exists(string word)
 {
  int pre_p = 1;
  int cur_p = 0;
  
  for(int i=0;i<word.length();++i)
  {
   cur_p = base[pre_p]+getcharcode(word.charat(i));
   if(check[cur_p] != pre_p) return false;
   if(base[cur_p] < 0)
   {
    if(tailmatchstring(-base[cur_p],word.substring(i+1)))
     return true;
    return false;
   }
   pre_p = cur_p;
  }
  if(check[base[cur_p]+getcharcode(end_char)] == cur_p)
   return true;
  return false;
 }
 
 //内部函数,返回匹配单词的最靠后的base index,
 class findstruct
 {
  int p;
  string prefix="";
 }
 private findstruct find(string word)
 {
  int pre_p = 1;
  int cur_p = 0;
  findstruct fs = new findstruct();
  for(int i=0;i<word.length();++i)
  {
   // bug
   fs.prefix += word.charat(i);
   cur_p = base[pre_p]+getcharcode(word.charat(i));
   if(check[cur_p] != pre_p)
   {
    fs.p = -1;
    return fs;
   }
   if(base[cur_p] < 0)
   {
    if(tailcontainstring(-base[cur_p],word.substring(i+1)))
    {
     fs.p = cur_p;
     return fs;
    }
    fs.p = -1;
    return fs;
   }
   pre_p = cur_p;
  }
  fs.p =  cur_p;
  return fs;
 }
 
 public arraylist<string> getallchildword(int index)
 {
  arraylist<string> result = new arraylist<string>();
  if(base[index] == 0)
  {
   result.add("");
   return result;
  }
  if(base[index] < 0)
  {
   string r="";
   for(int i=-base[index];tail[i]!=end_char;++i)
   {
    r+= tail[i];
   }
   result.add(r);
   return result;
  }
  for(int i=1;i<=charmap.size();++i)
  {
   if(check[base[index]+i] == index)
   {
    for(string s:getallchildword(base[index]+i))
    {
     result.add(charlist.get(i)+s);
    }
    //result.addall(getallchildword(base[index]+i));
   }
  }
  return result;
 }
 
 public arraylist<string> findallwords(string word)
 {
  arraylist<string> result = new arraylist<string>();
  string prefix = "";
  findstruct fs = find(word);
  int p = fs.p;
  if (p == -1) return result;
  if(base[p]<0)
  {
   string r="";
   for(int i=-base[p];tail[i]!=end_char;++i)
   {
    r+= tail[i];
   }
   result.add(fs.prefix+r);
   return result;
  }
  
  if(base[p] > 0)
  {
   arraylist<string> r =  getallchildword(p);
   for(int i=0;i<r.size();++i)
   {
    r.set(i, fs.prefix+r.get(i));
   }
   return r;
  }
  
  return result;
 }
 
}

测试代码:

复制代码 代码如下:

import java.io.bufferedreader;
import java.io.fileinputstream;
import java.io.ioexception;
import java.io.inputstream;
import java.io.inputstreamreader;
import java.util.arraylist;
import java.util.scanner;

import javax.xml.crypto.data;


public class main {

 public static void main(string[] args) throws exception {
  arraylist<string> words = new arraylist<string>();
  bufferedreader reader = new bufferedreader(new inputstreamreader(new fileinputstream("e:/兔子的试验学习中心[课内]/acm大赛/acm第四届校赛/e命令提示/words3.dic")));
  string s;
  int num = 0;
  while((s=reader.readline()) != null)
  {
   words.add(s);
   num ++;
  }
  doublearraytrie dat = new doublearraytrie();
  
  for(string word: words)
  {
   dat.insert(word);
  }
  
  system.out.println(dat.base.length);
  system.out.println(dat.tail.length);
  
  scanner sc = new scanner(system.in);
  while(sc.hasnext())
  {
   string word = sc.next();
   system.out.println(dat.exists(word));
   system.out.println(dat.findallwords(word));
  }
  
 }

}

下面是测试结果,构造6w英文单词的dat,大概需要20秒

我增长数组的时候是每次长度增加到2倍,初始1024

base和check数组的长度为131072

tail的长度为262144

ttt1