欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch的骚操作

程序员文章站 2022-03-23 13:28:24
...

1. Pytorch的detach()、detach_()

detach():
返回一个新的 从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor

import torch
from torch.nn import init
from torch.autograd import Variable
t1 = torch.FloatTensor([1., 2.])
v1 = Variable(t1)
t2 = torch.FloatTensor([2., 3.])
v2 = Variable(t2)
v3 = v1 + v2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached)    # v3 中tensor 的值也会改变

detach_():
将 Variable 从创建它的 graph 中分离,把它作为叶子节点
(1)将Variable的grad_fn 设置成None,在BP的时候,到这个Variable就找不到它的grad_fn,所以就不会再往后BP了
(2)将requires_grad设置为False。需要计算梯度的话手动设置requies_grad=True

用处:

假设有2个网络A,B,两个关系是这样的y=A(x),z =B(y)现在想用z.backward()来为B网络的参数来求梯度,但是又不想求A网络的梯度:

第一种方法:
y = A(x)
z = B(y.detach())
z.backward()
第二种方法:
y = A(x)
y.detach_()
z = B(y)
z.backward()

ps:个人感觉像Python的深copy、浅copy,顺着写一下Python的深浅copy

Python的复制、深copy、浅copy

(1)python的直接赋值默认是浅copy,只是传递对象的引用,原始列表改变,子对象也会改变
(2)深copy,包含对象里面的子对象的copy,原对象的改变不会造成深copy里任何子元素的改变

相关标签: Pytorch