pytorch中scatter()和scatter_()的作用和区别
程序员文章站
2024-03-24 23:36:10
...
scatter和scatter_函数原型如下
Tensor.scatter_(dim, index, src, reduce=None)->Tensor
scatter(input, dim, index, src)->Tensor
函数作用是将src中的数据按照dim中指定的维度和index中的索引写入self中。
- dim(int) - 操作的维度
- index(LongTensor) - 填充依据的索引,
- src(Tensor of float) - 操作的src数据
- reduce(str, optional) - reduce选择运算方式,有’add’和’mutiply’方式, 默认为替换
dim(int)
在scatter中self指返回的tensor,scatter_中self指输入的tensor自身。
对于一个三维张量,self更新结果如下
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
使用示例
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
dim=0, 说明按照行赋值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素为2
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
dim,说明按照列赋值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素为2
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
scatter的应用, one-hot编码
def one_hot(x, n_class, dtype=torch.float32):
# X shape: (batch), output shape: (batch, n_class)
x=x.long()
res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape为[batch, n_class]全零向量
res.scatter_(1, x.view(-1,1), 1)
# scatter_(input, dim, index, src)将src中数据根据index的索引按照dim的方向填进input中
return res
x=torch.tensor([5,7,0])
one_hot(x, 10)
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
上一篇: Android BLE蓝牙4.0开发详解
下一篇: pytorch——LSTM
推荐阅读
-
pytorch中的 scatter_()函数使用和详解
-
Pytorch中的scatter_函数
-
pytorch中scatter()和scatter_()的作用和区别
-
PyTorch中scatter()和scatter_()函数的作用
-
mysql中engine=innodb和engine=myisam的区别 mysql问题
-
shell 脚本中while循环和for循环的区别
-
Java中replace和replaceAll的区别 博客分类: java java正则表达式
-
Java中replace和replaceAll的区别 博客分类: java java正则表达式
-
java中的引用类型和值类型的区别
-
c语言函数声明中,static inline和extern inline的区别