欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【PyTorch】contiguous==>保证Tensor是连续的,通常transpose、permute 操作后执行 view需要此方法

程序员文章站 2022-06-13 15:19:22
...

目录

PyTorch中的is_contiguous是什么含义?

行优先

为什么需要 contiguous ?


contiguous 本身是形容词表示连续的关于 contiguous,

PyTorch 提供了

  • is_contiguous:用于判定Tensor是否是 contiguous 的,
  • contiguous(形容词动用)两个方法 :保证Tensor是contiguous的。

PyTorch中的is_contiguous是什么含义?

is_contiguous直观的解释是Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致

Tensor多维数组底层实现是使用一块连续内存的1维数组(行优先顺序存储,下文描述),Tensor在元信息里保存了多维数组的形状,在访问元素时,通过多维度索引转化成1维数组相对于数组起始位置的偏移量即可找到对应的数据。

如果想要变得连续使用contiguous方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的,如果Tensor是连续的,则contiguous无操作。

行优先

行是指多维数组一维展开的方式,对应的是列优先。PyTorch中Tensor底层实现是C,也是使用行优先顺序。举例说明如下:

>>> t = torch.arange(12).reshape(3,4)
>>> t
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

二维数组 t 如图1:

【PyTorch】contiguous==>保证Tensor是连续的,通常transpose、permute 操作后执行 view需要此方法

数组 t 在内存中实际以一维数组形式存储,通过 flatten 方法查看 t 的一维展开形式,实际存储形式与一维展开一致,如图2,

>>> t.flatten()
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

【PyTorch】contiguous==>保证Tensor是连续的,通常transpose、permute 操作后执行 view需要此方法

颜色相同的数据表示在同一行,不论是行优先顺序、或是列优先顺序,如果要访问矩阵中的下一个元素都是通过偏移来实现,这个偏移量称为步长(stride)。在行优先的存储方式下,访问行中相邻元素物理结构需要偏移1个位置,在列优先存储方式下偏移3个位置。

为什么需要 contiguous ?

 torch.view等方法操作需要连续的Tensor。

transpose、permute 操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。torch.view 方法约定了不修改数组本身,只是使用新的形状查看数据。如果我们在 transpose、permute 操作后执行 view,Pytorch 会抛出以下错误:

invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension 
spans across two contiguous subspaces). Call .contiguous() before .view(). 
at /Users/soumith/b101_2/2019_02_08/wheel_build_dirs/wheel_3.6/pytorch/aten/src/TH/generic/THTensor.cpp:213

即需要在transpose()、permute() 操作后先执行contiguous()让后再执行 view(),如下:

windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)

原文:https://zhuanlan.zhihu.com/p/64551412

相关标签: 机器学习基础