PyTorch中的scatter()与scatter_()函数
仔细看了PyTorch的文档才搞懂这两个函数
PyTorch: torch.Tensor.scatter
另一个文档: pytorch_scatter
scatter()
这个是scatter_()的out-of-place版本,即函数修改的不是原tensor
在vscode里面看这个函数有两种:
scatter(self: Tensor, dim: _int, index: Tensor, src: Tensor) -> Tensor
# param dim:_int 是让输入第一个参数(?
scatter(self: Tensor, dim: _int, index: Tensor, value: Number) -> Tensor
# param dim:_int
两个的区别在于最后一个参数,可以用Tensor作为src进行填充,也可以指定某个数值作为填充
scatter_()
一句话总结:在一个tensor的基础上,在dim
维上,根据index
选择src
的一些数填到原始的那个tensor里。
对于scatter,向原始tensor填数得到另外一个tensor,原tensor不变;对于scatter_(),就是在原来tensor里填,覆盖。
那么self
这个tensor如何更改呢?
对于一个三维的tensorself
,更改如下:(参见PyTorch文档)
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
如果最后一个参数是value: Number
,就直接在相应位置都填那个数就行。
两种调用方式
tensor_t.scatter_(dim, index, value|src)
其中的tensor_t
是一个Tensor类型的tensor,比如用torch.ones_like(xxx)
或是torch.zeros(...)
创建的Tensor,修改了该tensor。
result = torch.scatter(tensor_t, dim, index, value|src)
其中的tensor_t
是一个Tensor类型的tensor,没有修改该tensor,得到的tensor放到result里面。
一个例子
mask = torch.scatter(
torch.ones_like(A),
1,
torch.arange(A.shape[0]).view(-1, 1),
0
)
A是一个shape[0]==shape[1]的Tensor,得到的mask是一个和A大小一样大的、对角线为0且其它位置全为1的Tensor,并且A不被改变。
上一篇: 使用pytorch的LSTM实现MNIST数据集分类任务
下一篇: 客户端下载文件核心代码