BiLSTM+ Attention Pytorch实现
程序员文章站
2024-03-25 00:01:35
...
最近写算法的时候发现网上关于BiLSTM加Attention的实现方式五花八门,其中很多是错的,自己基于PyTorch框架实现了一版,主要用到了LSTM处理变长序列和masked softmax两个技巧。代码如下:
1、attention_utils.py
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
def create_src_lengths_mask(
batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None
):
"""
Generate boolean mask to prevent attention beyond the end of source
Inputs:
batch_size : int
src_lengths : [batch_size] of sentence lengths
max_src_len: Optionally override max_src_len for the mask
Outputs:
[batch_size, max_src_len]
"""
if max_src_len is None:
max_src_len = int(src_lengths.max())
src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
src_indices = src_indices.expand(batch_size, max_src_len)
src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)
# returns [batch_size, max_seq_len]
return (src_indices < src_lengths).int().detach()
def masked_softmax(scores, src_lengths, src_length_masking=True):
"""Apply source length masking then softmax.
Input and output have shape bsz x src_len"""
if src_length_masking:
bsz, max_src_len = scores.size()
# compute masks
src_mask = create_src_lengths_mask(bsz, src_lengths)
# Fill pad positions with -inf
scores = scores.masked_fill(src_mask == 0, -np.inf)
# Cast to float and then back again to prevent loss explosion under fp16.
return F.softmax(scores.float(), dim=-1).type_as(scores)
2、layers.py
import torch
from torch import nn
import torch.nn.functional as F
from utils.attention_utils import masked_softmax
# s(x, q) = v.T * tanh (W * x + b)
class MLPAttentionNetwork(nn.Module):
def __init__(self, hidden_dim, attention_dim, src_length_masking=True):
super(MLPAttentionNetwork, self).__init__()
self.hidden_dim = hidden_dim
self.attention_dim = attention_dim
self.src_length_masking = src_length_masking
# W * x + b
self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
# v.T
self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)
def forward(self, x, x_lengths):
"""
:param x: seq_len * batch_size * hidden_dim
:param x_lengths: batch_size
:return: batch_size * seq_len, batch_size * hidden_dim
"""
seq_len, batch_size, _ = x.size()
# (seq_len * batch_size, hidden_dim)
# flat_inputs = x.view(-1, self.hidden_dim)
flat_inputs = x.reshape(-1, self.hidden_dim)
# (seq_len * batch_size, attention_dim)
mlp_x = self.proj_w(flat_inputs)
# (batch_size, seq_len)
att_scores = self.proj_v(mlp_x).view(seq_len, batch_size).t()
# (seq_len, batch_size)
normalized_masked_att_scores = masked_softmax(
att_scores, x_lengths, self.src_length_masking
).t()
# (batch_size, hidden_dim)
attn_x = (x * normalized_masked_att_scores.unsqueeze(2)).sum(0)
return normalized_masked_att_scores.t(), attn_x
3、model.py
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from layers import MLPAttentionNetwork
class BiLSTMAttentionNetwork(nn.Module):
def __init__(self, num_vocab, embedding_dim, max_len, hidden_dim, num_layers, bidirectional, attention_dim, num_classes):
super(BiLSTMAttentionNetwork, self).__init__()
# 词表长度,实际单词数量+填充占位符(1)
self.num_vocab = num_vocab
# 最大序列长度
self.max_len = max_len
# 单词隐向量维度
self.embedding_dim = embedding_dim
# LSTM中隐层的维度
self.hidden_dim = hidden_dim
# 循环神经网络的层数
self.num_layers = num_layers
# 是否使用双向RNN,布尔值
self.bidirectional = bidirectional
# 注意力层参数维度
self.attention_dim = attention_dim
# 标签数量
self.num_classes = num_classes
# Embedding层
self.embedding_layer = nn.Embedding(self.num_vocab, self.embedding_dim, padding_idx=0)
# RNN层
self.bilstm_layer = nn.LSTM(self.embedding_dim, self.hidden_dim, self.num_layers, bidirectional=self.bidirectional,
batch_first=True)
# MLP注意力层
self.mlp_attention_layer = MLPAttentionNetwork(2 * self.hidden_dim, self.attention_dim)
# 全连接层
self.fc_layer = nn.Linear(2 * self.hidden_dim, self.num_classes)
# 单层softmax分类器
self.softmax_layer = nn.Softmax(dim=1)
def forward(self, x, lengths):
"""
:param x: 填充好的序列
:param lengths:
:return:
"""
# x: t.tensor([[1,2,3],[6,0,0],[4,5,0], [3, 7, 1]])
# lengths: t.tensor([3, 1, 2, 3])、序列的实际长度
x_input = self.embedding_layer(x)
# print(x_input)
x_packed_input = pack_padded_sequence(input=x_input, lengths=lengths, batch_first=True, enforce_sorted=False)
# print(x_packed_input)
packed_out, _ = self.bilstm_layer(x_packed_input)
# print(packed_out)
outputs, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=self.max_len, padding_value=0.0)
# print(out)
atten_scores, atten_out = self.mlp_attention_layer(outputs.permute(1, 0, 2), lengths)
# print(atten_out)
# (batch_size, num_classes)
logits = self.softmax_layer(self.fc_layer(atten_out))
return atten_scores, logits
测试代码如下:
if __name__ == '__main__':
# num_vocab, embedding_dim, hidden_dim, num_layers, bidirectional, attention_dim, num_classes
b_a = BiLSTMAttentionNetwork(20, 3, 3, 2, 1, bidirectional=True, attention_dim=5, num_classes=10)
x = torch.tensor([[1, 2, 3], [6, 0, 0], [4, 5, 0], [3, 7, 1]])
lengths = torch.tensor([3, 1, 2, 3])
atten_scores, logits = b_a(x, lengths)
print('--------------------->注意力分布')
print(atten_scores)
print('--------------------->预测概率')
print(logits)
效果如下:
--------------------->注意力分布:
tensor([[0.3346, 0.3320, 0.3335],
[1.0000, 0.0000, 0.0000],
[0.4906, 0.5094, 0.0000],
[0.3380, 0.3289, 0.3330]], grad_fn=<TBackward>)
--------------------->预测概率:
tensor([[0.0636, 0.0723, 0.1241, 0.0663, 0.0671, 0.0912, 0.1244, 0.1446, 0.1223,
0.1240],
[0.0591, 0.0745, 0.1264, 0.0650, 0.0657, 0.0853, 0.1273, 0.1478, 0.1186,
0.1303],
[0.0634, 0.0727, 0.1178, 0.0678, 0.0688, 0.0925, 0.1203, 0.1492, 0.1228,
0.1249],
[0.0615, 0.0739, 0.1253, 0.0633, 0.0675, 0.0872, 0.1226, 0.1490, 0.1224,
0.1274]], grad_fn=<SoftmaxBackward>)
上一篇: Android开发总结之低功耗蓝牙开发
下一篇: 视觉SLAM十四讲 实践笔记(一)