欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch学习:loss为什么要加item()

程序员文章站 2022-06-15 15:51:36
...

作者:陈诚
链接:https://www.zhihu.com/question/67209417/answer/344752405
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

PyTorch 0.4.0版本去掉了Variable,将Variable和Tensor融合起来,可以视Variable为requires_grad=True的Tensor。其动态原理还是不变。在获取数据的时候也变得更优雅:使用loss += loss.detach()来获取不需要梯度回传的部分。或者使用loss.item()直接获得所对应的python数据类型。============================================================

以下为原回答:算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。应该不止我经历过这个吧…久久不用又会不小心掉到这个坑里去…for data, label in trainloader:

    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里

运行着就发现显存炸了观察了一下发现随着每个batch显存消耗在不断增大…参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/

loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大

那么消耗的显存也就越来越大

总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算… 包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~题外话想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。

大概是这样

z = x + y
z.grad_fn
out:
<AddBackward1 at 0x107286240>

相关标签: 技术栈