Pytorch:问题记录
程序员文章站
2022-03-27 10:21:40
...
1:Pytorch和Numpy中默认数据类型的区别
来源:https://blog.csdn.net/yyb19951015/article/details/84781434
问题:
RuntimeError: Expected object of type torch.cuda.DoubleTensor but found type torch.cuda.FloatTensor for argument #3 'other'
这种错误是由于数据类型不匹配造成的。这种不匹配可能来自Pytorch各个层之间,
也可能来自于使用Dataset和Dataloader导入来自Numpy的数据。后者有时更难以发现。
原因:
在Numpy中,小数的默认数据类型是np.float,但np.float与np.float64等价;
在Pytorch中,默认数据类型是torch.float,但float与torch.float32等价。
如果不加转换地使用torch.from_numpy,numpy中的数组将会被转换成pytorch中的torch.double类型。
而网络的其他部分如果有torch.float32类型,这就造成了数据类型的不匹配。
解决:
将Numpy的输入数据类型改为np.float32类型即可。
在输入数据加后.astype(np.float32)就可保证两边的数据类型统一。
2:python int转换为numpy.int64
来源:https://www.cnpython.com/qa/38838
分析:
Python的数据类型和numpy中的默认数据类型长度是不同的
解决:
3:Pytorch中Tensor和tensor的区别
来源:https://blog.csdn.net/tfcy694/article/details/85338745
Tensor
torch.Tensor()是python类,更明确地说,是默认张量类型torch.FloatTensor()的别名,
torch.Tensor([1,2])会调用Tensor类的构造函数__init__,生成单精度浮点类型的张量。
>>> a=torch.Tensor([1,2])
>>> a.type()
'torch.FloatTensor'
tensor
torch.tensor()仅仅是python函数:https://pytorch.org/docs/stable/torch.html#torch.tensor ,函数原型是:
torch.tensor(data, dtype=None, device=None, requires_grad=False)
其中data可以是:list, tuple, NumPy ndarray, scalar和其他类型。
torch.tensor会从data中的数据部分做拷贝(而不是直接引用),
根据原始数据类型生成相应的torch.LongTensor、torch.FloatTensor和torch.DoubleTensor。
>>> a=torch.tensor([1,2])
>>> a.type()
'torch.LongTensor'
>>> a=torch.tensor([1.,2.])
>>> a.type()
'torch.FloatTensor'
>>> a=np.zeros(2,dtype=np.float64)
>>> a=torch.tensor(a)
>>> a.type()
'torch.DoubleTensor'
这里再说一下torch.empty(),根据 https://pytorch.org/docs/stable/torch.html?highlight=empty#torch.empty ,我们可以生成指定类型、指定设备以及其他参数的张量,由于torch.Tensor()只能指定数据类型为torch.float,所以torch.Tensor()可以看做torch.empty()的一个特殊情况。
最后放一个小彩蛋
>>> a=torch.tensor(1)
>>> a
tensor(1)
>>> a.type()
'torch.LongTensor'
>>> a=torch.Tensor(1)
>>> a
tensor([0.])
>>> a.type()
'torch.FloatTensor'
我把a=torch.Tensor(1)改成a=torch.Tensor([1]),a就是一个数值为1的tensor了,而不是0
标量1是作为size传入的,向量1是作为value传入的
4:Pytorch中item和data的区别
来源:https://blog.csdn.net/weixin_38316806/article/details/104971419?
.data返回的是一个tensor
.item()返回的是一个具体的数值,注意只能是一个值,适合返回loss,acc
实验:
import torch
a = torch.ones([1,3])
print(a)
print(a.data)
print(a.data[0,1])
print(a.data[0,1].item())
# print(a.item()) 运行该行代码会报错
结果:
下一篇: SSM框架整合
推荐阅读
-
百度影音怎么删除播放痕迹?百度影音清空播放记录的方法图解
-
Mysql5.7中使用group concat函数数据被截断的问题完美解决方法
-
Windows 64 位 mysql 5.7以上版本包解压中没有data目录和my-default.ini及服务无法启动的快速解决办法(问题小结)
-
个人所得税app常见的五大问题及解决方法介绍
-
解决mysql ERROR 1045 (28000)-- Access denied for user问题
-
selenium处理元素定位点击无效问题
-
vs2015/vs2013中mvc5 viewbag总是出现问题该怎么办?
-
企业做SEO优化前需要考虑哪些问题?
-
企业官网SEO优化被忽略的问题 你可能想错了
-
解决Python plt.savefig 保存图片时一片空白的问题