pytorch入门与实践学习笔记:chapter9 pytorch 实现CharRNN
程序员文章站
2024-03-25 11:15:52
...
使用RNN写汉语诗
数据库:
GitHub上手机的50k左右的唐诗原文 唐诗数据库,保存格式是.json和sqlite数据库
需要做的改动是:
将繁体转换成简体;
将所有的数据进行截断和补齐,以易于处理。
提供了一个numpy的压缩包tang.npz,里面包含三个对象:
- data:(57580,125)的numpy数组
- word2ix:每个字和它对应的序号,如“春”——>1000;
- ix2word:每个序号和它对应的字,如1000——>“春”。
数据加载
import numpy as np
datas = np.load('tang.npz')
data = datas['data']
ix2word = data['ix2word'].item()
poem = data[0]
poem_txt = [ix2word[ii] for ii in poem]
程序依赖
git+https://github.com/pytorch/[email protected]
fire
ipdb
torchvision
visdom
tqdm
主要函数组成:
mian.py:包含程序配置、训练和生成
model.py:模型定义;
utils.py:可视化工具封装;
data.py:数据预处理。
main.py
#coding:utf8
import sys,os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
class Config(object):
data_path = 'data/' # 诗歌的文本文件存放路径
pickle_path= 'tang.npz' # 预处理好的二进制文件
author = None # 只学习某位作者的诗歌
constrain = None # 长度限制
category = 'poet.tang' # 类别,唐诗还是宋诗歌(poet.song)
lr = 1e-3
weight_decay = 1e-4
use_gpu = True
epoch = 20
batch_size = 128
maxlen = 125 # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格
plot_every = 20 # 每20个batch 可视化一次
# use_env = True # 是否使用visodm
env='poetry' # visdom env
max_gen_len = 200 # 生成诗歌最长长度
debug_file='/tmp/debugp'
model_path=None # 预训练模型路径
prefix_words = '细雨鱼儿出,微风燕子斜。' # 不是诗歌的组成部分,用来控制生成诗歌的意境
start_words='闲云潭影日悠悠' # 诗歌开始
acrostic = False # 是否是藏头诗
model_prefix = 'checkpoints/tang' # 模型保存路径
opt = Config()
def generate(model,start_words,ix2word,word2ix,prefix_words=None):
'''
给定几个词,根据这几个词接着生成一首完整的诗歌
start_words:u'春江潮水连海平'
比如start_words 为 春江潮水连海平,可以生成:
'''
results = list(start_words)
start_word_len = len(start_words)
# 手动设置第一个词为<START>
input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
if opt.use_gpu:input=input.cuda()
hidden = None
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(opt.max_gen_len):
output,hidden = model(input,hidden)
if i<start_word_len:
w = results[i]
input = Variable(input.data.new([word2ix[w]])).view(1,1)
else:
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index]
results.append(w)
input = Variable(input.data.new([top_index])).view(1,1)
if w=='<EOP>':
del results[-1]
break
return results
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
'''
生成藏头诗
start_words : u'深度学习'
生成:
深木通中岳,青苔半日脂。
度山分地险,逆浪到南巴。
学道兵犹毒,当时燕不移。
习根通古岸,开镜出清羸。
'''
results = []
start_word_len = len(start_words)
input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
if opt.use_gpu:input=input.cuda()
hidden = None
index=0 # 用来指示已经生成了多少句藏头诗
# 上一个词
pre_word='<START>'
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(opt.max_gen_len):
output,hidden = model(input,hidden)
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index]
if (pre_word in {u'。',u'!','<START>'} ):
# 如果遇到句号,藏头的词送进去生成
if index==start_word_len:
# 如果生成的诗歌已经包含全部藏头的词,则结束
break
else:
# 把藏头的词作为输入送入模型
w = start_words[index]
index+=1
input = Variable(input.data.new([word2ix[w]])).view(1,1)
else:
# 否则的话,把上一次预测是词作为下一个词输入
input = Variable(input.data.new([word2ix[w]])).view(1,1)
results.append(w)
pre_word = w
return results
def train(**kwargs):
for k,v in kwargs.items():
setattr(opt,k,v)
vis = Visualizer(env=opt.env)
# 获取数据
data,word2ix,ix2word = get_data(opt)
data = t.from_numpy(data)
dataloader = t.utils.data.DataLoader(data,
batch_size=opt.batch_size,
shuffle=True,
num_workers=1)
# 模型定义
model = PoetryModel(len(word2ix), 128, 256)
optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
criterion = nn.CrossEntropyLoss()
if opt.model_path:
model.load_state_dict(t.load(opt.model_path))
if opt.use_gpu:
model.cuda()
criterion.cuda()
loss_meter = meter.AverageValueMeter()
for epoch in range(opt.epoch):
loss_meter.reset()
for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
# 训练
data_ = data_.long().transpose(1,0).contiguous()
if opt.use_gpu: data_ = data_.cuda()
optimizer.zero_grad()
input_,target = Variable(data_[:-1,:]),Variable(data_[1:,:])
output,_ = model(input_)
loss = criterion(output,target.view(-1))
loss.backward()
optimizer.step()
loss_meter.add(loss.data[0])
# 可视化
if (1+ii)%opt.plot_every==0:
if os.path.exists(opt.debug_file):
ipdb.set_trace()
vis.plot('loss',loss_meter.value()[0])
# 诗歌原文
poetrys=[ [ix2word[_word] for _word in data_[:,_iii]]
for _iii in range(data_.size(1))][:16]
vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),win=u'origin_poem')
gen_poetries = []
# 分别以这几个字作为诗歌的第一个字,生成8首诗
for word in list(u'春江花月夜凉如水'):
gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
gen_poetries.append(gen_poetry)
vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]),win=u'gen_poem')
t.save(model.state_dict(),'%s_%s.pth' %(opt.model_prefix,epoch))
def gen(**kwargs):
'''
提供命令行接口,用以生成相应的诗
'''
for k,v in kwargs.items():
setattr(opt,k,v)
data,word2ix,ix2word = get_data(opt)
model = PoetryModel(len(word2ix), 128, 256);
map_location = lambda s,l:s
state_dict = t.load(opt.model_path,map_location=map_location)
model.load_state_dict(state_dict)
if opt.use_gpu:
model.cuda()
if sys.version_info.major == 3:
start_words = opt.start_words.encode('ascii', 'surrogateescape').decode('utf8')
prefix_words = opt.prefix_words.encode('ascii', 'surrogateescape').decode('utf8') if opt.prefix_words else None
else:
start_words = opt.start_words.decode('utf8')
prefix_words = opt.prefix_words.decode('utf8') if opt.prefix_words else None
start_words= start_words.replace(',',u',')\
.replace('.',u'。')\
.replace('?',u'?')
gen_poetry = gen_acrostic if opt.acrostic else generate
result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
print(''.join(result))
if __name__ == '__main__':
import fire
fire.Fire()
model.py
#coding:utf8
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
class PoetryModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(PoetryModel, self).__init__()
self.hidden_dim = hidden_dim
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, self.hidden_dim,num_layers=2)
self.linear1 = nn.Linear(self.hidden_dim, vocab_size)
def forward(self, input,hidden=None):
seq_len,batch_size = input.size()
if hidden is None:
# h_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda()
# c_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda()
h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
h_0,c_0 = Variable(h_0),Variable(c_0)
else:
h_0,c_0 = hidden
# size: (seq_len,batch_size,embeding_dim)
embeds = self.embeddings(input)
# output size: (seq_len,batch_size,hidden_dim)
output, hidden = self.lstm(embeds, (h_0,c_0))
# size: (seq_len*batch_size,vocab_size)
output = self.linear1(output.view(seq_len*batch_size, -1))
return output,hidden
data.py
#coding:utf-8
import sys
import os
import json
import re
import numpy as np
def _parseRawData(author = None, constrain = None,src = './chinese-poetry/json/simplified',category="poet.tang"):
'''
code from https://github.com/justdark/pytorch-poetry-gen/blob/master/dataHandler.py
处理json文件,返回诗歌内容
@param: author: 作者名字
@param: constrain: 长度限制
@param: src: json 文件存放路径
@param: category: 类别,有poet.song 和 poet.tang
返回 data:list
['床前明月光,疑是地上霜,举头望明月,低头思故乡。',
'一去二三里,烟村四五家,亭台六七座,八九十支花。',
.........
]
'''
def sentenceParse(para):
# para 形如 "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡,
# 生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。)
# 好是去塵俗,煙花長一欄。"
result, number = re.subn(u"(.*)", "", para)
result, number = re.subn(u"(.*)", "", para)
result, number = re.subn(u"{.*}", "", result)
result, number = re.subn(u"《.*》", "", result)
result, number = re.subn(u"《.*》", "", result)
result, number = re.subn(u"[\]\[]", "", result)
r = ""
for s in result:
if s not in set('0123456789-'):
r += s;
r, number = re.subn(u"。。", u"。", r)
return r
def handleJson(file):
# print file
rst = []
data = json.loads(open(file).read())
for poetry in data:
pdata = ""
if (author!=None and poetry.get("author")!=author):
continue
p = poetry.get("paragraphs")
flag = False
for s in p:
sp = re.split(u"[,!。]", s)
for tr in sp:
if constrain != None and len(tr) != constrain and len(tr)!=0:
flag = True
break
if flag:
break
if flag:
continue
for sentence in poetry.get("paragraphs"):
pdata += sentence
pdata = sentenceParse(pdata)
if pdata!="":
rst.append(pdata)
return rst
data = []
for filename in os.listdir(src):
if filename.startswith(category):
data.extend(handleJson(src+filename))
return data
def pad_sequences(sequences,
maxlen=None,
dtype='int32',
padding='pre',
truncating='pre',
value=0.):
"""
code from keras
Pads each sequence to the same length (length of the longest sequence).
If maxlen is provided, any sequence longer
than maxlen is truncated to maxlen.
Truncation happens off either the beginning (default) or
the end of the sequence.
Supports post-padding and pre-padding (default).
Arguments:
sequences: list of lists where each element is a sequence
maxlen: int, maximum length
dtype: type to cast the resulting sequence.
padding: 'pre' or 'post', pad either before or after each sequence.
truncating: 'pre' or 'post', remove values from sequences larger than
maxlen either in the beginning or in the end of the sequence
value: float, value to pad the sequences to the desired value.
Returns:
x: numpy array with dimensions (number_of_sequences, maxlen)
Raises:
ValueError: in case of invalid values for `truncating` or `padding`,
or in case of invalid shape for a `sequences` entry.
"""
if not hasattr(sequences, '__len__'):
raise ValueError('`sequences` must be iterable.')
lengths = []
for x in sequences:
if not hasattr(x, '__len__'):
raise ValueError('`sequences` must be a list of iterables. '
'Found non-iterable: ' + str(x))
lengths.append(len(x))
num_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()
for s in sequences:
if len(s) > 0: # pylint: disable=g-explicit-length-test
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if not len(s): # pylint: disable=g-explicit-length-test
continue # empty list/array was found
if truncating == 'pre':
trunc = s[-maxlen:] # pylint: disable=invalid-unary-operand-type
elif truncating == 'post':
trunc = s[:maxlen]
else:
raise ValueError('Truncating type "%s" not understood' % truncating)
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError(
'Shape of sample %s of sequence at position %s is different from '
'expected shape %s'
% (trunc.shape[1:], idx, sample_shape))
if padding == 'post':
x[idx, :len(trunc)] = trunc
elif padding == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError('Padding type "%s" not understood' % padding)
return x
def get_data(opt):
'''
@param opt 配置选项 Config对象
@return word2ix: dict,每个字对应的序号,形如u'月'->100
@return ix2word: dict,每个序号对应的字,形如'100'->u'月'
@return data: numpy数组,每一行是一首诗对应的字的下标
'''
if os.path.exists(opt.pickle_path):
data = np.load(opt.pickle_path)
data,word2ix,ix2word = data['data'],data['word2ix'].item(),data['ix2word'].item()
return data,word2ix,ix2word
# 如果没有处理好的二进制文件,则处理原始的json文件
data = _parseRawData(opt.author,opt.constrain,opt.data_path,opt.category)
words = {_word for _sentence in data for _word in _sentence}
word2ix = {_word:_ix for _ix,_word in enumerate(words)}
word2ix['<EOP>'] = len(word2ix) # 终止标识符
word2ix['<START>'] = len(word2ix) # 起始标识符
word2ix['</s>'] = len(word2ix) # 空格
ix2word = {_ix:_word for _word,_ix in list(word2ix.items())}
# 为每首诗歌加上起始符和终止符
for i in range(len(data)):
data[i] = ["<START>"]+list(data[i]) + ["<EOP>"]
# 将每首诗歌保存的内容由‘字’变成‘数’
# 形如[春,江,花,月,夜]变成[1,2,3,4,5]
new_data = [ [word2ix[_word] for _word in _sentence]
for _sentence in data]
# 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的
pad_data = pad_sequences(new_data,\
maxlen=opt.maxlen,\
padding='pre',
truncating='post',
value=len(word2ix)-1)
# 保存成二进制文件
np.savez_compressed(opt.pickle_path,
data=pad_data,
word2ix=word2ix,
ix2word=ix2word)
return pad_data,word2ix,ix2word