Pytorch中scatter_ 的使用详细解读
先看一个例子:
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 1, 1, 1]]), 2)
tensor([[2., 0., 0., 0., 0.],
[0., 2., 2., 2., 2.],
[0., 0., 0., 0., 0.]])
首先是定义了一个3行5列的数组,_scatter中第一个参数0.表示沿着第0轴, 后面第二个参数是坐标,第三个是对应坐标的值,整个意思就是给torch.zeros(3, 5)对应元素赋值,怎么理解呢:
我们看torch.tensor([[0, 1, 1, 1, 1]]), 这个是一个1行5列的数组,【0, 0】的值是0, 【0,1】的值是1,等等,这个地方,数组的值就是torch.zeros(3, 5)的第几行,因为这个地方是沿着第0轴,对应的是行, 那么torch.tensor([[0, 1, 1, 1, 1]]) 对应torch.zeros(3, 5)的坐标就是 【0, 0 】, 【1, 1】, 【1,3】, 【1, 4】,torch.tensor([[0, 1, 1, 1, 1]])这个数组的列就是对应torch.zeros(3, 5)的列,值代表的是他的行,因为他是沿着第0轴,所以如果就是行,
所以因为沿着第0轴,所以torch.tensor([[0, 1, 1, 1, 1]])的列和 torch.zeros(3, 5)列必须一样,都是5, 如果这个地方torch.tensor([[0, 1, 1, 1, 1]]) 列是4,会报错,同时因为torch.tensor([[0, 1, 1, 1, 1]])的值表示行,所以他的值不能大于等于3,不然会报错,因为 torch.zeros(3, 5)索引越界
再来看个例子:
torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), 2)
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 0., 2., 0.]])
因为沿着第1轴,所以值表示第1轴的索引,所以对应索引是【0, 0】, 【1,1】,【2,3】对应的元素是2,
在看下面:
torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), torch.from_numpy(np.array([[0], [1], [3]], np.float32)))
tensor([[0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 3., 0.]])
第三参数也可以是一个数组,也就是对应的值可以赋值不一样的,但是对应位置是第二个数组控制的。
所以,我们可以拓展到多维度的,
torch.zeros(2, 3, 5).scatter_(0, torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]]), 2)
tensor([[[2., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
[[0., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]])
沿着第0轴,所以第二个参数的值就是第0轴的坐标,比如第二个参数的第一个值,表示[0,0,0], 第二个是【1, 0, 1】等等
torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]])的shape是[1, 3,5]这个数组的第0轴,表示的是我们可以多赋值,就是也可以他的shape也可以是[2, 3,5], ,[3, 3,5]等,可以理解为是为了方便赋值更多,所以这个数组的第0个维度不做对应索引元素的值,就像上面数组维度为2一样,
所以,给定任意的scatter_函数,都是可以这样理解的。
上一篇: Android蓝牙BLE入门
下一篇: 数据库三大范式通俗理解