一种输入[batch, seq_len1, hidden_dim]输出[batch, seq_len2, hidden_dim]的self-attention的pytorch实现
程序员文章站
2022-06-15 13:47:36
...
class Attention(nn.Module):
"""
inputs是[batch, seq_len1, hidden_dim]
labels_num是seq_len2
"""
def __init__(self, labels_num, hidden_size):
super(Attention, self).__init__()
self.attention = nn.Linear(hidden_size, labels_num, bias=False)
nn.init.xavier_uniform_(self.attention.weight)
def forward(self, inputs, masks):
masks = torch.unsqueeze(masks, 1) # [batch, 1, seq_len1]
attention = self.attention(inputs).transpose(1, 2).masked_fill(1.0 - masks, -np.inf) # attention 是 [batch, labels_num, seq_len1]
attention = F.softmax(attention, -1)
return attention @ inputs # return结果 [batch, labels_num, hidden_size]
上一篇: GSON 转换
下一篇: 关于Flex Ant的相关注意事项