浅谈PyTorch/Numpy中view和copy/clone的区别
前言
学习完 深浅拷贝的区别 (之前写的文章)后,就继续来看看在PyTorch/Numpy中view与copy有什么区别。由于PyTorch是类似于Numpy的可用于GPU加速的计算库,在很多api或概念上都是基本一致的,因此本文对于view和copy的对比分析对两个库都是适用的。
传送门:图文代码浅谈Python中Shallow Copy(浅拷贝)和Deep Copy(深拷贝)的区别
View vs. Copy
1. 什么是view?
view,顾名思义,就是查看的意思,也就是说 用另外一种方式去查看一个数组/张量的数据。对于一个数组/张量来说,它可能是多维度的,暂且就把这些维度信息看成是数组/张量的结构吧,举个例子,假如有个数组/张量A,它的结构是S1,那view执行的操作就是,用同样的基础数据重构一个结构为S2的数组/张量B,A和B共享这些基础数据,再详细一点的话就是,A = [1, 2, 3, 4], B = [[1, 2], [3, 4]],A和B结构不同,但是基础数据是完全一样的。
view也还有(名词)视图的意思,对于一个物体,它可以有各种各样的视图(像初中学的俯视图,仰视图,侧视图等),这些视图给人直观的感受是不同形状的东西,但本质上,它们属于同一个物体。放在PyTorch/Numpy里,道理是相同的,这里view(视图)最贴近的意思就可能是数组/张量的形状shape了。
==》小结一下 view操作
- 重塑数组/张量的shape,形成一个新的视图;
- viewed对象与原始对象共享基础数据;
- viewed对象中的元素是按原始对象元素的顺序排列的。
==》常见viewed对象产生的情景:
- 调用.view()方法时;
- 数组/张量索引切片时(但不包括花式索引(fancy indexing))。
说到共享数据,那其实就跟浅拷贝(shallow copy)很像了,但它们也不完全相同,毕竟view可以改变数组/张量的形状,不过也可以不负责任地说:“PyTorch/Numpy中的view是一种特殊的浅拷贝(shallow copy)或叫其变种”。
view操作在PyTorch中相当于浅拷贝+reshape/resize操作,在Numpy中则还要手动reshape;
import copy
import torch
import numpy as np
# 创建原始tensor/ndarray对象
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 0, 0]]
t0 = torch.tensor(data) # Out: tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9],
# [0, 0, 0]])
a0 = np.array(data) # Out: array([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9],
# [0, 0, 0]])
# view
t1 = t0.view(3, 4) # Out: tensor[[1, 2, 3, 4],
# [5, 6, 7, 8],
# [9, 0, 0, 0]]
a1 = a0.view().reshape(3, 4) # array[[1, 2, 3, 4],
# [5, 6, 7, 8],
# [9, 0, 0, 0]]
print(t1.shape, a1.shape) # Out: torch.Size([3, 4]) (3, 4)
print(id(t1)==id(t0), id(a1)==id(a0)) # False False
print(id(t1[0])==id(t0[0]), id(a1[0])==id(a0[0])) # True True
print(id(t1[0][0])==id(t0[0][0]), id(a1[0][0])==id(a0[0][0])) # True/False True (注意这里第一个我用了/,因为这里每次运行的结果可能是不一样的,原因不明,可能跟数据存储有关)
# copy.copy
t2 = copy.copy(t0) # Out: tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9],
# [0, 0, 0]])
a2 = copy.copy(a0) # Out: array([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9],
# [0, 0, 0]])
print(id(t2)==id(t0), id(a2)==id(a0)) # False False
print(id(t2[0])==id(t0[0]), id(a2[0])==id(a0[0])) # True True
print(id(t2[0][0])==id(t0[0][0]), id(a2[0][0])==id(a0[0][0])) # True/False True (注意这里第一个我用了/,因为这里每次运行的结果可能是不一样的,原因不明,可能跟数据存储有关)
# 改变原始对象的元素
t0[-1] = 999
a0[-1] = 999
print(t0[-1], t1[-1], t2[-1]) # Out: tensor([999, 999, 999]) tensor([ 9, 999, 999, 999]) tensor([999, 999, 999])
print(a0[-1], a1[-1], a2[-1]) # Out: [999 999 999] [ 9 999 999 999] [0 0 0]
t0[0][0] = 666
a0[0][0] = 666
print(t0[0][0], t1[0][0], t2[0][0]) # Out: tensor(666) tensor(666) tensor(666)
print(a0[0][0], a1[0][0], a2[0][0]) # Out: 666 666 1
t2[1] = 0
print(t0[1], t1[1], t2[1]) # Out: tensor([0, 0, 0]) tensor([5, 6, 7, 8]) tensor([0, 0, 0])
a2[1] = 0
print(a0[1], a1[1], a2[1]) # Out: [4 5 6] [5 6 7 8] [0 0 0]
从上面可以看出一个很奇怪的地方,跟python中的浅拷贝操作不一样的是,对Numpy的ndarray执行浅拷贝操作时,浅拷贝的子对象的id都跟原始对象的子对象的id一致,但是两者其一的改变并没有传递给对方,这里我也不懂是为什么,不传递变化的性质就像是深拷贝的副本一样,而子对象id一致又跟浅拷贝操作相同。
2. PyTorch/Numpy中的copy/clone
这部分比较好理解,一句话可以总结,在PyTorch/Numpy中的copy/clone就是深拷贝的完全副本,与我们平时说的复制/拷贝语义是一样的,所以一旦创建副本,copied object与original object是不会互相影响的。