PyTorch中scatter()和scatter_()函数的作用
本文讲的是我对PyTorch中scatter()函数的理解。
原创,转载请标明来源。
一言以蔽之:修改tensor中的指定位置的值。
函数
scatter(dim, index, src)
- dim: 索引的维度。按照i, j, k, ...的哪个方向进行索引
- index: 索引。可以是一个tensor,存储需要改的元素的位置的tensor
- src: 用src中的值来修改。可以是tensor;可以是一个数字,用同样的数字写入tensor
scatter() 和 scatter_() 函数功能相同:只不过带下划线的函数,通常是直接修改原来的tensor
原理
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
函数的具体实现,如上述代码框所示:使用src中的值,修改self中位置为index[i][j][k]的值。
举例
# 这是src
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
# index是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]
# self就是下面的torch.zeros(3, 5)
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
# 这是结果
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
在这个例子中,dim=0,是按照i的方向修改torch.zeros(3, 5)的。可以看出,index实际上表达了一种映射关系:同样都是在第j列上,将src在该列的值 根据index在该列的指示 映射到self的这一列上。
比如:src在第0列的0.01940和0.2078,被放到self的第0列上,但不是完全一样的放过来,而是经过index变化了上下位置。其他列同理。
在简单RNN中的应用
该应用代码如下 [1]:
def one_hot(x, n_class, dtype=torch.float32):
# X shape: (batch), output shape: (batch, n_class)
x = x.long()
res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)
res.scatter_(1, x.view(-1, 1), 1)
return res
x = torch.tensor([0, 2])
one_hot(x, vocab_size)
运行结果:
tensor([[1., 0., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.]])
在本例中,该函数的任务是将输入的文本使用one_hot编码。其中,x是一个vector,代表一个二字词语,其中的0和2代表汉字(在程序上文定义的字典中)所对应的数字。vocab_size是字典大小,即在该程序中所考虑的汉字总个数。n_class是one_hot编码中所考虑的类别数,在本例中等于vocab_size。
res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) :生成了两行,vocab_size列的零矩阵
res.scatter_(1, x.view(-1, 1), 1) :在res中,将1,按照dim=1(即不改行改列)的方向,根据[[0],[2]]所指示的位置,放入res中。(比如,x中的0,代表要放入第0列;而0本身处于第0行,所以是第0行中的第0列。)
参考文献
[1] 循环神经网络的从零开始实现. 原书作者:阿斯顿·张、李沐、扎卡里 C. 立顿、亚历山大 J. 斯莫拉以及其他社区贡献者. 原书名称:动手学深度学习Pytorch版
[2] PyTorch官方文档
推荐阅读
-
Pytorch中的scatter_函数
-
PyTorch中scatter()和scatter_()函数的作用
-
【转】Shell中脚本变量和函数变量的作用域 博客分类: linux命令unix shell作用域shelllocal函数变量
-
【转】Shell中脚本变量和函数变量的作用域 博客分类: linux命令unix shell作用域shelllocal函数变量
-
Excel中SEARCH和FIND函数的作用及其区别介绍
-
入口函数的作用,以及原生js和jQuery库中的入口函数的不同
-
多角度让你彻底明白yield语法糖的用法和原理及在C#函数式编程中的作用
-
Pytorch中的permute函数和transpose,contiguous,view函数的关联
-
深入理解php中构造函数和析构函数的作用
-
php中@符号的功用和php函数前的&符号的作用