欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch改变Tensor维度

程序员文章站 2022-06-11 22:09:19
...

在pytorch中,有比较多的函数可以对Tensor的维度进行改变,下面笔者就简单列出一些。

1.torch.squeeze()/torch.unsqueeze()
squeeze函数是对张量的维度进行压缩,去掉维数为1的维度;相反unsqueeze函数是对张量进行维度扩张。

多说无益,赶紧上马:

import torch

#这里随机产生一个多维张量shape = ([3, 1, 2, 1, 4, 1])
a = torch.randn(3,1,2,1,4,1)
b = a.squeeze()
print(b.size())            # output:torch.Size([3, 2, 4])

#对特定的维度进行压缩
c = a.squeeze(1)
print(c.size())          # output: torch.Size([3, 2, 1, 4, 1])

#如果要对两个维度进行压缩,则一定要一个一个进行压缩,不能一次同时压缩
#比如对a的1、3维度压缩,如下。这里不能a.squeeze(1, 3),程序会报错!
d = a.squeeze(1).squeeze(2)  #是在c的第2维度上进行的压缩
print(d.size())         #output:torch.Size([3, 2, 4, 1])


#同样,unsqueeze也类似
# b :torch.Size([3, 2, 4])
e = b.unsqueeze(1)
print(e.size())            #output:torch.Size([3, 1, 2, 4])

f = b.unsqueeze(1).unsqueeze(dim = 3)
print(f.size())            # output : torch.Size([3, 1, 2, 1, 4])

2.torch.view()

把原先Tensor中的数据按照行优先的顺序排成一个一维的数据,然后按照参数组合成其他维度的Tensor。

多说无益,赶紧上马:

import torch
a = torch.randn(2,3,4)
print(a)   # output: tensor([[[-0.8822,  0.6797,  0.5335,  0.1103],
                              [ 0.6791, -2.1690, -0.6625, -0.2989],
                              [-1.3134, -1.1234, -0.7303,  2.1314]],

                             [[-0.8697,  0.8352,  0.9058, -1.2924],
                              [ 1.3043, -0.8773,  0.5054,  0.4219],
                              [-1.0243, -2.5556, -0.6324, -1.6356]]])
b = a.view(1, 24)
print(b)    # output: tensor([[-0.8822,  0.6797,  0.5335,  0.1103,  0.6791, -2.1690, -0.6625, -0.2989,
                          #   -1.3134, -1.1234, -0.7303,  2.1314, -0.8697,  0.8352,  0.9058, -1.2924,
                         #   1.3043, -0.8773,  0.5054,  0.4219, -1.0243, -2.5556, -0.6324, -1.6356]])
print(b.size())    # output : torch.Size([1, 24])

c = a.view(3, 8)  # output : tensor([[-0.8822,  0.6797,  0.5335,  0.1103,  0.6791, -2.1690, -0.6625, -0.2989],
                                 #   [-1.3134, -1.1234, -0.7303,  2.1314, -0.8697,  0.8352,  0.9058, -1.2924],
                                 # [ 1.3043, -0.8773,  0.5054,  0.4219, -1.0243, -2.5556, -0.6324, -1.6356]])
print(c.size())   # output: torch.Size([3, 8])

d = a.view(3, 2, 4)
print(d.size())   # output : torch.Size([3, 2, 4])

e = a.view(3, 2, 2, 2) # output: tensor([[[[-0.8822,  0.6797],
                                           [ 0.5335,  0.1103]],

                                          [[ 0.6791, -2.1690],
                                           [-0.6625, -0.2989]]],


                                         [[[-1.3134, -1.1234],
                                           [-0.7303,  2.1314]],

                                          [[-0.8697,  0.8352],
                                           [ 0.9058, -1.2924]]],


                                         [[[ 1.3043, -0.8773],
                                           [ 0.5054,  0.4219]],

                                          [[-1.0243, -2.5556],
                                           [-0.6324, -1.6356]]]])
print(e.size())  # output: torch.Size([3, 2, 2, 2])

3.permute
将Tensor维度进行换位。

import torch
a = torch.randn(2,3,4)
print(a)   # output: tensor([[[-0.8822,  0.6797,  0.5335,  0.1103],
                              [ 0.6791, -2.1690, -0.6625, -0.2989],
                              [-1.3134, -1.1234, -0.7303,  2.1314]],

                             [[-0.8697,  0.8352,  0.9058, -1.2924],
                              [ 1.3043, -0.8773,  0.5054,  0.4219],
                              [-1.0243, -2.5556, -0.6324, -1.6356]]])
>>> b = a.permute(0,2,1)
>>> print(b)
tensor([[[-0.8822,  0.6791, -1.3134],
         [ 0.6797, -2.1690, -1.1234],
         [ 0.5335, -0.6625, -0.7303],
         [ 0.1103, -0.2989,  2.1314]],

        [[-0.8697,  1.3043, -1.0243],
         [ 0.8352, -0.8773, -2.5556],
         [ 0.9058,  0.5054, -0.6324],
         [-1.2924,  0.4219, -1.6356]]])
>>> print(b.size())
torch.Size([2, 4, 3])

4.stack/cat

torch.stack(sequence, dim=0, out=None),
torch.cat(sequence, dim=0, out=None),

sequence表示Tensor列表,dim表示拼接的维度,stack是建立一个新的维度,然后再在该纬度上进行拼接;而cat是在已有的维度上拼接。

不理解?直接上马:

>>> import torch
>>> t1 = torch.tensor([1,1,1])
>>> t2 = torch.tensor([2,2,2])
>>> t3 = torch.tensor([3,3,3])
>>> torch.cat((t1,t2,t3),dim=0)
tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])

>>> torch.stack((t1,t2,t3), dim=0)
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

>>> torch.cat((t1.unsqueeze(0), t2.unsqueeze(0),t3.unsqueeze(0)),dim=0)
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

>>> torch.stack((t1,t2,t3),dim=1)
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

>>> torch.cat((t1.unsqueeze(1), t2.unsqueeze(1), t3.unsqueeze(1)), dim=1)
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

通过上面的示例可以看出,先使用unsqueeze对Tensor进行维度扩张,然后再cat便可以得到与stack一样的结果。

相关标签: pytorch python