欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch入门与实践学习笔记:chapter9 pytorch 实现CharRNN

程序员文章站 2024-03-25 11:15:52
...

使用RNN写汉语诗

数据库:

GitHub上手机的50k左右的唐诗原文 唐诗数据库,保存格式是.json和sqlite数据库

需要做的改动是:

将繁体转换成简体;

将所有的数据进行截断和补齐,以易于处理。

提供了一个numpy的压缩包tang.npz,里面包含三个对象:

  • data:(57580,125)的numpy数组
  • word2ix:每个字和它对应的序号,如“春”——>1000;
  • ix2word:每个序号和它对应的字,如1000——>“春”。

数据加载

import numpy as np

datas = np.load('tang.npz')
data = datas['data']
ix2word = data['ix2word'].item()

poem = data[0]
poem_txt = [ix2word[ii] for ii in poem]

程序依赖

git+https://github.com/pytorch/[email protected]
fire
ipdb
torchvision
visdom
tqdm

主要函数组成:

mian.py:包含程序配置、训练和生成

model.py:模型定义;

utils.py:可视化工具封装;

data.py:数据预处理。

main.py

#coding:utf8
import sys,os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb

class Config(object):
    data_path = 'data/' # 诗歌的文本文件存放路径
    pickle_path= 'tang.npz' # 预处理好的二进制文件 
    author = None # 只学习某位作者的诗歌
    constrain = None # 长度限制
    category = 'poet.tang' # 类别,唐诗还是宋诗歌(poet.song)
    lr = 1e-3 
    weight_decay = 1e-4
    use_gpu = True
    epoch = 20  
    batch_size = 128
    maxlen = 125 # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格
    plot_every = 20 # 每20个batch 可视化一次
    # use_env = True # 是否使用visodm
    env='poetry' # visdom env
    max_gen_len = 200 # 生成诗歌最长长度
    debug_file='/tmp/debugp'
    model_path=None # 预训练模型路径
    prefix_words = '细雨鱼儿出,微风燕子斜。' # 不是诗歌的组成部分,用来控制生成诗歌的意境
    start_words='闲云潭影日悠悠' # 诗歌开始
    acrostic = False # 是否是藏头诗
    model_prefix = 'checkpoints/tang' # 模型保存路径

opt = Config()

def generate(model,start_words,ix2word,word2ix,prefix_words=None):
    '''
    给定几个词,根据这几个词接着生成一首完整的诗歌
    start_words:u'春江潮水连海平'
    比如start_words 为 春江潮水连海平,可以生成:

    '''
    results = list(start_words)
    start_word_len = len(start_words)
    # 手动设置第一个词为<START>
    input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
    if opt.use_gpu:input=input.cuda()
    hidden = None

    if prefix_words:
        for word in prefix_words:
            output,hidden = model(input,hidden)
            input = Variable(input.data.new([word2ix[word]])).view(1,1)
    
    for i in range(opt.max_gen_len):
        output,hidden = model(input,hidden)
  
        if i<start_word_len:
            w = results[i]
            input = Variable(input.data.new([word2ix[w]])).view(1,1)      
        else:
            top_index  = output.data[0].topk(1)[1][0]
            w = ix2word[top_index]  
            results.append(w)
            input = Variable(input.data.new([top_index])).view(1,1)
        if w=='<EOP>':
            del results[-1]
            break     
    return results

