Expected object of scalar type Long but got scalar type Double for argument #2 'target'
程序员文章站
2022-06-13 08:00:18
...
1.pytorch报错:
loss_class = torch.nn.CrossEntropyLoss()
s_data, s_label = data_source[0].to(DEVICE), data_source[1].to(DEVICE)
class_output, domain_output = model(input_data=s_data.float(), alpha=alpha)
# 报错位置如下:
err_s_label = loss_class(class_output, s_label)
报错内容如下:
Expected object of scalar type Long but got scalar type Double for argument #2 'target'
表示第二个位置的参数要求是Long类型,然而传入的时候是Double类型,因此我们只需:
s_label.long()
即可。
2. 如果会继续出现报错:
RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15
表示在计算loss过程中遇到了多输出预测值。或者标签的维度是不同的(我这边标签的shape是(128, 1)),我们只需要将标签squeeze就行,具体参考torch.squeeze()函数,个人变动方法:
err_s_label = loss_class(class_output, s_label.squeeze(1).long())
3.总结:torch中数据类型的变化:
数据类型 | 数据长度 | 用法 |
int | int32 | torch.int() |
long | int64 | torch.long() |
float | float32 | torch.float() |
double | float64 | torch.double() |
至于需要哪种变化,各位看官,请便。。。
上一篇: 蓝玉死前念了首诗,在场官员为何全被斩杀?
下一篇: Expected object of scalar type Float but got scalar type Long for argument #2 'target'
推荐阅读
-
报错:Expected object of scalar type Float but got scalar type Long for argument #2 'target'
-
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got CUDAType
-
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'
-
RuntimeError: Expected tensor for argument #1 ‘indices‘ to have scalar type Long; but got CUDAFloatT
-
RuntimeError: Expected tensor for argument #1 ‘indices‘ to have scalar type Long; but got torch.IntT
-
Expected tensor for argument #1 ‘indices‘ to have scalar type Long; but got torch.cuda.FloatTensor i
-
Expected object of scalar type Long but got scalar type Int for argument #2 'target'
-
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'
-
Expected object of scalar type Long but got scalar type Double for argument #2 'target'
-
Expected object of scalar type Float but got scalar type Long for argument #2 'target'