Pytorch中的scatter_函数
程序员文章站
2024-03-24 23:27:34
...
(1). scatter_函数详细描述如下:
scatter_(input,dim,index,value)
将value对应的值按照index确定的索引写入input张量中,其中索引是根据给定的dim(维度)来确定的。
"""
Args:
input:要进行scatter_填充的tensor
dim:在input张量进行scatter_填充的维度
index:input对应dim的填充索引,要小于对应填充维度的长度,且index维度要与input张量维度一致
value:填充值
"""
(2). 代码实现
import torch
label = torch.zeros(2, 4)
print("label:",label)
label.scatter_(dim=1,index=torch.LongTensor([[2],[3]]),value=1)
print("new_label: ",label)
显示结果:
label: tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
new_label: tensor([[0., 0., 1., 0.],
[0., 0., 0., 1.]])
推荐阅读
-
Pytorch中的scatter_函数
-
Pytorch---之scatter_ 理解轴的含义
-
【Pytorch】scatter_ 理解轴的含义
-
pytorch 中的torch.nn.LSTM函数
-
[概念]医学图像分割中常用的Loss function(损失函数) + 从loss处理图像分割中类别极度不均衡
-
自定义JSP中的Taglib标签之四自定义标签中的Function函数 博客分类: Java
-
Python中的bytearray()和bytes()函数
-
c语言函数声明中,static inline和extern inline的区别
-
关于MFC中AfxGetApp函数是怎么得到全局对象的指针的简要分析
-
ServerSocket构造函数中backlog参数的含义,可以接受客户端的数量