欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch中的permute函数和transpose,contiguous,view函数的关联

程序员文章站 2022-06-13 15:18:58
...

一、前言

在进行深度学习的过程中,经常遇到permute函数,transpose函数,view函数,contiguous函数等,他们起什么作用,之间又有什么联系呢?

二、主要内容

2.1、permute函数和transpose函数

Tensor.permute(a,b,c,d, …):可以对任意高维矩阵进行转置。例子见下:

In[1]: torch.randn(2,3,4,5).permute(3,2,0,1).shape

Out[1]:torch.Size([5, 4, 2, 3])

torch.transpose(Tensor, a,b):只能操作2D矩阵的转置,这是相比于permute的一个不同点;此外,由格式我们可以看出,transpose函数比permute函数多了种调用方式,即torch.transpose(Tensor, a,b)。但是,transpose函数可以通过多次变换达到permute函数的效果。具体见下:

#两种调用方式:
In[1]: t1 = torch.randint(1,10,(2,3,4,5))
	   shape1 = torch.transpose(t1,1,0).shape
	   shape2 = t1.transpose(1,0).shape
	   shape1,shape2
	   
Out[1]:(torch.Size([3, 2, 4, 5]), torch.Size([3, 2, 4, 5]))

#类似permute的效果
In[2]: shape3 = t1.transpose(3,0).transpose(2,1).transpose(3,2).shape
	   shape3
	   
Out[2]:torch.Size([5, 4, 2, 3])	   

2.2 permute函数和view函数

两个函数都是改变tensor的维度,但是区别在于__,具体如下:

#初始化
In[1]: a = torch.randint(1,10,(1,2,3))
 		a_size = a.size()
 		a,a_size
 		
Out[1]:(tensor([[[7, 4, 5],
         		 [9, 5, 6]]]), torch.Size([1, 2, 3]))
         		 
#permute
In[2]:per = a.permute(2,0,1)
	  per_size = per.size()
	   per,per_size
	   
Out[2]:(tensor([[[7, 9]],

        		[[4, 5]],

       			[[5, 6]]]), torch.Size([3, 1, 2]))

#view
In[3]: view = a.view(3,1,2)
	   view_size = view.size()
	   view,view_size

Out[3]:(tensor([[[7, 4]],

        		[[5, 9]],

        		[[5, 6]]]), torch.Size([3, 1, 2]))

#diff
#相信细心的小伙伴已经从两个的output看出来区别了
#具体原因就是在调用permute函数后,数据不再连续,即contiguous,可以继续看3.2

2.3、contiguous函数

contiguous函数起什么作用呢?
当我们在使用transpose或者permute函数之后,tensor数据将会变的不在连续,而此时,如果我们采用view函数等需要tensor数据联系的函数时,将会抛出以下错误:

Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible 
with input tensor's size and stride (at least one dimension spans 
across two contiguous subspaces). Call .contiguous() before 
.view(). at ..\aten\src\TH/generic/THTensor.cpp:203

如果这是使用了contiguous函数,将会解决此错误。

transposepermute操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。view 方法约定了不修改数组本身,只是使用新的形状查看数据,仅在底层数组上使用指定的形状进行变形。示例如下:

#初始化
In[1]:t = torch.arange(12).reshape(3,4)
	  t,t.stride()

Out[1]:tensor([[ 0,  1,  2,  3],
        	   [ 4,  5,  6,  7],
        	   [ 8,  9, 10, 11]]) ,  (4, 1)
#transpose
In[2]:t2 = t.transpose(0,1)
      t2,t2.stride()
  
Out[2]:tensor([[ 0,  4,  8],
        	   [ 1,  5,  9],
        	   [ 2,  6, 10],
        	   [ 3,  7, 11]]) ,  (1, 4)
#对比验证
In[3]:t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
Out[3]:True

In[4]:t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
Out[4]:(True, False)

#即t和t2引用同一份底层数据,如下:
#[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]


#contiguous
In[5]:t3 = t2.contiguous()
      t3

Out[5]:tensor([[ 0,  4,  8],
        	   [ 1,  5,  9],
           	   [ 2,  6, 10],
        	   [ 3,  7, 11]])
        	   
In[6]:t3.data_ptr() == t2.data_ptr() # 底层数据不是同一个一维数组

Out[6]:False


#可以发现 t与t2 底层数据指针一致,t3 与 t2 底层数据指针不一致,说明确实重新开辟了内存空间。

三、结尾

上述只是简单介绍了下其功能和异同,具体原理没有深挖,对于想进一步了解contiguous函数的可以移步下方的参考。

ps:
参考—PyTorch中的contiguous