Pytorch改变Tensor维度
程序员文章站
2022-06-11 22:09:19
...
在pytorch中,有比较多的函数可以对Tensor的维度进行改变,下面笔者就简单列出一些。
1.torch.squeeze()/torch.unsqueeze()
squeeze函数是对张量的维度进行压缩,去掉维数为1的维度;相反unsqueeze函数是对张量进行维度扩张。
多说无益,赶紧上马:
import torch
#这里随机产生一个多维张量shape = ([3, 1, 2, 1, 4, 1])
a = torch.randn(3,1,2,1,4,1)
b = a.squeeze()
print(b.size()) # output:torch.Size([3, 2, 4])
#对特定的维度进行压缩
c = a.squeeze(1)
print(c.size()) # output: torch.Size([3, 2, 1, 4, 1])
#如果要对两个维度进行压缩,则一定要一个一个进行压缩,不能一次同时压缩
#比如对a的1、3维度压缩,如下。这里不能a.squeeze(1, 3),程序会报错!
d = a.squeeze(1).squeeze(2) #是在c的第2维度上进行的压缩
print(d.size()) #output:torch.Size([3, 2, 4, 1])
#同样,unsqueeze也类似
# b :torch.Size([3, 2, 4])
e = b.unsqueeze(1)
print(e.size()) #output:torch.Size([3, 1, 2, 4])
f = b.unsqueeze(1).unsqueeze(dim = 3)
print(f.size()) # output : torch.Size([3, 1, 2, 1, 4])
2.torch.view()
把原先Tensor中的数据按照行优先的顺序排成一个一维的数据,然后按照参数组合成其他维度的Tensor。
多说无益,赶紧上马:
import torch
a = torch.randn(2,3,4)
print(a) # output: tensor([[[-0.8822, 0.6797, 0.5335, 0.1103],
[ 0.6791, -2.1690, -0.6625, -0.2989],
[-1.3134, -1.1234, -0.7303, 2.1314]],
[[-0.8697, 0.8352, 0.9058, -1.2924],
[ 1.3043, -0.8773, 0.5054, 0.4219],
[-1.0243, -2.5556, -0.6324, -1.6356]]])
b = a.view(1, 24)
print(b) # output: tensor([[-0.8822, 0.6797, 0.5335, 0.1103, 0.6791, -2.1690, -0.6625, -0.2989,
# -1.3134, -1.1234, -0.7303, 2.1314, -0.8697, 0.8352, 0.9058, -1.2924,
# 1.3043, -0.8773, 0.5054, 0.4219, -1.0243, -2.5556, -0.6324, -1.6356]])
print(b.size()) # output : torch.Size([1, 24])
c = a.view(3, 8) # output : tensor([[-0.8822, 0.6797, 0.5335, 0.1103, 0.6791, -2.1690, -0.6625, -0.2989],
# [-1.3134, -1.1234, -0.7303, 2.1314, -0.8697, 0.8352, 0.9058, -1.2924],
# [ 1.3043, -0.8773, 0.5054, 0.4219, -1.0243, -2.5556, -0.6324, -1.6356]])
print(c.size()) # output: torch.Size([3, 8])
d = a.view(3, 2, 4)
print(d.size()) # output : torch.Size([3, 2, 4])
e = a.view(3, 2, 2, 2) # output: tensor([[[[-0.8822, 0.6797],
[ 0.5335, 0.1103]],
[[ 0.6791, -2.1690],
[-0.6625, -0.2989]]],
[[[-1.3134, -1.1234],
[-0.7303, 2.1314]],
[[-0.8697, 0.8352],
[ 0.9058, -1.2924]]],
[[[ 1.3043, -0.8773],
[ 0.5054, 0.4219]],
[[-1.0243, -2.5556],
[-0.6324, -1.6356]]]])
print(e.size()) # output: torch.Size([3, 2, 2, 2])
3.permute
将Tensor维度进行换位。
import torch
a = torch.randn(2,3,4)
print(a) # output: tensor([[[-0.8822, 0.6797, 0.5335, 0.1103],
[ 0.6791, -2.1690, -0.6625, -0.2989],
[-1.3134, -1.1234, -0.7303, 2.1314]],
[[-0.8697, 0.8352, 0.9058, -1.2924],
[ 1.3043, -0.8773, 0.5054, 0.4219],
[-1.0243, -2.5556, -0.6324, -1.6356]]])
>>> b = a.permute(0,2,1)
>>> print(b)
tensor([[[-0.8822, 0.6791, -1.3134],
[ 0.6797, -2.1690, -1.1234],
[ 0.5335, -0.6625, -0.7303],
[ 0.1103, -0.2989, 2.1314]],
[[-0.8697, 1.3043, -1.0243],
[ 0.8352, -0.8773, -2.5556],
[ 0.9058, 0.5054, -0.6324],
[-1.2924, 0.4219, -1.6356]]])
>>> print(b.size())
torch.Size([2, 4, 3])
4.stack/cat
torch.stack(sequence, dim=0, out=None),
torch.cat(sequence, dim=0, out=None),
sequence表示Tensor列表,dim表示拼接的维度,stack是建立一个新的维度,然后再在该纬度上进行拼接;而cat是在已有的维度上拼接。
不理解?直接上马:
>>> import torch
>>> t1 = torch.tensor([1,1,1])
>>> t2 = torch.tensor([2,2,2])
>>> t3 = torch.tensor([3,3,3])
>>> torch.cat((t1,t2,t3),dim=0)
tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])
>>> torch.stack((t1,t2,t3), dim=0)
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
>>> torch.cat((t1.unsqueeze(0), t2.unsqueeze(0),t3.unsqueeze(0)),dim=0)
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
>>> torch.stack((t1,t2,t3),dim=1)
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
>>> torch.cat((t1.unsqueeze(1), t2.unsqueeze(1), t3.unsqueeze(1)), dim=1)
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
通过上面的示例可以看出,先使用unsqueeze对Tensor进行维度扩张,然后再cat便可以得到与stack一样的结果。