对LSTM中间变量形状shape的理解, 附keras中LSTM的各个变量的shape理解
假设输入的shape是[bs, length, d], bs是批数量, length是预定义的最大序列长度, d是序列中每个step的维度(对于图像序列,可以理解为每一帧的特征向量维度).
下面说对于bs中1个样本的情况, 也就是shape为[1, length, d]
LSTM(或者RNN)有多个cell, 1个cell对应1个step(1个时刻的状态), 这些cell之间的网络层是共享的, 即对于1个LSTM层, 所有的参数数量等于1个cell的参数数量(下图的LSTM由3个cell)
下图是1个样本在1个cell中(step=t)的工作原理和中间变量的shape:
图中的六边形表示神经网络层(通常是全连接层), 表示矩阵乘积操作, 里面是这个网络层的权重W(shape:[d+m, m]), 可以分为Wa和Ua, Wa(shape: [d, m])是与Xt相乘的, Ua(shape: [m, m])是与ht-1相乘的, 红色是权重的shape.
4个六边形对应3个门(输入门涉及了两个六边形, 由it, ct共同决定).
m是LSTM的unit个数, 就是每个网络层的神经元个数(全连接层的输出向量维度)
具体的各个操作如下:
更多的关于LSTM的原理可以参考这里和这里和这里.
(个人原创,转载请注明出处https://blog.csdn.net/ying86615791/article/details/103085269,谢谢!)
每个cell都由1个Xt,1个Ct-1和1个ht-1作为输入,
生成1个ht和1个Ct.
因为有length个step, 就有length个cell,
那么1个样本(X: [length, d])最终可以产生的各个变量的shape如下
h: [length, m], 隐藏状态
c: [length, m], 细胞状态
如果网络只有1个输出,
通常取最后1个cell(即最后1个step, 即最后1个时刻)输出的h_last[1, m]作为最终LSTM的输出, 因为该此时的h已经融合了之前的信息了.
上面是1个样本的情况, 对于bs个样本, 该LSTM的各个变量的shape就是
h: [bs, length, m], 隐藏状态
c: [bs, length, m], 细胞状态
最后LSTM的输出就是
h_last: [bs, 1, m]==>[bs, m]
下面从从keras中1个LSTM使用例子来理解这几个变量的shape
import numpy as np
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Input
bs = 5
length = 60
d = 1
m = 30
x = np.random.rand(bs, length, d)
input_ = Input(shape=(length, d))
lstm, hidden, cell = LSTM(units=m, return_state=True, return_sequences=True)(input_)
model = Model(inputs=input_, outputs=(lstm, hidden, cell))
print('input shape ',x.shape)
output, hidden, cell = model(x)
print('output shape ',output.shape)
wx, ux, b = model.layers[1].trainable_weights
print('wx shape ',wx.shape)
print('ux shape ',ux.shape)
print('b shape ',b.shape)
length=60表示序列长度为60个step,过程中会有60个cell
运行后, 各个变量的shape如下
这里的wx, ux, b是把4个全连接层的权重拼在一起的结果, 参考
也就是, 对于其中1层的话,
wx: [1, 30]
ux: [30, 30]
b: [30]
这里的hidden是最后1个step(也就是最后1个cell)的h, cell也是最后1个step的c
output shape的意思是包含了60个cell的h, 每个h的维度是m=30, output由所有cell输出的h组成
如果只取最后时刻的h作为输出的话, 就是
output[:, -1, :]==>shape: [5, 30]
注意,其实这个时候output[:,-1, :]的值就是就是hidden了, 验证如下:
通过上面可以知道所有cell的h都在output里面, 那么keras中如何拿到所有cell的细胞状态c呢?看这里好像还不能实现?
下一篇: 十分钟学会 GIT 命令(补充)