def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
    '''
    生成藏头诗
    start_words : u'深度学习'
    生成:
    深木通中岳,青苔半日脂。
    度山分地险,逆浪到南巴。
    学道兵犹毒,当时燕不移。
    习根通古岸,开镜出清羸。
    '''
    results = []
    start_word_len = len(start_words)
    input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
    if opt.use_gpu:input=input.cuda()
    hidden = None
    
    index=0 # 用来指示已经生成了多少句藏头诗
    # 上一个词
    pre_word='<START>'

    if prefix_words:
        for word in prefix_words:
            output,hidden = model(input,hidden)
            input = Variable(input.data.new([word2ix[word]])).view(1,1)

    for i in range(opt.max_gen_len):
        output,hidden = model(input,hidden)
        top_index  = output.data[0].topk(1)[1][0]
        w = ix2word[top_index]

        if (pre_word  in {u'。',u'!','<START>'} ):
            # 如果遇到句号,藏头的词送进去生成

            if index==start_word_len:
                # 如果生成的诗歌已经包含全部藏头的词,则结束
                break
            else:  
                # 把藏头的词作为输入送入模型
                w = start_words[index]
                index+=1
                input = Variable(input.data.new([word2ix[w]])).view(1,1)    
        else:
            # 否则的话,把上一次预测是词作为下一个词输入
            input = Variable(input.data.new([word2ix[w]])).view(1,1)
        results.append(w)
        pre_word = w
    return results


def train(**kwargs):

    for k,v in kwargs.items():
        setattr(opt,k,v)

    vis = Visualizer(env=opt.env)

    # 获取数据
    data,word2ix,ix2word = get_data(opt)
    data = t.from_numpy(data)
    dataloader = t.utils.data.DataLoader(data,
                    batch_size=opt.batch_size,
                    shuffle=True,
                    num_workers=1)

    # 模型定义
    model = PoetryModel(len(word2ix), 128, 256)
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = nn.CrossEntropyLoss()
    
    if opt.model_path:
        model.load_state_dict(t.load(opt.model_path))

    if opt.use_gpu:
        model.cuda()
        criterion.cuda()
    loss_meter = meter.AverageValueMeter()

    for epoch in range(opt.epoch):
        loss_meter.reset()
        for ii,data_ in tqdm.tqdm(enumerate(dataloader)):    

            # 训练
            data_ = data_.long().transpose(1,0).contiguous()
            if opt.use_gpu: data_ = data_.cuda()      
            optimizer.zero_grad()
            input_,target = Variable(data_[:-1,:]),Variable(data_[1:,:])
            output,_  = model(input_)
            loss = criterion(output,target.view(-1))
            loss.backward()
            optimizer.step()
        
            loss_meter.add(loss.data[0])

            # 可视化
            if (1+ii)%opt.plot_every==0:
            
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss',loss_meter.value()[0])
                
                # 诗歌原文
                poetrys=[ [ix2word[_word] for _word in data_[:,_iii]] 
                                    for _iii in range(data_.size(1))][:16]
                vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),win=u'origin_poem')
                
                gen_poetries = []
                # 分别以这几个字作为诗歌的第一个字,生成8首诗
                for word in list(u'春江花月夜凉如水'):
                    gen_poetry =  ''.join(generate(model,word,ix2word,word2ix))
                    gen_poetries.append(gen_poetry)
                vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]),win=u'gen_poem')  
        
        t.save(model.state_dict(),'%s_%s.pth' %(opt.model_prefix,epoch))

def gen(**kwargs):
    '''
    提供命令行接口,用以生成相应的诗
    '''

    for k,v in kwargs.items():
        setattr(opt,k,v)

    data,word2ix,ix2word = get_data(opt)
    model = PoetryModel(len(word2ix), 128, 256);
    map_location = lambda s,l:s
    state_dict = t.load(opt.model_path,map_location=map_location)
    model.load_state_dict(state_dict)
    
    if opt.use_gpu:
        model.cuda()
    if sys.version_info.major == 3:
        start_words = opt.start_words.encode('ascii', 'surrogateescape').decode('utf8')
        prefix_words = opt.prefix_words.encode('ascii', 'surrogateescape').decode('utf8') if opt.prefix_words else None
    else:
        start_words = opt.start_words.decode('utf8')
        prefix_words = opt.prefix_words.decode('utf8') if opt.prefix_words else None
     
     
    start_words= start_words.replace(',',u',')\
                                     .replace('.',u'。')\
                                     .replace('?',u'?')

    gen_poetry = gen_acrostic if opt.acrostic else generate
    result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
    print(''.join(result))

