pytorch中的 scatter_()函数使用和详解
程序员文章站
2024-03-24 23:35:52
...
scatter(dim, index, src)的三个参数为:
(1)dim:沿着哪个维度进行索引
(2)index: 用来scatter的元素索引
(3)src: 用来scatter的源元素,可以使一个标量也可以是一个张量
官方给的例子为三维情况下的例子:
y = y.scatter(dim,index,src)
#则结果为:
y[ index[i][j][k] ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ] = src[i][j][k] # if dim == 2
如果是二维的例子,则应该对应下面的情况:
y = y.scatter(dim,index,src)
#则:
y [ index[i][j] ] [j] = src[i][j] #if dim==0
y[i] [ index[i][j] ] = src[i][j] #if dim==1
我们举一个实际的例子:
import torch
x = torch.randn(2,4)
print(x)
y = torch.zeros(3,4)
y = y.scatter_(0,torch.LongTensor([[2,1,2,2],[0,2,1,1]]),x)
print(y)
#结果为:
tensor([[-0.9669, -0.4518, 1.7987, 0.1546],
[-0.1122, -0.7998, 0.6075, 1.0192]])
tensor([[-0.1122, 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.4518, 0.6075, 1.0192],
[-0.9669, -0.7998, 1.7987, 0.1546]])
'''
scatter后:
y[ index[0][0] ] [0] = src[0][0] -> y[2][0]=-0.9669
y[ index[1][3] ] [3] = src[1][3] -> y[1][3]=1.10192
'''
#如果src为标量,则代表着将对应位置的数值改为src这个标量
那么这个函数有什么作用呢?其实可以利用这个功能将pytorch 中mini batch中的返回的label(特指[ 1,0,4,9 ],即size为[4]这样的label)转为one-hot类型的label,举例子如下:
import torch
mini_batch = 4
out_planes = 6
out_put = torch.rand(mini_batch, out_planes)
softmax = torch.nn.Softmax(dim=1)
out_put = softmax(out_put)
print(out_put)
label = torch.tensor([1,3,3,5])
one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
print(one_hot_label)
上述的这个例子假设是一个分类问题,我设置out_planes=6,是假设总共有6类,mini_batch是我们送入的网络的每个mini_batch的样本数量,这里我们不设置网络,直接假设网络的输出为一个随机的张量 ,通常我们要对这个输出进行softmax归一化,此时就代表着其属于每个类别的概率了。说到这里都不是重点,就是为了方便理解如何使用scatter,将size为[mini_batch]的张量,转为size为[mini_batch, out_palnes]的张量,并且这个生成的张量的每个行向量都是one-hot类型的了。通过看下面的输出结果就完全能够理解了,不理解,给我留言,我给你解释清楚。
tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
[0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
[0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
[0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
tensor([1, 3, 3, 5])
tensor([[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 1.]])
上一篇: LSTM