RNN处理长句子
程序员文章站
2024-03-25 10:45:16
...
处理文本、获取到X和Y
import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
from tensorflow.contrib import layers, seq2seq
tf.set_random_seed(777)
sentence = ("if you want to build a ship, don't drum up people together to "
"collect wood and don't assign them tasks and work, but rather "
"teach them to long for the endless immentsity of the sea.")
char_set = list(set(sentence)) # 得到不重复的字母
char_dict = {w : i for i, w in enumerate(char_set)} # 获得词典,以便以字母得到对应的数值
hidden_size = len(char_set) # 隐藏层个数
num_classes = len(char_set) # 分类
sequence_length = 20 # 每次训练的字符串长度为10,数量任意
learning_rate = 0.01 # 学习率
'''
每次循环20个字母
'''
dataX = []
dataY = []
for i in range(0, len(sentence) - sequence_length):
# 获取特征和标签集数据
x_str = sentence[i : i + sequence_length]
y_str = sentence[i + 1 : i + sequence_length + 1]
# 根据字典获取数据对应数值
x = [char_dict[i] for i in x_str]
y = [char_dict[i] for i in y_str]
# 将数值向量放入集合
dataX.append(x)
dataY.append(y)
占位符,session会话中传值调用
X = tf.placeholder(tf.int32, [None, sequence_length])
Y = tf.placeholder(tf.int32, [None, sequence_length])
one-hot编码
X_one_hot = tf.one_hot(dataX, num_classes)
RNN使用
# LSTM单元格,使用hidden_size(每个单位输出向量大小)创建一个lstm单元格
def lstm_cell():
cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)
return cell
# RNN单元格
multi_cells = rnn.MultiRNNCell([lstm_cell() for _ in range(2)], state_is_tuple=True)
outputs, _ = tf.nn.dynamic_rnn(multi_cells, X_one_hot, dtype=tf.float32)
添加一层隐藏层,增加深度,为了更准确
X_for_fc = tf.reshape(outputs, [-1, hidden_size])
outputs = layers.fully_connected(X_for_fc, num_classes, activation_fn=None)
# 确定输出的维度
batch_size = len(dataX)
outputs = tf.reshape(outputs, [batch_size, sequence_length, num_classes])
使用了一层全连接之后效果提升很大
下面就是很正常的优化训练
weights = tf.ones([batch_size, sequence_length])
sequence_loss = seq2seq.sequence_loss(logits=outputs, targets=Y, weights=weights)
mean_loss = tf.reduce_mean(sequence_loss)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(mean_loss)
# 打开会话
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 训练过程
for i in range(500):
_, loss, results = sess.run(
[train_op, mean_loss, outputs], feed_dict={X : dataX, Y : dataY}
)
for j, result in enumerate(results):
index = np.argmax(result, axis=1)
print(i, j, ''.join([char_set[t] for t in index]), loss)
最后输出预测值,输出值中间都为重复值,所以进行判断组合成语句(可以选择从头取,也可以从尾取)。
results = sess.run(outputs, feed_dict={X : dataX})
for j, result in enumerate(results):
index = np.argmax(result, axis=1)
if j is 0:
print(''.join([char_set[t] for t in index]), end='')
else:
print(char_set[index[-1]], end='')
预测效果挺好,基本与样本语句符合
上一篇: emmet常用语法
下一篇: ROS 用hector_slam建图