欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch基础到提高(2)-Tensor(2)

程序员文章站 2022-06-12 10:05:59
...

torch.Tensor 是torch.FloatTensor的别名。
tensor可用Python list或sequence 使用torch.tensor()进行构造。

import torch
x=torch.tensor([[10,20],[2,4]])
print(x)
tensor([[10, 20],
        [ 2,  4]])
import torch
a=[[10.2,20.6],[2,4]]
x=torch.DoubleTensor(a)
print(x)
y=torch.IntTensor(a)
print(y)
tensor([[10.2000, 20.6000],
        [ 2.0000,  4.0000]], dtype=torch.float64)
tensor([[10, 20],
        [ 2,  4]], dtype=torch.int32)
<ipython-input-14-f9db31b67daa>:5: DeprecationWarning: an integer is required (got type float).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.
  y=torch.IntTensor(a)

torch.tensor()始终复制数据。如果您有一个 Tensor数据,并且只想更改它的requires_grad标志,可使用requires_grad_()或 detach() 防止拷贝。
如果您有一个numpy数组并希望避免复制 ,可使用torch.as_tensor()。

import torch
import numpy as np 
a = np.arange(8)
b = a.reshape(4,2)
print (b)
y=torch.torch.as_tensor(b)
print(y)
y[1][1]=55
print(y)
print(b)
[[0 1]
 [2 3]
 [4 5]
 [6 7]]
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
tensor([[ 0,  1],
        [ 2, 55],
        [ 4,  5],
        [ 6,  7]])
[[ 0  1]
 [ 2 55]
 [ 4  5]
 [ 6  7]]
import torch
y=torch.zeros([2, 4], dtype=torch.int32)
print(y)
tensor([[0, 0, 0, 0],
        [0, 0, 0, 0]], dtype=torch.int32)