欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch中的scatter()与scatter_()函数

程序员文章站 2024-03-24 23:40:10
...

仔细看了PyTorch的文档才搞懂这两个函数
PyTorch: torch.Tensor.scatter
另一个文档: pytorch_scatter

scatter()

这个是scatter_()的out-of-place版本,即函数修改的不是原tensor
在vscode里面看这个函数有两种:

	scatter(self: Tensor, dim: _int, index: Tensor, src: Tensor) -> Tensor
	# param dim:_int 是让输入第一个参数(?
	scatter(self: Tensor, dim: _int, index: Tensor, value: Number) -> Tensor
	# param dim:_int

两个的区别在于最后一个参数,可以用Tensor作为src进行填充,也可以指定某个数值作为填充

scatter_()

一句话总结:在一个tensor的基础上,在dim维上,根据index选择src的一些数填到原始的那个tensor里。
对于scatter,向原始tensor填数得到另外一个tensor,原tensor不变;对于scatter_(),就是在原来tensor里填,覆盖。
那么self这个tensor如何更改呢?
对于一个三维的tensorself,更改如下:(参见PyTorch文档)

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

如果最后一个参数是value: Number,就直接在相应位置都填那个数就行。

两种调用方式

	tensor_t.scatter_(dim, index, value|src)

其中的tensor_t是一个Tensor类型的tensor,比如用torch.ones_like(xxx)或是torch.zeros(...)创建的Tensor,修改了该tensor。

	result = torch.scatter(tensor_t, dim, index, value|src)

其中的tensor_t是一个Tensor类型的tensor,没有修改该tensor,得到的tensor放到result里面。

一个例子

	mask = torch.scatter(
		torch.ones_like(A),
        1,
        torch.arange(A.shape[0]).view(-1, 1),
        0
    )

A是一个shape[0]==shape[1]的Tensor,得到的mask是一个和A大小一样大的、对角线为0且其它位置全为1的Tensor,并且A不被改变。