PyTorch中.view()与.reshape()方法的对比(还有.resize_()方法的一些说明)
torch.Tensor.reshape() VS torch.Tensor.view()
- 相同点:从功能上来看,它们的作用是相同的,都是将原张量元素(按顺序)重组为新的shape。
- 区别在于:
- .view()方法只能改变连续的(contiguous)张量,否则需要先调用.contiguous()方法,而.reshape()方法不受此限制;
- .view()方法返回的张量与原张量共享基础数据(存储器,注意不是共享内存地址,详见代码 ),而.reshape()方法返回的是原张量的copy还是view(即是否跟原张量共享存储),事先是不知道的,如果可以返回view,那么.reshape()方法返回的就是元张量的view,否则返回的就是copy。
–> 因此,为避免语义冲突:
- 如果需要原张量的拷贝(copy),就使用.clone()方法;
- 而如果需要原张量的视图(view),就使用.view()方法;
- 如果想要原张量的视图(view),但是原张量不连续(contiguous),不过原张量拥有兼容的步长(strides),此时可以考虑使用.reshape()函数。
a = torch.randint(0, 10, (3, 4))
"""
Out:
tensor([[3, 7, 1, 3],
[6, 4, 1, 3],
[8, 8, 5, 7]])
"""
b = a.view(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
[1, 3, 8, 8, 5, 7]])
"""
c = a.reshape(2, 6)
"""
Out:
tensor([[3, 7, 1, 3, 6, 4],
[1, 3, 8, 8, 5, 7]])
"""
print(id(a)==id(b), id(a)==id(c), id(b)==id(c))
"""
Out:
False False False
"""
a[0]=0
print(a, b, c)
"""
Out:
tensor([[0, 0, 0, 0],
[6, 4, 1, 3],
[8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
[1, 3, 8, 8, 5, 7]])
tensor([[0, 0, 0, 0, 6, 4],
[1, 3, 8, 8, 5, 7]])
"""
c[0]=1
print(a, b, c)
"""
Out:
tensor([[1, 1, 1, 1],
[1, 1, 1, 3],
[8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
[1, 3, 8, 8, 5, 7]])
tensor([[1, 1, 1, 1, 1, 1],
[1, 3, 8, 8, 5, 7]])
"""
----PS----
**torch.Tensor.resize_()**方法的功能跟.reshape()/.view()方法的功能一样,也是将原张量元素(按顺序)重组为新的shape。
当resize前后的shape兼容时,返回原张量的视图(view);当目标大小(resize后的总元素数)大于当前大小(resize前的总元素数)时,基础存储器的大小将改变(即增大),以适应新的元素数,任何新的内存(新元素值)都是未初始化的;当目标大小(resize后的总元素数)小于当前大小(resize前的总元素数)时,基础存储器的大小保持不变,返回目标大小的元素重组后的张量,未使用的元素仍然保存在存储器中,如果再次resize回原来的大小,这些元素将会被重新使用。
(这里说的shape兼容的意思是:resize前后的shape包含的总元素数是一致的,即resize前后的shape的所有维度的乘积是相同的。如resize前,shape为(1, 2 ,3),那resize之后的张量的总元素数需要是1*2*3,故目标shape可以是(2, 3), 可以是(3, 2, 1),可以是(2, 1, 3)等尺寸。)
–> 文字说明有点干燥,看点例子感受一下:
a = torch.arange(24).view(4, 6)
"""
Out:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
"""
a.resize_(6, 4)
"""
Out:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
"""
a.resize_(3, 3)
"""
Out:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
"""
a.resize_(7, 4)
"""
Out:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[ 12, 13, 14, 15],
[ 16, 17, 18, 19],
[ 20, 21, 22, 23],
[140720147688480, 140720141167152, 1, 0]])
"""
ps(官方解释,不是很能理解): 这是一个底层方法。存储被重新解释为c连续的,忽略当前的步长(除非目标大小等于当前大小,在这种情况下张量保持不变)
更多时候应该使用.view() / .reshape() / .set_()方法来替代此方法
参考文献:
What’s the difference between reshape and view in pytorch?
上一篇: pytorch(三) 常用类型