欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch中.view()与.reshape()方法的对比(还有.resize_()方法的一些说明)

程序员文章站 2022-03-21 19:46:54
...

torch.Tensor.reshape() VS torch.Tensor.view()

  • 相同点:从功能上来看,它们的作用是相同的,都是将原张量元素(按顺序)重组为新的shape。
  • 区别在于:
    • .view()方法只能改变连续的(contiguous)张量,否则需要先调用.contiguous()方法,而.reshape()方法不受此限制;
    • .view()方法返回的张量与原张量共享基础数据(存储器,注意不是共享内存地址,详见代码 ),而.reshape()方法返回的是原张量的copy还是view(即是否跟原张量共享存储),事先是不知道的,如果可以返回view,那么.reshape()方法返回的就是元张量的view,否则返回的就是copy。

–> 因此,为避免语义冲突:

  1. 如果需要原张量的拷贝(copy),就使用.clone()方法;
  2. 而如果需要原张量的视图(view),就使用.view()方法;
  3. 如果想要原张量的视图(view),但是原张量不连续(contiguous),不过原张量拥有兼容的步长(strides),此时可以考虑使用.reshape()函数。
a = torch.randint(0, 10, (3, 4))
"""
Out:
tensor([[3, 7, 1, 3],
        [6, 4, 1, 3],
        [8, 8, 5, 7]])
"""

b = a.view(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

c = a.reshape(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

print(id(a)==id(b), id(a)==id(c), id(b)==id(c))
"""
Out:
False False False
"""

a[0]=0
print(a, b, c)
"""
Out:
tensor([[0, 0, 0, 0],
        [6, 4, 1, 3],
        [8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
        [1, 3, 8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
        [1, 3, 8, 8, 5, 7]])
"""

c[0]=1
print(a, b, c)
"""
Out:
tensor([[1, 1, 1, 1],
        [1, 1, 1, 3],
        [8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
        [1, 3, 8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
        [1, 3, 8, 8, 5, 7]])
"""

----PS----

**torch.Tensor.resize_()**方法的功能跟.reshape()/.view()方法的功能一样,也是将原张量元素(按顺序)重组为新的shape。

当resize前后的shape兼容时,返回原张量的视图(view);当目标大小(resize后的总元素数)大于当前大小(resize前的总元素数)时,基础存储器的大小将改变(即增大),以适应新的元素数,任何新的内存(新元素值)都是未初始化的;当目标大小(resize后的总元素数)小于当前大小(resize前的总元素数)时,基础存储器的大小保持不变,返回目标大小的元素重组后的张量,未使用的元素仍然保存在存储器中,如果再次resize回原来的大小,这些元素将会被重新使用。

(这里说的shape兼容的意思是:resize前后的shape包含的总元素数是一致的,即resize前后的shape的所有维度的乘积是相同的。如resize前,shape为(1, 2 ,3),那resize之后的张量的总元素数需要是1*2*3,故目标shape可以是(2, 3), 可以是(3, 2, 1),可以是(2, 1, 3)等尺寸。)

–> 文字说明有点干燥,看点例子感受一下:

a = torch.arange(24).view(4, 6)
"""
Out:
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])
"""

a.resize_(6, 4)
"""
Out:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])
"""

a.resize_(3, 3)
"""
Out:
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
"""

a.resize_(7, 4)
"""
Out:
tensor([[              0,               1,               2,               3],
        [              4,               5,               6,               7],
        [              8,               9,              10,              11],
        [             12,              13,              14,              15],
        [             16,              17,              18,              19],
        [             20,              21,              22,              23],
        [140720147688480, 140720141167152,               1,               0]])
"""

ps(官方解释,不是很能理解): 这是一个底层方法。存储被重新解释为c连续的,忽略当前的步长(除非目标大小等于当前大小,在这种情况下张量保持不变)

更多时候应该使用.view() / .reshape() / .set_()方法来替代此方法


参考文献:
What’s the difference between reshape and view in pytorch?