Katz平滑的实现
程序员文章站
2024-03-24 11:16:34
...
Katz平滑的讲解可参见:https://zhuanlan.zhihu.com/p/100256789。本文是实现了该文章的示例。具体的Katz平滑有时间会把自己的理解写一下。看代码可能会有助于理解文章的内容。
import tensorflow as tf
import collections
class Katz():
def __init__(self):
self.words_uni = collections.defaultdict(int)
self.words_bi = collections.defaultdict(int)
self.ml_uni=collections.defaultdict(float)
self.nr_bi=collections.defaultdict(int)
self.rhat_bi = collections.defaultdict(float)
self.r_bi = collections.defaultdict(float) #ml_bi
self.dr_bi = collections.defaultdict(float)
self.ml_bi_discount = collections.defaultdict(float)
self.head_bi = collections.defaultdict(float)
self.tail_bi = collections.defaultdict(float)
self.bow_dict = collections.defaultdict(float)
self.T = 3
self.A = 0
def replace_end(self,input,pattern,rewrite):
return tf.strings.regex_replace(input,pattern,rewrite)
def split_ngram(self,input,width):
words = tf.strings.split(input)
return tf.strings.ngrams(words,width)
def sen_bi(self,sen):
bi = self.split_ngram(sen,2)
p=1.
for i in bi:
words = tf.strings.split(i)[0]
str_index = words.numpy().decode()
p = p * self.words_bi[i.numpy().decode()] / self.words_uni[str_index]
#p = p * self.words_bi_tail[tail]/self.words_uni[tail]
return p
def add_head_tail(self,s):
head_tail=[]
for i in s:
head_tail.append("<s> "+i+" </s>")
return head_tail
def example(self):
data=["dogs chase cats",
"dogs bark",
"cats meow",
"dogs chase birds",
"cats chase birds",
"dogs chase the cats",
"the birds chirp"
]
self.test_data = [
"cats meow",
"dogs chase the birds",
"birds chirp",
"Wang dogs Zhi Guo"
]
self.data_sep = self.add_head_tail(data)
self.test_data = self.add_head_tail(self.test_data)
def unigram(self):
data=tf.strings.split(self.data_sep)
self.words_uni= collections.defaultdict(int)
for item in data:
for i in item:
if i != "":
freq = tf.strings.split(i)
for index in freq:
words = index.numpy().decode()
self.words_uni[words] += 1
def bigram(self):
self.total_bi_sentence = 0
self.words_bi = collections.defaultdict(int)
for item in self.data_sep:
self.total_bi_sentence += 1
bigram_index = self.split_ngram(item, 2)
for index in bigram_index:
str_index = index.numpy().decode()
self.words_bi[str_index] += 1
def log10(self,data):
return tf.math.log(data)/tf.math.log(10.)
def unigram_ml(self):
total_unigram = sum(self.words_uni.values())- self.words_uni['<s>']
for key in self.words_uni:
self.ml_uni[key]=self.words_uni[key]/total_unigram
def bigram_ml(self):
for key in self.words_bi:
head = key.split()
k0 = head[0]
k1 = head[1]
self.r_bi[key] = self.words_bi[key]/self.words_uni[k0]
def bigram_nr(self):
values = self.words_bi.values()
for v in values:
self.nr_bi[v]+=1
def bigram_A(self):
self.A = (self.T+1.)*self.nr_bi[self.T+1]/self.nr_bi[1]
def bigram_discount(self):
key = self.words_bi.keys()
for k in key:
head = k.split()
k0=head[0]
k1 = head[1]
n = self.words_bi[k]
if n<self.T:
self.rhat_bi[k]=(n+1)*self.nr_bi[n+1]/self.nr_bi[n]
self.dr_bi[k]=(self.rhat_bi[k]/n-self.A)/(1-self.A)
self.ml_bi_discount[k]=self.r_bi[k]*self.dr_bi[k]
self.head_bi[k0]+=self.ml_bi_discount[k]
self.tail_bi[k0]+=self.ml_uni[k1]
else:
self.ml_bi_discount[k]=self.r_bi[k]
self.head_bi[k0]+=self.ml_bi_discount[k]
self.tail_bi[k0]+=self.ml_uni[k1]
def bow(self):
for key in self.words_uni:
if key !="</s>":
self.bow_dict[key]=(1-self.head_bi[key])/(1-self.tail_bi[key])
self.bow_dict["<unk>"] = 0.0
def replace_nuk(self,input):
words = input.split()
for i in range(len(words)):
if not self.words_uni.get(words[i]):
words[i] = "<unk>"
return " ".join(words)
def test(self):
for item in self.test_data:
words = self.replace_nuk(item)
bigram_index = self.split_ngram(words, 2)
p = 0.
for index in bigram_index:
index = index.numpy().decode()
if index == "<s> <unk>":
continue
if index == "<unk> <unk>":
continue
if index == "<unk> </s>":
p = p+self.log10(self.ml_uni["</s>"])
continue
value = self.ml_bi_discount.get(index)
if value :
p += self.log10(value)
else:
head = index.split()
k0 = head[0]
k1 = head[1]
if k0 == "<unk>":
continue
if k1 == "<unk>":
p = p+self.log10(self.ml_uni[k0])
continue
bow_value = self.bow_dict.get(k0)
if bow_value :
p += self.log10(bow_value)
uni_value = self.ml_uni[k1]
if uni_value :
p += self.log10(uni_value)
print(item,":",p.numpy())
k=Katz()
k.example()
k.unigram()
k.unigram_ml()
k.bigram()
k.bigram_ml()
k.bigram_nr()
k.bigram_A()
k.bigram_discount()
k.bow()
k.test()