tensorflow教程:tf.dynamic_rnn
引言
TensorFlow很容易上手,但是TensorFlow的很多trick却是提升TensorFlow心法的法门,之前说过TensorFlow的read心法,现在想说一说TensorFlow在RNN上的心法,简直好用到哭 【以下实验均是基于TensorFlow1.0】
简要介绍tensorflow的RNN
其实在前面多篇都已经提到了TensorFlow的RNN,也在我之前的文章TensorFlow实现文本分类文章中用到了BasicLSTM的方法,通常的,使用RNN的时候,我们需要指定num_step,也就是TensorFlow的roll step步数,但是对于变长的文本来说,指定num_step就不可避免的需要进行padding操作,在之前的文章TensorFlow高阶读写教程也使用了dynamic_padding方法实现自动padding,但是这还不够,因为在跑一遍RNN/LSTM之后,还是需要对padding部分的内容进行删除,我称之为“反padding”,无可避免的,我们就需要指定mask矩阵了,这就有点不优雅,但是TensorFlow提供了一个很优雅的解决方法,让mask去见马克思去了,那就是dynamic_rnn
tf.dynamic_rnn
tensorflow 的dynamic_rnn方法,我们用一个小例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,last_states,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,last_states是由(c,h)组成的tuple,均为[batch,128]。
到这里并没有什么不同,但是dynamic有个参数:sequence_length,这个参数用来指定每个example的长度,比如上面的例子中,我们令 sequence_length为[20,13],表示第一个example有效长度为20,第二个example有效长度为13,当我们传入这个参数的时候,对于第二个example,TensorFlow对于13以后的padding就不计算了,其last_states将重复第13步的last_states直至第20步,而outputs中超过13步的结果将会被置零。
dynamic_rnn例子
#coding=utf-8
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)
# 第二个example长度为6
X[1,6:] = 0
X_lengths = [10, 6]
cell = tf.contrib.rnn.BasicLSTMCell(num_units=64, state_is_tuple=True)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
result = tf.contrib.learn.run_n(
{"outputs": outputs, "last_states": last_states},
n=1,
feed_dict=None)
print result[0]
assert result[0]["outputs"].shape == (2, 10, 64)
# 第二个example中的outputs超过6步(7-10步)的值应该为0
assert (result[0]["outputs"][1,7,:] == np.zeros(cell.output_size)).all()
我们看输出:
{'outputs': array([[[ 0.02343191, 0.05894056, 0.01552576, ..., -0.06954119,
-0.02693178, -0.02773715],
[-0.01897412, 0.00430241, 0.05111675, ..., -0.12161507,
0.00998021, -0.0282588 ],
[-0.01222279, -0.00742003, 0.1395104 , ..., 0.06212089,
0.05438172, -0.10756982],
...,
[ 0.04471944, 0.03058323, -0.08105398, ..., -0.08458089,
-0.00789265, 0.00711049],
[ 0.07910491, -0.0015225 , -0.08136954, ..., -0.03702021,
-0.02530194, 0.07729477],
[ 0.06114135, 0.0263763 , 0.0153004 , ..., -0.07590827,
-0.00899063, -0.031571 ]],
[[ 0.04057412, 0.0379415 , 0.01818413, ..., 0.00513165,
0.09185232, -0.16915748],
[ 0.08922272, 0.04556143, -0.06847201, ..., -0.03329186,
0.07859877, -0.22903247],
[ 0.04083256, -0.0191676 , -0.00690892, ..., -0.00552511,
0.07809589, -0.16655875],
...,
[ 0. , 0. , 0. , ..., 0. ,
0. , 0. ],
[ 0. , 0. , 0. , ..., 0. ,
0. , 0. ],
[ 0. , 0. , 0. , ..., 0. ,
0. , 0. ]]]), 'last_states': LSTMStateTuple(c=array([[ 1.17486513e-01, 4.53374791e-02, 3.27930624e-02,
1.88688948e-01, -9.18940578e-02, 1.10607361e-01,
7.69938294e-02, 1.02080487e-01, 2.35188842e-01,
-6.99273490e-02, 1.98158514e-01, -2.66004847e-02,
-2.00984914e-01, -1.22899439e-01, -9.09574947e-03,
1.25963024e-01, 8.78420353e-02, -4.48895848e-02,
1.41703260e-02, 7.78878760e-03, -3.56721497e-02,
-1.02126920e-01, -9.31018826e-02, -1.18749056e-01,
-2.15687558e-02, -6.48136325e-02, -6.67117612e-02,
2.06457878e-01, 1.05809077e-01, 3.25519072e-02,
6.68543364e-02, -1.25674027e-01, 1.65443839e-01,
-8.19379933e-02, -2.68197695e-02, -1.26924280e-01,
9.66936841e-02, 2.45289838e-02, -3.15856903e-02,
-9.30471642e-02, 2.28047923e-02, 1.64577723e-01,
-2.13811172e-02, 2.31624708e-01, -5.05328136e-02,
-2.15352598e-01, 1.17756556e-01, 1.24231633e-01,
2.17948294e-01, -1.88141852e-01, 5.56704829e-02,
1.85995614e-04, -1.63170139e-02, 4.14733115e-02,
-1.42410828e-01, -2.10698220e-02, 1.13032204e-01,
1.16487820e-01, 1.14937607e-01, 1.15206014e-01,
9.07994735e-02, -1.47575747e-01, -1.67919061e-02,
-5.57344372e-02],
[ -1.87032883e-01, -4.50730933e-02, 1.65264860e-01,
-1.57064693e-01, -1.02704183e-01, -1.42700035e-01,
-1.82858618e-01, -5.69656656e-02, -3.19701571e-01,
-9.45731981e-04, -8.96991629e-02, 6.37877888e-02,
-7.24395155e-02, 2.24324167e-01, -2.26432828e-01,
-2.12203247e-02, -9.89278157e-02, -1.79787292e-01,
1.17519710e-01, -2.43337123e-01, 6.08713955e-02,
3.71411367e-01, 3.96845821e-02, -1.34371544e-01,
-1.54702491e-01, -1.80343050e-02, 7.06988306e-02,
-1.58112671e-01, -1.74782878e-01, 1.24460790e-01,
-2.01408352e-02, -2.19578859e-01, -1.09101701e-01,
-3.36411660e-02, -4.12966791e-02, -2.62211522e-01,
6.09266090e-02, 5.15926436e-02, 1.31553677e-01,
3.85248320e-02, 6.82502698e-02, 3.20785503e-01,
6.02489641e-02, 1.03486249e-02, -1.98853998e-01,
2.42482932e-01, -3.03208095e-03, 3.26806427e-02,
1.43904791e-01, 4.83002308e-02, 1.06806422e-01,
2.19021559e-01, -1.04280654e-01, 7.02105858e-02,
-1.08238911e-01, 5.31858915e-02, -1.30427149e-01,
-3.14307444e-02, 2.60903800e-02, -3.49547176e-03,
3.15445855e-02, 1.26248331e-01, 2.98049766e-01,
-1.35553357e-01]]), h=array([[ 6.11413522e-02, 2.63763025e-02, 1.53004046e-02,
1.00835659e-01, -4.07618767e-02, 6.39206416e-02,
4.17340362e-02, 5.10448527e-02, 9.37222463e-02,
-3.43376107e-02, 1.00684542e-01, -1.28972917e-02,
-1.20061738e-01, -6.48411970e-02, -4.66407837e-03,
6.29309198e-02, 4.64027731e-02, -1.80123985e-02,
7.18521681e-03, 4.55297690e-03, -1.95851481e-02,
-4.94828658e-02, -4.56579935e-02, -5.68909598e-02,
-1.03985798e-02, -2.80805943e-02, -3.67050137e-02,
1.11822759e-01, 4.82685695e-02, 1.51483196e-02,
3.61371426e-02, -4.92942874e-02, 8.74024618e-02,
-3.75624886e-02, -1.54172618e-02, -6.26848414e-02,
3.92306304e-02, 1.08791341e-02, -1.76010076e-02,
-4.68257540e-02, 1.11274774e-02, 7.26592349e-02,
-1.10059670e-02, 1.25391653e-01, -2.45894375e-02,
-1.10484543e-01, 5.64758454e-02, 6.85158790e-02,
1.05166465e-01, -9.38722289e-02, 2.87157035e-02,
9.68917170e-05, -7.59567519e-03, 2.00130197e-02,
-5.71313903e-02, -1.06302802e-02, 6.53980752e-02,
5.53559936e-02, 5.63571469e-02, 5.87699760e-02,
4.93030711e-02, -7.59082740e-02, -8.99063316e-03,
-3.15710039e-02],
[ -8.75580540e-02, -2.40814362e-02, 7.62920499e-02,
-7.99111282e-02, -5.25187098e-02, -6.82907819e-02,
-9.22920867e-02, -2.82334342e-02, -1.35842188e-01,
-4.41795008e-04, -4.67307509e-02, 3.26420635e-02,
-3.43710296e-02, 1.08600958e-01, -1.19684674e-01,
-1.15702585e-02, -5.29742132e-02, -8.58632779e-02,
5.49293634e-02, -1.28582904e-01, 3.30139501e-02,
1.91180419e-01, 2.06462597e-02, -6.48707477e-02,
-8.20119830e-02, -8.35309469e-03, 3.54353392e-02,
-7.91071596e-02, -8.36684223e-02, 6.17335216e-02,
-1.01217617e-02, -1.00540861e-01, -5.48336196e-02,
-1.71105389e-02, -2.12356078e-02, -1.14496268e-01,
2.93849624e-02, 2.36536930e-02, 6.08473933e-02,
1.81132892e-02, 3.16145248e-02, 1.56376674e-01,
3.24342202e-02, 5.35344708e-03, -9.31969777e-02,
1.23855219e-01, -1.54691975e-03, 1.70947532e-02,
7.22062554e-02, 2.54588642e-02, 5.57794494e-02,
9.75779489e-02, -4.55104484e-02, 3.46636330e-02,
-5.55832345e-02, 2.72228363e-02, -7.08426689e-02,
-1.49771182e-02, 1.34402453e-02, -1.72122309e-03,
1.56672952e-02, 6.92526562e-02, 1.50181313e-01,
-7.16690686e-02]]))}
用rnn处理变长文本时,使用dynamic_rnn可以跳过padding部分的计算,减少计算量。假设有两个文本,一个长度为10,另一个长度为5,那么需要对第二文本使用0-padding方法填充,得到的shape为(2, 10, dim),其中dim是词向量维度。使用dynamic_rnn的代码如下:
outputs, last_states = tf.nn.dynamic_rnn( cell=cell, dtype=tf.float32, sequence_length=x_lengths, inputs=x)
其中cell是RNN节点,比如tf.contrib.rnn.BasicLSTMCel,x是0-padding以后的数据,x_lengths是每个文本的长度。计算第二个文本的时候,只计算前面5个值,后面的就直接跳过了,对应的output直接设为0,cell的状态保持第5步的值。
dynamic_rnn返回两个变量,第一个是每个step的输出值,第二个是最终的状态。那么问题来了,对于第二个文本,我想取的肯定是第5个output,最后一个output是无效的0对我来说没有意义。目前我知道的有3种做法。
第一种是从别人代码里面看到,链接在此。作者自己写了个index的operation,代码比较绕。
第二种是构建一个mask,长度对应的那位为1,其余的为0,比如第二个文本对应的mask为[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],然后将这个mask与outputs按时间维度进行sum,这样得到的刚好是第5个输出的值。
第三种做法最简单,这得从rnn的定义说起,rnn的输出其实就是状态中的h,因此last_states 中的h状态就是我们需要的output。也就是我们把last_states.h当作rnn的最终输出就行了。
作者:王买买提
链接:https://www.zhihu.com/question/52200883/answer/153694449
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
在每一个train step,传入model的是一个batch的数据(这一个batch的数据forward得到predictions,计算loss,backpropagation更新参数),这一个batch内的数据一定是padding成相同长度的。那么,如果可以只在一个batch内部进行padding,例如一个batch中数据长度均在6-10这个范围内,就可以让这个batch中所有数据pad到固定长度10,而整个dataset上的数据最大长度很可能是100,这样就不需要让这些数据也pad到100那么长,白白浪费空间。所以dynamic_rnn实现的功能就是可以让不同迭代传入的batch可以是长度不同数据,但同一次迭代一个batch内部的所有数据长度仍然是固定的。例如,第一时刻传入的数据shape=[batch_size, 10],第二时刻传入的数据shape=[batch_size, 12],第三时刻传入的数据shape=[batch_size, 8]等等。但是rnn不能这样,它要求每一时刻传入的batch数据的[batch_size, max_seq],在每次迭代过程中都保持不变。这样不就必须要求全部数据都要pad到统一的max_seq长度了吗?是的,但也有个折中办法。——将数据集的sequence length做个初步统计,看会落在哪几个区间段内。然后根据区间段将数据进行归类,也就是所谓的放在不同buckets中。最后用rnn为每一个buckets都创建一个sub graph。训练的时候,根据当前batch data所归属的bucket id,找到它对应的sub graph,进行参数更新(虽然是不同的sub graph,但参数是共享的。至少tensorflow中是这么实现的~(≧▽≦)/~)具体可参看:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py#L1143另外,两者的输入形式确实不同,但你提到的“将填充的部分输出为0”,给rnn传入sequence_length这个参数后,也是可以的。
注:内容均转至网上,如有侵权请及时联系。
上一篇: this有关