欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

一种输入[batch, seq_len1, hidden_dim]输出[batch, seq_len2, hidden_dim]的self-attention的pytorch实现

程序员文章站 2022-06-15 13:47:36
...
class Attention(nn.Module):
    """
    inputs是[batch, seq_len1, hidden_dim]
    labels_num是seq_len2
    """
    def __init__(self, labels_num, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size, labels_num, bias=False)
        nn.init.xavier_uniform_(self.attention.weight)

    def forward(self, inputs, masks):
        masks = torch.unsqueeze(masks, 1)  # [batch, 1, seq_len1]
        attention = self.attention(inputs).transpose(1, 2).masked_fill(1.0 - masks, -np.inf)  # attention 是 [batch, labels_num, seq_len1]
        attention = F.softmax(attention, -1)
        return attention @ inputs   # return结果 [batch, labels_num, hidden_size]