if __name__ == '__main__':
    import fire
    fire.Fire()

model.py

#coding:utf8
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(PoetryModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim,num_layers=2)
        self.linear1 = nn.Linear(self.hidden_dim, vocab_size)

    def forward(self, input,hidden=None):
        seq_len,batch_size = input.size()
        if hidden is None:
            #  h_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda()
            #  c_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda()
            h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            h_0,c_0 = Variable(h_0),Variable(c_0)
        else:
            h_0,c_0 = hidden
        # size: (seq_len,batch_size,embeding_dim)
        embeds = self.embeddings(input)
        # output size: (seq_len,batch_size,hidden_dim)
        output, hidden = self.lstm(embeds, (h_0,c_0))

        # size: (seq_len*batch_size,vocab_size)
        output = self.linear1(output.view(seq_len*batch_size, -1))
        return output,hidden

data.py

#coding:utf-8
import sys
import os
import json
import re
import numpy as np

def _parseRawData(author = None, constrain = None,src = './chinese-poetry/json/simplified',category="poet.tang"):
    '''
    code from https://github.com/justdark/pytorch-poetry-gen/blob/master/dataHandler.py
    处理json文件,返回诗歌内容
    @param: author: 作者名字
    @param: constrain: 长度限制
    @param: src: json 文件存放路径
    @param: category: 类别,有poet.song 和 poet.tang

    返回 data:list
        ['床前明月光,疑是地上霜,举头望明月,低头思故乡。',
         '一去二三里,烟村四五家,亭台六七座,八九十支花。',
        .........
        ]
    '''

    def sentenceParse(para):
        # para 形如 "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡,
        # 生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。)
        # 好是去塵俗,煙花長一欄。"
        result, number = re.subn(u"(.*)", "", para)
        result, number = re.subn(u"(.*)", "", para)
        result, number = re.subn(u"{.*}", "", result)
        result, number = re.subn(u"《.*》", "", result)
        result, number = re.subn(u"《.*》", "", result)
        result, number = re.subn(u"[\]\[]", "", result)
        r = ""
        for s in result:
            if s not in set('0123456789-'):
                r += s;
        r, number = re.subn(u"。。", u"。", r)
        return r

    def handleJson(file):
        # print file
        rst = []
        data = json.loads(open(file).read())
        for poetry in data:
            pdata = ""
            if (author!=None and poetry.get("author")!=author):
                continue
            p = poetry.get("paragraphs")
            flag = False
            for s in p:
                sp = re.split(u"[,!。]", s)
                for tr in sp:
                    if constrain != None and len(tr) != constrain and len(tr)!=0:
                        flag = True
                        break
                    if flag:
                        break
            if flag:
                continue
            for sentence in poetry.get("paragraphs"):
                pdata += sentence
            pdata = sentenceParse(pdata)
            if pdata!="":
                rst.append(pdata)
        return rst
    
    data = []
    for filename in os.listdir(src):
        if filename.startswith(category):
            data.extend(handleJson(src+filename))
    return data


