欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow 一系列LSTMCell的特点及用法

程序员文章站 2022-03-29 08:14:18
...

首先鸣谢这个博客  https://www.cnblogs.com/hrlnw/p/10748990.html 带来的启发

原博客用的tf1.10,我用的1.15,实测无影响。

tf.nn.rnn_cell、tf.compat.v1.nn.rnn_cell和tf.contrib.rnn互相等价,rnn的包分为两个部分

1. tf.contrib.rnn         2.tf.contrib.cudnn_rnn

一、 tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=hidden_num, initializer=weight_initializer)

=tf.nn.rnn_cell.LSTMCell()=tf.nn.rnn_cell.BasicLSTMCell()

该函数清楚标明“Note that this cell is not optimized for performance. Please use tf.contrib.cudnn_rnn.CudnnLSTM for better performance on GPU, or tf.contrib.rnn.LSTMBlockCell and tf.contrib.rnn.LSTMBlockFusedCell for better performance on CPU.”

tf.nn.rnn_cell.BasicLSTMCell 应该被最后考虑使用。对于不常见的RNN cell类型( tf.contrib.rnn.BasicLSTMCell 变体,比如:tf.contrib.rnn.NASCell、tf.contrib.rnn.PhasedLSTMCell,tf.contrib.rnn.UGRNNCell,tf.contrib.rnn.GLSTMCell,tf.contrib.rnn.Conv1DLSTMCell,tf.contrib.rnn.Conv2DLSTMCell,tf.contrib.rnn.LayerNormBasicLSTMCell等),我们应该意识到它们在计算图中,像tf.contrib.rnn.BasicLSTMCell 一样,性能低,并且内存占用高。我们在使用这些单元前,需要考虑这样的平衡是否值得。例如,虽然 layer normalization 能够加速收敛速度,但在不使用layer normalization的情况下,cuDNN 能够加速20倍。

二、tf.contrib.rnn.LSTMBlockCell(num_units=hidden_num)继承自LayerRNNCell,适用于一个时间步运行一个rnncell的场景。

       如果不是使用一个 RNN layer,而是只使用一个 RNN cell,应该首要选择 tf.nn.dynamic_rnn。无论是dynamic_rnn还是static_rnn,对性能都没有影响,但是好处有:

1. 如果 inputs 过大的话,使用 tf.nn.static_rnn 会增加 graph 的大小,并且有增加编译时间。

2. tf.nn.dynamic_rnn 能够很好地处理长 sequence,它可以从 GPU 往 CPU 中交换内存。

有可能的话,可以在tf.while_loop中并行运行多个tf.nn.dynamic_rnn,但在RNN中几乎不用,因为他们本来就是序列的。

三、tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units = hidden_num)继承自LSTMBlockCell,适用于一个时间步运行一个rnncell的场景。

1. 如果 NN 限定只在 NVIDIA 的 GPU 上运行,可以考虑使用 tf.contrib.cudnn_rnn,它通常比 tf.contrib.rnn.BasicLSTMCell 和 tf.contrib.rnn.LSTMBlockCell 快一个数量级,并且,相比于 tf.contrib.rnn.BasicLSTMCell,它使用少三四倍的内存。
2. 如果 NN 需要 layer normalization, 则不应该使用 tf.contrib.cudnn_rnn。

四、tf.contrib.rnn.LSTMBlockFusedCell(num_units=hidden_num)继承自LSTMBlockWrapper,相当于一个rnn层,不能用tf.nn.dynamic_rnn、tf.contrib.rnn.FusedRNNCellAdaptor等修饰,需要直接实例化。在只有 CPU,或者 GPU 机器上无法获得 tf.contrib.cudnn_rnn,或者移动设备上,应该使用 tf.contrib.rnn.LSTMBlockFusedCell。这个是个速度王者,但是无法按照时间步输出。

五、手写LSTMCell,这个下面样例会出现,毕竟没经过专业优化,速度real一般了,但是比起tf.nn.rnn_cell.LSTMCell()=tf.nn.rnn_cell.BasicLSTMCell()之流还是强太多

下面上代码(有点乱,凑合看吧)

import numpy as np
import tensorflow as tf
import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

batch_size = 1
time_step = 70
hidden_num = 512
weight_initializer = tf.truncated_normal_initializer(stddev=0.01)
w_h = tf.compat.v1.get_variable('w_h', [hidden_num, hidden_num * 4], initializer=tf.initializers.orthogonal())
w_in = tf.compat.v1.get_variable('w_in', [hidden_num, hidden_num * 4], initializer=tf.initializers.orthogonal())
def lstm_cell(input, hidden_state, cell_state, w_in=w_in, w_h=w_h):
    pack_with_bias = tf.add(tf.matmul(input, w_in), tf.matmul(hidden_state, w_h))
    i, f, o, g = tf.split(pack_with_bias, num_or_size_splits=4, axis=1)  # (bsize, hid_dim)
    i = tf.sigmoid(i)
    f = tf.sigmoid(f)  # (f + forget_bias)
    o = tf.sigmoid(o)
    g = tf.tanh(g)
    c = tf.add(tf.multiply(f, cell_state), tf.multiply(i, g))
    h = tf.multiply(o, tf.tanh(c))

    return h, c


