欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch中的变长bi-lstm

程序员文章站 2022-03-16 17:21:40
...

在Tensorflow 1.12中如果LSTM输入的序列是变长的话,有dynamic_rnn()bidirection_dynamic_rnn()方法来处理,但是在pytorch中怎么处理呢?

在pytorch中也有相对应的方法在torch.nn.utils包中的pack_padded_sequence()pad_packed_sequence()用来处理变长序列的问题。pack是压缩的意思,pad是填充的意思。pytorch的处理是先根据seq_len来压缩输入,经过LSTM后再填充。所以这两个方法都要知道padding的index。

# -*- coding: utf-8 -*-
import torch
from torch import nn
from torch.nn import LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

X = torch.tensor([[1, 2, 3],
                  [1, 2, 0],
                  [3, 0, 0],
                  [2, 1, 0]])
seq_len = torch.tensor([3, 2, 1, 2])
print('X shape', X.shape)
vocab_size = 4
embedding_dim = 2
hidden_size = 6
batch_size = 4

# tell word embedding the padding_idx,
# whenever encounter the padding idx, get the vector with padding idx
word_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
word_vectors = word_embedding(X)
print(word_vectors)

lstm = LSTM(input_size=embedding_dim,
            hidden_size=hidden_size // 2,
            num_layers=1,
            bidirectional=True,
            batch_first=True)

packed_X = pack_padded_sequence(word_vectors, seq_len, batch_first=True, enforce_sorted=False)

print('--->pack padded sequence:{}'.format(packed_X))

# num_layers* num_directions=1*2=2
h0 = (torch.randn(2, batch_size, hidden_size // 2),
      torch.randn(2, batch_size, hidden_size // 2))

outputs, h = lstm(packed_X, h0)
print('after lstm outputs: {}'.format(outputs))
pad_packed_X = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0)
# pad_packed_X是一个tuple,
print('---->pad_packed:{}'.format(pad_packed_X))
# [batch_size, max_seq_len, hidden_size]
print('---->pad_outputs:{}'.format(pad_packed_X[0].shape))
print('---->seq_len:{}'.format(pad_packed_X[1].shape))
X shape torch.Size([4, 3])
tensor([[[ 0.7049, -0.6178],
         [-2.0429,  0.7651],
         [-0.4018,  0.5503]],

        [[ 0.7049, -0.6178],
         [-2.0429,  0.7651],
         [ 0.0000,  0.0000]],

        [[-0.4018,  0.5503],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[-2.0429,  0.7651],
         [ 0.7049, -0.6178],
         [ 0.0000,  0.0000]]], grad_fn=<EmbeddingBackward>)
--->pack padded sequence:PackedSequence(data=tensor([[ 0.7049, -0.6178],
        [ 0.7049, -0.6178],
        [-2.0429,  0.7651],
        [-0.4018,  0.5503],
        [-2.0429,  0.7651],
        [-2.0429,  0.7651],
        [ 0.7049, -0.6178],
        [-0.4018,  0.5503]], grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
after lstm outputs: PackedSequence(data=tensor([[ 0.0594, -0.1153,  0.2028,  0.0240, -0.3177,  0.1364],
        [ 0.1639, -0.0723,  0.2713,  0.0129, -0.3103, -0.0151],
        [ 0.0144,  0.3130,  0.0122, -0.1323, -0.0634,  0.0550],
        [ 0.2420, -0.3547,  0.1539,  0.0024, -0.4087,  0.3998],
        [ 0.0393,  0.2744,  0.0270, -0.0392, -0.0846,  0.2874],
        [ 0.0752,  0.3018,  0.0351, -0.0354, -0.2119, -0.5068],
        [ 0.2033,  0.1408,  0.0661, -0.4325, -0.0149,  0.0146],
        [ 0.0668,  0.3305,  0.0304,  0.2233, -0.0265,  0.1531]],
       grad_fn=<CatBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
---->pad_packed:(tensor([[[ 0.0594, -0.1153,  0.2028,  0.0240, -0.3177,  0.1364],
         [ 0.0393,  0.2744,  0.0270, -0.0392, -0.0846,  0.2874],
         [ 0.0668,  0.3305,  0.0304,  0.2233, -0.0265,  0.1531]],

        [[ 0.1639, -0.0723,  0.2713,  0.0129, -0.3103, -0.0151],
         [ 0.0752,  0.3018,  0.0351, -0.0354, -0.2119, -0.5068],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.2420, -0.3547,  0.1539,  0.0024, -0.4087,  0.3998],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0144,  0.3130,  0.0122, -0.1323, -0.0634,  0.0550],
         [ 0.2033,  0.1408,  0.0661, -0.4325, -0.0149,  0.0146],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<IndexSelectBackward>), tensor([3, 2, 1, 2]))
---->pad_outputs:torch.Size([4, 3, 6])
---->seq_len:torch.Size([4])

Process finished with exit code 0

参考 :
https://blog.csdn.net/So_that/article/details/94731614