欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch中scatter()和scatter_()函数的作用

程序员文章站 2024-03-24 23:31:34
...

本文讲的是我对PyTorch中scatter()函数的理解。

原创,转载请标明来源。

 

一言以蔽之:修改tensor中的指定位置的值。

函数

scatter(dim, index, src) 

  • dim: 索引的维度。按照i, j, k, ...的哪个方向进行索引
  • index: 索引。可以是一个tensor,存储需要改的元素的位置的tensor
  • src: 用src中的值来修改。可以是tensor;可以是一个数字,用同样的数字写入tensor

scatter() 和 scatter_() 函数功能相同:只不过带下划线的函数,通常是直接修改原来的tensor

原理

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

函数的具体实现,如上述代码框所示:使用src中的值,修改self中位置为index[i][j][k]的值。

举例

# 这是src
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])

# index是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]
# self就是下面的torch.zeros(3, 5)
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

# 这是结果
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

在这个例子中,dim=0,是按照i的方向修改torch.zeros(3, 5)的。可以看出,index实际上表达了一种映射关系:同样都是在第j列上,将src在该列的值 根据index在该列的指示 映射到self的这一列上。

比如:src在第0列的0.01940和0.2078,被放到self的第0列上,但不是完全一样的放过来,而是经过index变化了上下位置。其他列同理。

在简单RNN中的应用

该应用代码如下 [1]:

def one_hot(x, n_class, dtype=torch.float32): 
    # X shape: (batch), output shape: (batch, n_class)
    x = x.long()
    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)
    res.scatter_(1, x.view(-1, 1), 1)
    return res

x = torch.tensor([0, 2])
one_hot(x, vocab_size)

 运行结果:

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]])

在本例中,该函数的任务是将输入的文本使用one_hot编码。其中,x是一个vector,代表一个二字词语,其中的0和2代表汉字(在程序上文定义的字典中)所对应的数字。vocab_size是字典大小,即在该程序中所考虑的汉字总个数。n_class是one_hot编码中所考虑的类别数,在本例中等于vocab_size。

res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) :生成了两行,vocab_size列的零矩阵

res.scatter_(1, x.view(-1, 1), 1) :在res中,将1,按照dim=1(即不改行改列)的方向,根据[[0],[2]]所指示的位置,放入res中。(比如,x中的0,代表要放入第0列;而0本身处于第0行,所以是第0行中的第0列。)

 

参考文献

[1] 循环神经网络的从零开始实现. 原书作者:阿斯顿·张、李沐、扎卡里 C. 立顿、亚历山大 J. 斯莫拉以及其他社区贡献者. 原书名称:动手学深度学习Pytorch版

[2] PyTorch官方文档

相关标签: Python PyTorch