欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

基于LSTM的文本生成

程序员文章站 2024-03-24 23:57:34
...

      这是鄙人的毕设题目,最近由于总是在迷茫与纠结中度过,考虑是深造还是直接工作,焦虑到快自刎了,昨天猛然醒悟:是该做点事情来填补这种没营养的空虚感了。

      先把寒假前+寒假中所做的事情总结一下。

      首先是RNN与LSTM的区别:

          1.我在接触这两种神经网络的时候首先的印象就是,RNN采用传统的backpropagation+梯度下降法对参数进行学习,第t层的误差函数跟ot直接相关,而ot依赖于前面每一层的xi和si,i小于等于t,因此RNN会出现梯度消失的情况。而LSTM也属于一种改良的RNN,但它不是强行把依赖链截断,而是采用了一种更巧妙的设计来绕开了梯度消失或梯度爆炸的问题

          2.传统RNN每一步的隐藏单元只是执行一个简单的tanh或ReLU操作。 LSTM每个循环的模块内又有4层结构:3个sigmoid层,1个tanh层。

 (我还没搞清楚怎么画图插进来,所以有点干。。。。)

 

-------------------------------------------------------------------------------华丽的分界线---------------------------------------------------------------------------

接下来上代码:

模型的训练:

def generate_sequences_from_texts(texts, indices_list,
                                  textgenrnn, context_labels,
                                  batch_size=128):
    is_words = textgenrnn.config['word_level']
    is_single = textgenrnn.config['single_text']
    max_length = textgenrnn.config['max_length']
    meta_token = textgenrnn.META_TOKEN

    if is_words:
        new_tokenizer = Tokenizer(filters='', char_level=True)
        new_tokenizer.word_index = textgenrnn.vocab
    else:
        new_tokenizer = textgenrnn.tokenizer

    while True:
        np.random.shuffle(indices_list)

        X_batch = []
        Y_batch = []
        context_batch = []
        count_batch = 0

        for row in range(indices_list.shape[0]):
            text_index = indices_list[row, 0]
            end_index = indices_list[row, 1]

            text = texts[text_index]

            if not is_single:
                text = [meta_token] + list(text) + [meta_token]

            if end_index > max_length:
                x = text[end_index - max_length: end_index + 1]
            else:
                x = text[0: end_index + 1]
            y = text[end_index + 1]

            if y in textgenrnn.vocab:
                x = process_sequence([x], textgenrnn, new_tokenizer)
                y = textgenrnn_encode_cat([y], textgenrnn.vocab)

                X_batch.append(x)
                Y_batch.append(y)

                if context_labels is not None:
                    context_batch.append(context_labels[text_index])

                count_batch += 1

                if count_batch % batch_size == 0:
                    X_batch = np.squeeze(np.array(X_batch))
                    Y_batch = np.squeeze(np.array(Y_batch))
                    context_batch = np.squeeze(np.array(context_batch))

                    # print(X_batch.shape)

                    if context_labels is not None:
                        yield ([X_batch, context_batch], [Y_batch, Y_batch])
                    else:
                        yield (X_batch, Y_batch)
                    X_batch = []
                    Y_batch = []
                    context_batch = []
                    count_batch = 0

后续继续,今天开始重新做人,开始营业博客。。。

相关标签: LSTM