欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

RNN, LSTM和GRU [Pytorch]

程序员文章站 2024-03-24 23:48:52
...

SpaceX 猎鹰重型 现代工程的杰作(高燃混剪)

 

 

import numpy as np
import torch
from torch import nn


#RNN
rnn = nn.RNN(10, 7, 2)          #(each_input_size, hidden_state, num_layers)
inputs = torch.randn(5, 3, 10)  #(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 7)       #(num_layers * num_directions, batch, hidden_size)
output, hn = rnn(inputs, h0)
print(output.size(), hn.size())
print(output)
print()

#LSTM
rnn = nn.LSTM(10, 7, 2)        #(each_input_size, hidden_state, num_layers)
inputs = inputs
h0 = h0
c0 = torch.randn(2, 3, 7)      #(num_layers * num_directions, batch, hidden_size)
output, (hn,cn) = rnn(inputs, (h0,c0))   #seq_len x batch x hidden*bi_directiona
相关标签: LSTM