欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch里面的scatter_

程序员文章站 2024-03-24 23:31:46
...
x=torch.zeros(3,3)
x.scatter_(1,torch.tensor([0,1,2],dtype=torch.int64).view(-1,1),34)
print(x)

输出
tensor([[34.,  0.,  0.],
        [ 0., 34.,  0.],
        [ 0.,  0., 34.]])

功能:在指定维度的指定位置填充整数
x.scatter_(dim,tensor,int)
dim填充的维度
tensor: int64类型,与x同维度
int:填充的数字
再来一个示例

x=torch.zeros(3,3)
x.scatter_(1,torch.tensor([0,1,2,0,1,2],dtype=torch.int64).view(3,-1),34)
print(x)

输出
tensor([[34., 34.,  0.],
        [34.,  0., 34.],
        [ 0., 34., 34.]])
``
还有这个

x=torch.zeros(5,5)
x.scatter_(1,torch.tensor([0,1,2,0,1,2],dtype=torch.int64).view(3,-1),34)
print(torch.tensor([0,1,2,0,1,2],dtype=torch.int64).view(3,-1))
print(x)

输出:
tensor([[0, 1],
[2, 0],
[1, 2]])
tensor([[34., 34., 0., 0., 0.],
[34., 0., 34., 0., 0.],
[ 0., 34., 34., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.]])

记忆小技巧:填1维看0维,填0维看1维,毕竟01还是要搭配着来比较好

相关标签: 机器学习