Pytorch的骚操作
程序员文章站
2022-03-23 13:28:24
...
1. Pytorch的detach()、detach_()
detach():
返回一个新的 从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor
import torch
from torch.nn import init
from torch.autograd import Variable
t1 = torch.FloatTensor([1., 2.])
v1 = Variable(t1)
t2 = torch.FloatTensor([2., 3.])
v2 = Variable(t2)
v3 = v1 + v2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached) # v3 中tensor 的值也会改变
detach_():
将 Variable 从创建它的 graph 中分离,把它作为叶子节点
(1)将Variable的grad_fn 设置成None,在BP的时候,到这个Variable就找不到它的grad_fn,所以就不会再往后BP了
(2)将requires_grad设置为False。需要计算梯度的话手动设置requies_grad=True
用处:
假设有2个网络A,B,两个关系是这样的y=A(x),z =B(y)现在想用z.backward()来为B网络的参数来求梯度,但是又不想求A网络的梯度:
第一种方法:
y = A(x)
z = B(y.detach())
z.backward()
第二种方法:
y = A(x)
y.detach_()
z = B(y)
z.backward()
ps:个人感觉像Python的深copy、浅copy,顺着写一下Python的深浅copy
Python的复制、深copy、浅copy
(1)python的直接赋值默认是浅copy,只是传递对象的引用,原始列表改变,子对象也会改变
(2)深copy,包含对象里面的子对象的copy,原对象的改变不会造成深copy里任何子元素的改变
上一篇: C++ set用法总结
下一篇: PHP 字符串转数组与数组转字符串