rnn_cell1 = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=hidden_num, initializer=weight_initializer)
rnn_cell2 = tf.contrib.rnn.LSTMBlockCell(num_units=hidden_num)
rnn_cell3 = tf.contrib.rnn.LSTMBlockFusedCell(num_units=hidden_num)
rnn_cell4 = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units = hidden_num)


np_input_data = np.random.randn(batch_size, time_step, hidden_num).astype(np.float32)
np_hidden_state = np.random.randn(batch_size, hidden_num).astype(np.float32)
np_cell_state = np.random.randn(batch_size, hidden_num).astype(np.float32)
np_input_len = [time_step]*batch_size
input_data = tf.placeholder(dtype=tf.float32, shape=[batch_size, time_step, hidden_num], name='input_data')
hidden_data = tf.placeholder(dtype=tf.float32, shape=[batch_size, hidden_num], name='hidden')
cell_data = tf.placeholder(dtype=tf.float32, shape=[batch_size, hidden_num], name='cell')
trans_data = tf.transpose(input_data, [1, 0, 2])
state =(cell_data, hidden_data)
rnn_cell_list = [rnn_cell1, rnn_cell2, rnn_cell3, rnn_cell4]


outputs = []
output_array = tf.TensorArray(dtype=tf.float32, size=time_step)
def wl_t_func(i, trans_data, output_array, h, c):
    h, c = lstm_cell(trans_data[i,:,:], h, c)
    output_array = output_array.write(i, h)
    return i+1, trans_data, output_array, h, c
_, _, output_array, _, _ = tf.while_loop(cond=lambda i, *_: i<time_step, body=wl_t_func, loop_vars=(tf.constant(0, tf.int32), tf.transpose(tf.convert_to_tensor(np_input_data),[1, 0, 2]), output_array,
                                                                                                    tf.convert_to_tensor(np_hidden_state), tf.convert_to_tensor(np_cell_state)))
output_array = output_array.stack()

# for t in range(time_step):
#     h, c = lstm_cell(trans_data[t,:,:], hidden_data, cell_data)
#     hidden_data, cell_data = h, c
#     outputs.append(h)
# outputs = tf.stack(outputs, axis=0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    start = time.time()
    # result = sess.run([outputs], feed_dict={input_data: np_input_data, hidden_data:np_hidden_state, cell_data:np_cell_state})[0]

    result = sess.run(output_array)
    end = time.time()
print('lstm_cell', '*', end - start, '*', result.shape)


for i in range(4):
    outputs = [trans_data]
    rnn_cell = rnn_cell_list[i]
    if i == 0:
        fw_rnn = tf.contrib.rnn.FusedRNNCellAdaptor(rnn_cell, use_dynamic_rnn=False)
        outputs1, state1 = fw_rnn(outputs[-1], sequence_length=np_input_len, dtype=tf.float32)  # time_len, batch, output_size
        outputs.append(outputs1)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            start = time.time()
            result = sess.run([outputs[-1]], feed_dict={input_data: np_input_data})[0]
            end = time.time()
        print(i, '*', end - start, '*', result.shape)

    elif i == 2:
        fw_rnn = rnn_cell
        outputs_t = fw_rnn(outputs[-1], dtype=tf.float32)

        outputs.append(outputs_t)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            start = time.time()
            result = sess.run([outputs[-1]], feed_dict={input_data: np_input_data})[0][0]
            end = time.time()
        print(i, '*', end - start, '*', result.shape)

    else:
        fw_rnn = rnn_cell

        output_array = tf.TensorArray(dtype=tf.float32, size=time_step)

        def wl_t_fw_rnn(i, trans_data, output_array, state):
            h, new_state = fw_rnn(trans_data[i, :, :], state)
            output_array = output_array.write(i, h)
            return i + 1, trans_data, output_array, new_state

        init_state = tf.nn.rnn_cell.LSTMStateTuple(c=tf.convert_to_tensor(np_cell_state), h=tf.convert_to_tensor(np_hidden_state))
        _, _, output_array, _ = tf.while_loop(cond=lambda i, *_: i < time_step, body=wl_t_fw_rnn, loop_vars=(
        tf.constant(0, tf.int32), tf.transpose(tf.convert_to_tensor(np_input_data), [1, 0, 2]), output_array, init_state))
        output_array = output_array.stack()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            start = time.time()
            result = sess.run(output_array)
            end = time.time()
        print(i, '*while_loop*', end - start, '*', result.shape)

        outputs = []
        for t in range(time_step):
            h, new_state = fw_rnn(trans_data[t,:,:], state)
            state = new_state
            outputs.append(h)
        outputs = tf.stack(outputs, axis=0)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            start = time.time()
            result = sess.run([outputs], feed_dict={input_data: np_input_data, cell_data:np_cell_state, hidden_data:np_hidden_state})[0]
            end = time.time()
        print(i, '*for*', end - start, '*', result.shape)


运行时间统计如下:

cell类型 时间 时间for循环
手写lstmcell whileloop 0.227  
tf.compat.v1.nn.rnn_cell.LSTMCell 0.772  
tf.contrib.rnn.LSTMBlockCell

whileloop

0.075

0.156
tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell

whileloop

0.086

0.247
tf.contrib.rnn.LSTMBlockFusedCell 0.066  

 

相关标签: TF tensorflow