欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch中scatter_ 的使用详细解读

程序员文章站 2022-03-16 17:18:22
...

先看一个例子:

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 1, 1, 1]]), 2)
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 2., 2., 2.],
        [0., 0., 0., 0., 0.]])

首先是定义了一个3行5列的数组,_scatter中第一个参数0.表示沿着第0轴, 后面第二个参数是坐标,第三个是对应坐标的值,整个意思就是给torch.zeros(3, 5)对应元素赋值,怎么理解呢:

我们看torch.tensor([[0, 1, 1, 1, 1]]), 这个是一个1行5列的数组,【0, 0】的值是0, 【0,1】的值是1,等等,这个地方,数组的值就是torch.zeros(3, 5)的第几行,因为这个地方是沿着第0轴,对应的是行, 那么torch.tensor([[0, 1, 1, 1, 1]]) 对应torch.zeros(3, 5)的坐标就是 【0, 0 】, 【1, 1】, 【1,3】, 【1, 4】,torch.tensor([[0, 1, 1, 1, 1]])这个数组的列就是对应torch.zeros(3, 5)的列,值代表的是他的行,因为他是沿着第0轴,所以如果就是行,

所以因为沿着第0轴,所以torch.tensor([[0, 1, 1, 1, 1]])的列和 torch.zeros(3, 5)列必须一样,都是5, 如果这个地方torch.tensor([[0, 1, 1, 1, 1]]) 列是4,会报错,同时因为torch.tensor([[0, 1, 1, 1, 1]])的值表示行,所以他的值不能大于等于3,不然会报错,因为 torch.zeros(3, 5)索引越界

再来看个例子:

torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), 2)

tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 0., 2., 0.]])

因为沿着第1轴,所以值表示第1轴的索引,所以对应索引是【0, 0】, 【1,1】,【2,3】对应的元素是2,

在看下面:

torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), torch.from_numpy(np.array([[0], [1], [3]], np.float32)))
tensor([[0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 3., 0.]])

第三参数也可以是一个数组,也就是对应的值可以赋值不一样的,但是对应位置是第二个数组控制的。

所以,我们可以拓展到多维度的,

torch.zeros(2, 3, 5).scatter_(0, torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]]), 2)
tensor([[[2., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]]])

沿着第0轴,所以第二个参数的值就是第0轴的坐标,比如第二个参数的第一个值,表示[0,0,0], 第二个是【1, 0, 1】等等

torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]])的shape是[1, 3,5]这个数组的第0轴,表示的是我们可以多赋值,就是也可以他的shape也可以是[2, 3,5], ,[3, 3,5]等,可以理解为是为了方便赋值更多,所以这个数组的第0个维度不做对应索引元素的值,就像上面数组维度为2一样,

所以,给定任意的scatter_函数,都是可以这样理解的。