欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

GAT中的如何实现稀疏注意力

程序员文章站 2022-05-25 11:52:44
...

之前一直没看GAT的代码(https://github.com/PetarV-/GAT),不知道稀疏矩阵下如何实现注意力的,今天看到,恍然大悟,记录于此

首先,由于稀疏矩阵参与运算时其中的参数不能自动更新(pytorchz中暂时没有其反向传播函数),所以GAT自己写了稀疏矩阵(计算完注意力后的邻接矩阵)与稠密矩阵(特征)的乘法

class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)

在需要就算注意力的时候,就可以通过以下方式

class SpGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout, alpha, concat=True):
        super(SpGraphAttentionLayer, self).__init__()
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_dim, out_dim)),requires_grad=True)
        nn.init.xavier_normal_(self.W.data, gain=1.414)
                
        self.a = nn.Parameter(torch.zeros(size=(1, 2*out_dim)),requires_grad=True)
        nn.init.xavier_normal_(self.a.data, gain=1.414)

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()

    def forward(self, input, adj):
        dv = 'cuda' if input.is_cuda else 'cpu'
        N = input.size()[0]#节点数量2708
        edge_index = adj.nonzero().t()
        
         #先对所有特征进行一次线性变换
        input=torch.mm(input.self.W)

        # 连接节点与其所有邻居的表示
        hidden = torch.cat((input[edge_index[0, :], :], input[edge_index[1, :], :]), dim=1).t()
        
        #通过向量a得到分数
        edge_value = torch.exp(-self.leakyrelu(self.a.mm(hidden).squeeze()))#注意力中的分子
        e_rowsum = self.special_spmm(edge_index, edge_value, torch.Size([N, N]), torch.ones(size=(N,1), device=dv))

        edge_value= self.dropout(edge_value)
        
        #这也是一个技巧,乘完特征再去除,跟算完注意力再去乘特征是一个道理
        h_prime = self.special_spmm(edge_index, edge_value, torch.Size([N, N]), input)
        h_prime = h_prime.div(e_rowsum)
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

 

相关标签: 应用