def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int32',
                  padding='pre',
                  truncating='pre',
                  value=0.):
  """
  code from keras
  Pads each sequence to the same length (length of the longest sequence).
  If maxlen is provided, any sequence longer
  than maxlen is truncated to maxlen.
  Truncation happens off either the beginning (default) or
  the end of the sequence.
  Supports post-padding and pre-padding (default).
  Arguments:
      sequences: list of lists where each element is a sequence
      maxlen: int, maximum length
      dtype: type to cast the resulting sequence.
      padding: 'pre' or 'post', pad either before or after each sequence.
      truncating: 'pre' or 'post', remove values from sequences larger than
          maxlen either in the beginning or in the end of the sequence
      value: float, value to pad the sequences to the desired value.
  Returns:
      x: numpy array with dimensions (number_of_sequences, maxlen)
  Raises:
      ValueError: in case of invalid values for `truncating` or `padding`,
          or in case of invalid shape for a `sequences` entry.
  """
  if not hasattr(sequences, '__len__'):
    raise ValueError('`sequences` must be iterable.')
  lengths = []
  for x in sequences:
    if not hasattr(x, '__len__'):
      raise ValueError('`sequences` must be a list of iterables. '
                       'Found non-iterable: ' + str(x))
    lengths.append(len(x))

  num_samples = len(sequences)
  if maxlen is None:
    maxlen = np.max(lengths)

  # take the sample shape from the first non empty sequence
  # checking for consistency in the main loop below.
  sample_shape = tuple()
  for s in sequences:
    if len(s) > 0:  # pylint: disable=g-explicit-length-test
      sample_shape = np.asarray(s).shape[1:]
      break

  x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
  for idx, s in enumerate(sequences):
    if not len(s):  # pylint: disable=g-explicit-length-test
      continue  # empty list/array was found
    if truncating == 'pre':
      trunc = s[-maxlen:]  # pylint: disable=invalid-unary-operand-type
    elif truncating == 'post':
      trunc = s[:maxlen]
    else:
      raise ValueError('Truncating type "%s" not understood' % truncating)

    # check `trunc` has expected shape
    trunc = np.asarray(trunc, dtype=dtype)
    if trunc.shape[1:] != sample_shape:
      raise ValueError(
          'Shape of sample %s of sequence at position %s is different from '
          'expected shape %s'
          % (trunc.shape[1:], idx, sample_shape))

    if padding == 'post':
      x[idx, :len(trunc)] = trunc
    elif padding == 'pre':
      x[idx, -len(trunc):] = trunc
    else:
      raise ValueError('Padding type "%s" not understood' % padding)
  return x


def get_data(opt):
    '''
    @param opt 配置选项 Config对象
    @return word2ix: dict,每个字对应的序号,形如u'月'->100
    @return ix2word: dict,每个序号对应的字,形如'100'->u'月'
    @return data: numpy数组,每一行是一首诗对应的字的下标
    '''
    if os.path.exists(opt.pickle_path):
        data = np.load(opt.pickle_path)
        data,word2ix,ix2word  = data['data'],data['word2ix'].item(),data['ix2word'].item()
        return data,word2ix,ix2word

    # 如果没有处理好的二进制文件,则处理原始的json文件
    data = _parseRawData(opt.author,opt.constrain,opt.data_path,opt.category)
    words = {_word for _sentence in data for  _word in _sentence}
    word2ix = {_word:_ix for _ix,_word in enumerate(words)}
    word2ix['<EOP>'] = len(word2ix) # 终止标识符
    word2ix['<START>'] = len(word2ix) # 起始标识符
    word2ix['</s>'] = len(word2ix) # 空格
    ix2word = {_ix:_word for _word,_ix in list(word2ix.items())}
    
    # 为每首诗歌加上起始符和终止符
    for i in range(len(data)):
        data[i] = ["<START>"]+list(data[i]) + ["<EOP>"]
    
    # 将每首诗歌保存的内容由‘字’变成‘数’
    # 形如[春,江,花,月,夜]变成[1,2,3,4,5]
    new_data = [ [word2ix[_word] for _word in _sentence]
                                  for _sentence in data]

    # 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的
    pad_data = pad_sequences(new_data,\
                             maxlen=opt.maxlen,\
                             padding='pre',
                             truncating='post',
                             value=len(word2ix)-1)

    # 保存成二进制文件
    np.savez_compressed(opt.pickle_path,
                            data=pad_data,
                            word2ix=word2ix,
                            ix2word=ix2word)                        
    return pad_data,word2ix,ix2word

 

相关标签: pytorch RNN