Pytorch中的permute函数和transpose,contiguous,view函数的关联
一、前言
在进行深度学习的过程中,经常遇到permute
函数,transpose
函数,view
函数,contiguous
函数等,他们起什么作用,之间又有什么联系呢?
二、主要内容
2.1、permute
函数和transpose
函数
Tensor.permute(a,b,c,d, …):可以对任意高维矩阵进行转置。例子见下:
In[1]: torch.randn(2,3,4,5).permute(3,2,0,1).shape
Out[1]:torch.Size([5, 4, 2, 3])
torch.transpose(Tensor, a,b):只能操作2D矩阵的转置,这是相比于permute
的一个不同点;此外,由格式我们可以看出,transpose
函数比permute
函数多了种调用方式,即torch.transpose(Tensor, a,b)。但是,transpose
函数可以通过多次变换达到permute
函数的效果。具体见下:
#两种调用方式:
In[1]: t1 = torch.randint(1,10,(2,3,4,5))
shape1 = torch.transpose(t1,1,0).shape
shape2 = t1.transpose(1,0).shape
shape1,shape2
Out[1]:(torch.Size([3, 2, 4, 5]), torch.Size([3, 2, 4, 5]))
#类似permute的效果
In[2]: shape3 = t1.transpose(3,0).transpose(2,1).transpose(3,2).shape
shape3
Out[2]:torch.Size([5, 4, 2, 3])
2.2 permute
函数和view
函数
两个函数都是改变tensor的维度,但是区别在于__,具体如下:
#初始化
In[1]: a = torch.randint(1,10,(1,2,3))
a_size = a.size()
a,a_size
Out[1]:(tensor([[[7, 4, 5],
[9, 5, 6]]]), torch.Size([1, 2, 3]))
#permute
In[2]:per = a.permute(2,0,1)
per_size = per.size()
per,per_size
Out[2]:(tensor([[[7, 9]],
[[4, 5]],
[[5, 6]]]), torch.Size([3, 1, 2]))
#view
In[3]: view = a.view(3,1,2)
view_size = view.size()
view,view_size
Out[3]:(tensor([[[7, 4]],
[[5, 9]],
[[5, 6]]]), torch.Size([3, 1, 2]))
#diff
#相信细心的小伙伴已经从两个的output看出来区别了
#具体原因就是在调用permute函数后,数据不再连续,即contiguous,可以继续看3.2
2.3、contiguous
函数
contiguous
函数起什么作用呢?
当我们在使用transpose
或者permute
函数之后,tensor数据将会变的不在连续,而此时,如果我们采用view
函数等需要tensor数据联系的函数时,将会抛出以下错误:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible
with input tensor's size and stride (at least one dimension spans
across two contiguous subspaces). Call .contiguous() before
.view(). at ..\aten\src\TH/generic/THTensor.cpp:203
如果这是使用了contiguous
函数,将会解决此错误。
transpose
、permute
操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。view
方法约定了不修改数组本身,只是使用新的形状查看数据,仅在底层数组上使用指定的形状进行变形。示例如下:
#初始化
In[1]:t = torch.arange(12).reshape(3,4)
t,t.stride()
Out[1]:tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]) , (4, 1)
#transpose
In[2]:t2 = t.transpose(0,1)
t2,t2.stride()
Out[2]:tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]]) , (1, 4)
#对比验证
In[3]:t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
Out[3]:True
In[4]:t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
Out[4]:(True, False)
#即t和t2引用同一份底层数据,如下:
#[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
#contiguous
In[5]:t3 = t2.contiguous()
t3
Out[5]:tensor([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]])
In[6]:t3.data_ptr() == t2.data_ptr() # 底层数据不是同一个一维数组
Out[6]:False
#可以发现 t与t2 底层数据指针一致,t3 与 t2 底层数据指针不一致,说明确实重新开辟了内存空间。
三、结尾
上述只是简单介绍了下其功能和异同,具体原理没有深挖,对于想进一步了解contiguous
函数的可以移步下方的参考。
上一篇: pytorch view()、transpose()和permute()的区别
下一篇: 【PyTorch】contiguous==>保证Tensor是连续的,通常transpose、permute 操作后执行 view需要此方法
推荐阅读
-
对numpy中的transpose和swapaxes函数详解
-
pytorch view()、transpose()和permute()的区别
-
Pytorch中的permute函数和transpose,contiguous,view函数的关联
-
【PyTorch】contiguous==>保证Tensor是连续的,通常transpose、permute 操作后执行 view需要此方法
-
捋清pytorch的transpose、permute、view、reshape、contiguous
-
对numpy中的transpose和swapaxes函数详解
-
php用于反转/交换数组中的键名和对应关联的键值的函数array_flip()
-
php用于反转/交换数组中的键名和对应关联的键值的函数array_flip()