欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch中的 scatter_()函数使用和详解

程序员文章站 2024-03-24 23:35:52
...

scatter(dim, index, src)的三个参数为:

(1)dim:沿着哪个维度进行索引

(2)index: 用来scatter的元素索引

(3)src: 用来scatter的源元素,可以使一个标量也可以是一个张量

官方给的例子为三维情况下的例子:

y = y.scatter(dim,index,src)

#则结果为:
y[ index[i][j][k]  ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ]  = src[i][j][k] # if dim == 2

如果是二维的例子,则应该对应下面的情况:

y = y.scatter(dim,index,src)

#则:
y [ index[i][j] ] [j] = src[i][j] #if dim==0
y[i] [ index[i][j] ]  = src[i][j] #if dim==1 

我们举一个实际的例子:

import torch

x = torch.randn(2,4)
print(x)
y = torch.zeros(3,4)
y = y.scatter_(0,torch.LongTensor([[2,1,2,2],[0,2,1,1]]),x)
print(y)


#结果为:
tensor([[-0.9669, -0.4518,  1.7987,  0.1546],
        [-0.1122, -0.7998,  0.6075,  1.0192]])
tensor([[-0.1122,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.4518,  0.6075,  1.0192],
        [-0.9669, -0.7998,  1.7987,  0.1546]])


'''
scatter后:
y[ index[0][0] ] [0] = src[0][0] -> y[2][0]=-0.9669

y[ index[1][3] ] [3] = src[1][3] -> y[1][3]=1.10192

'''

#如果src为标量,则代表着将对应位置的数值改为src这个标量

那么这个函数有什么作用呢?其实可以利用这个功能将pytorch 中mini batch中的返回的label(特指[ 1,0,4,9 ],即size为[4]这样的label)转为one-hot类型的label,举例子如下:

import torch

mini_batch = 4
out_planes = 6
out_put = torch.rand(mini_batch, out_planes)
softmax = torch.nn.Softmax(dim=1)
out_put = softmax(out_put)

print(out_put)
label = torch.tensor([1,3,3,5])
one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
print(one_hot_label)

上述的这个例子假设是一个分类问题,我设置out_planes=6,是假设总共有6类,mini_batch是我们送入的网络的每个mini_batch的样本数量,这里我们不设置网络,直接假设网络的输出为一个随机的张量 ,通常我们要对这个输出进行softmax归一化,此时就代表着其属于每个类别的概率了。说到这里都不是重点,就是为了方便理解如何使用scatter,将size为[mini_batch]的张量,转为size为[mini_batch, out_palnes]的张量,并且这个生成的张量的每个行向量都是one-hot类型的了。通过看下面的输出结果就完全能够理解了,不理解,给我留言,我给你解释清楚。

tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
        [0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
        [0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
        [0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
tensor([1, 3, 3, 5])
tensor([[0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1.]])

 

相关标签: pytorch

上一篇: LSTM

下一篇: pytorch nn.LSTM()参数详解