欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch中的scatter_函数

程序员文章站 2024-03-24 23:27:34
...

(1). scatter_函数详细描述如下:

scatter_(input,dim,index,value) 
将value对应的值按照index确定的索引写入input张量中,其中索引是根据给定的dim(维度)来确定的。
"""
Args:
input:要进行scatter_填充的tensor
dim:在input张量进行scatter_填充的维度
index:input对应dim的填充索引,要小于对应填充维度的长度,且index维度要与input张量维度一致
value:填充值
"""

(2). 代码实现

import torch
label = torch.zeros(2, 4)
print("label:",label)
label.scatter_(dim=1,index=torch.LongTensor([[2],[3]]),value=1)
print("new_label: ",label)

显示结果:

label: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])
new_label:  tensor([[0., 0., 1., 0.],
        [0., 0., 0., 1.]])
相关标签: scatter_函数