pytorch中的LSTM
程序员文章站
2022-03-16 17:11:52
...
RNN和RNNCell层的区别在于前者能处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性。RNN层可以通过调用RNNCell来实现
# -*- coding: utf-8 -*-
#@Time :2019/7/1 22:41
#@Author :XiaoMa
import torch as t
from torch.autograd import Variable as V
from torch import nn
t.manual_seed(1000)#设置随机数种子,保证每次运行得到相同的结果
input=V(t.randn(2,3,4)) #batch_size=3,序列长度为2,序列中每个元素占4维
#LSTM输入向量为4维,3个隐藏元,1层
lstm=nn.LSTM(4,3,1)
#初始状态:1层,batch_size=3,3个隐藏元
h0=V(t.randn(1,3,3))
c0=V(t.randn(1,3,3))
out,hn=lstm(input,(h0,c0))
print(out)
输出结果:
tensor([[[-0.3610, -0.1643, 0.1631],
[-0.0613, -0.4937, -0.1642],
[ 0.5080, -0.4175, 0.2502]],
[[-0.0703, -0.0393, -0.0429],
[ 0.2085, -0.3005, -0.2686],
[ 0.1482, -0.4728, 0.1425]]], grad_fn=<CatBackward>)
示例2:
# -*- coding: utf-8 -*-
#@Time :2019/7/2 14:04
#@Author :XiaoMa
import torch as t
from torch.autograd import Variable as V
from torch import nn
t.manual_seed(1000)#设置随机数种子,保证每次运行得到相同的结果
input=V(t.randn(2,3,4)) #batch_size=3,序列长度为2,序列中每个元素占4维
input=V(t.randn(2,3,4))
#一个LSTMCell对应的层数只能是一层
lstm=nn.LSTMCell(4,3)
hx=V(t.randn(3,3,))
cx=V(t.randn(3,3))
out=[]
for i_ in input:
hx,cx=lstm(i_,(hx,cx))
out.append(hx)
print(t.stack(out))
#pytorch中的embedding层
embedding=nn.Embedding(4,5) #有4个词,每个词有5维向量表示
#可以用预训练好的词向量初始化embedding
embedding.weight.data=t.arange(0,20).view(4,5)
input=V(t.arange(3,0,-1).long())
output=embedding(input)
print('output=',output)
输出结果:
[ 0.2780, 0.0367, -0.0181],
[-0.0443, 0.2794, -0.3031]],
[[ 0.1455, 0.0193, -0.0891],
[ 0.3409, 0.2397, -0.2647],
[ 0.1984, 0.0731, -0.2611]]], grad_fn=<StackBackward>)
output= tensor([[15, 16, 17, 18, 19],
[10, 11, 12, 13, 14],
[ 5, 6, 7, 8, 9]], grad_fn=<EmbeddingBackward>)
下一篇: Pytorch 中的 LSTM
推荐阅读
-
EditPlus中的正则表达式中英文使用详解(附常用实例)
-
怎么让网页自动翻译? 在IE浏览器中实现网页自动翻译的方法
-
PowerDesigner 建立与SQLSERVER 2005数据库的连接以便生成数据库和从数据库生成到PD中
-
VS2012中通过IIS发布站点的步骤分享
-
windows系统中OneDrive无法打开登陆不了的解决办法
-
RubyMine编辑器中安装CoffeeScript和CoffeeScriptRedux的方法
-
PowerDesigner 建立与数据库的连接以便生成数据库和从数据库生成到PD中(Oracle 10G版)
-
MyEclipse中配置struts.xml自动提示的方法
-
ProE投影曲线中投影链怎么创建? ProE创建投影链的教程
-
Visual Studio中11个强大的调试技巧和方法