欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Katz平滑的实现

程序员文章站 2024-03-24 11:16:34
...

Katz平滑的讲解可参见:https://zhuanlan.zhihu.com/p/100256789。本文是实现了该文章的示例。具体的Katz平滑有时间会把自己的理解写一下。看代码可能会有助于理解文章的内容。

import tensorflow as tf
import collections
class Katz():
    def __init__(self):
        self.words_uni = collections.defaultdict(int)
        self.words_bi = collections.defaultdict(int)
        self.ml_uni=collections.defaultdict(float)
        self.nr_bi=collections.defaultdict(int)
        self.rhat_bi = collections.defaultdict(float)
        self.r_bi = collections.defaultdict(float) #ml_bi
        self.dr_bi = collections.defaultdict(float)
        self.ml_bi_discount = collections.defaultdict(float)
        self.head_bi = collections.defaultdict(float)
        self.tail_bi = collections.defaultdict(float)
        self.bow_dict = collections.defaultdict(float)
        self.T = 3
        self.A = 0

    def replace_end(self,input,pattern,rewrite):
        return tf.strings.regex_replace(input,pattern,rewrite)

    def split_ngram(self,input,width):
        words = tf.strings.split(input)
        return tf.strings.ngrams(words,width)

    def sen_bi(self,sen):
        bi = self.split_ngram(sen,2)
        p=1.
        for i in bi:
            words = tf.strings.split(i)[0]
            str_index = words.numpy().decode()
            p = p * self.words_bi[i.numpy().decode()] / self.words_uni[str_index]
        #p = p * self.words_bi_tail[tail]/self.words_uni[tail]
        return p

    def add_head_tail(self,s):
        head_tail=[]
        for i in s:
            head_tail.append("<s> "+i+" </s>")
        return head_tail

    def example(self):
        data=["dogs chase cats",
              "dogs bark",
              "cats meow",
              "dogs chase birds",
              "cats chase birds",
              "dogs chase the cats",
              "the birds chirp"
              ]
        self.test_data = [
            "cats meow",
            "dogs chase the birds",
            "birds chirp",
            "Wang dogs Zhi Guo"
        ]
        self.data_sep = self.add_head_tail(data)
        self.test_data = self.add_head_tail(self.test_data)

    def unigram(self):
        data=tf.strings.split(self.data_sep)
        self.words_uni= collections.defaultdict(int)
        for item in data:
            for i in item:
                if i != "":
                    freq = tf.strings.split(i)
                    for index in freq:
                        words = index.numpy().decode()
                        self.words_uni[words] += 1

    def bigram(self):
        self.total_bi_sentence = 0
        self.words_bi = collections.defaultdict(int)
        for item in self.data_sep:
            self.total_bi_sentence += 1
            bigram_index = self.split_ngram(item, 2)
            for index in bigram_index:
                str_index = index.numpy().decode()
                self.words_bi[str_index] += 1

    def log10(self,data):
        return tf.math.log(data)/tf.math.log(10.)

    def unigram_ml(self):
        total_unigram = sum(self.words_uni.values())- self.words_uni['<s>']

        for key in self.words_uni:
            self.ml_uni[key]=self.words_uni[key]/total_unigram

    def bigram_ml(self):
        for key in self.words_bi:
            head = key.split()
            k0 = head[0]
            k1 = head[1]
            self.r_bi[key] = self.words_bi[key]/self.words_uni[k0]

    def bigram_nr(self):
        values = self.words_bi.values()
        for v in values:
            self.nr_bi[v]+=1
    def bigram_A(self):
        self.A = (self.T+1.)*self.nr_bi[self.T+1]/self.nr_bi[1]

    def bigram_discount(self):
        key = self.words_bi.keys()
        for k in key:
            head = k.split()
            k0=head[0]
            k1 = head[1]
            n = self.words_bi[k]
            if n<self.T:
                self.rhat_bi[k]=(n+1)*self.nr_bi[n+1]/self.nr_bi[n]
                self.dr_bi[k]=(self.rhat_bi[k]/n-self.A)/(1-self.A)
                self.ml_bi_discount[k]=self.r_bi[k]*self.dr_bi[k]
                self.head_bi[k0]+=self.ml_bi_discount[k]
                self.tail_bi[k0]+=self.ml_uni[k1]
            else:
                self.ml_bi_discount[k]=self.r_bi[k]
                self.head_bi[k0]+=self.ml_bi_discount[k]
                self.tail_bi[k0]+=self.ml_uni[k1]

    def bow(self):
        for key in self.words_uni:
            if key !="</s>":
                self.bow_dict[key]=(1-self.head_bi[key])/(1-self.tail_bi[key])
        self.bow_dict["<unk>"] = 0.0

    def replace_nuk(self,input):
        words = input.split()
        for i in range(len(words)):
            if not self.words_uni.get(words[i]):
                words[i] = "<unk>"
        return " ".join(words)

    def test(self):
        for item in self.test_data:
            words = self.replace_nuk(item)
            bigram_index = self.split_ngram(words, 2)
            p = 0.
            for index in bigram_index:
                index = index.numpy().decode()
                if index == "<s> <unk>":
                    continue
                if index == "<unk> <unk>":
                    continue
                if index == "<unk> </s>":
                    p = p+self.log10(self.ml_uni["</s>"])
                    continue

                value = self.ml_bi_discount.get(index)
                if value :
                    p += self.log10(value)
                else:
                    head = index.split()
                    k0 = head[0]
                    k1 = head[1]
                    if k0 == "<unk>":
                        continue
                    if k1 == "<unk>":
                        p = p+self.log10(self.ml_uni[k0])
                        continue
                    bow_value = self.bow_dict.get(k0)
                    if bow_value :
                        p += self.log10(bow_value)
                    uni_value = self.ml_uni[k1]
                    if uni_value :
                        p += self.log10(uni_value)
            print(item,":",p.numpy())

k=Katz()
k.example()
k.unigram()
k.unigram_ml()
k.bigram()
k.bigram_ml()
k.bigram_nr()
k.bigram_A()
k.bigram_discount()
k.bow()
k.test()