GAT中的如何实现稀疏注意力
程序员文章站
2022-05-25 11:52:44
...
之前一直没看GAT的代码(https://github.com/PetarV-/GAT),不知道稀疏矩阵下如何实现注意力的,今天看到,恍然大悟,记录于此
首先,由于稀疏矩阵参与运算时其中的参数不能自动更新(pytorchz中暂时没有其反向传播函数),所以GAT自己写了稀疏矩阵(计算完注意力后的邻接矩阵)与稠密矩阵(特征)的乘法
class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
在需要就算注意力的时候,就可以通过以下方式
class SpGraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, dropout, alpha, concat=True):
super(SpGraphAttentionLayer, self).__init__()
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_dim, out_dim)),requires_grad=True)
nn.init.xavier_normal_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(1, 2*out_dim)),requires_grad=True)
nn.init.xavier_normal_(self.a.data, gain=1.414)
self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.special_spmm = SpecialSpmm()
def forward(self, input, adj):
dv = 'cuda' if input.is_cuda else 'cpu'
N = input.size()[0]#节点数量2708
edge_index = adj.nonzero().t()
#先对所有特征进行一次线性变换
input=torch.mm(input.self.W)
# 连接节点与其所有邻居的表示
hidden = torch.cat((input[edge_index[0, :], :], input[edge_index[1, :], :]), dim=1).t()
#通过向量a得到分数
edge_value = torch.exp(-self.leakyrelu(self.a.mm(hidden).squeeze()))#注意力中的分子
e_rowsum = self.special_spmm(edge_index, edge_value, torch.Size([N, N]), torch.ones(size=(N,1), device=dv))
edge_value= self.dropout(edge_value)
#这也是一个技巧,乘完特征再去除,跟算完注意力再去乘特征是一个道理
h_prime = self.special_spmm(edge_index, edge_value, torch.Size([N, N]), input)
h_prime = h_prime.div(e_rowsum)
if self.concat:
return F.elu(h_prime)
else:
return h_prime