PyTorch中的叶节点、中间节点、梯度计算等知识点总结
程序员文章站
2024-01-29 21:09:52
...
总结:
按照惯例,所有属性requires_grad=False的张量是叶子节点(即:叶子张量、
叶子节点张量).
对于属性requires_grad=True的张量可能是叶子节点张量也可能不是叶
子节点张量而是中间节点(中间节点张量). 如果该张量的属性requires_grad=True,
而且是用于直接创建的,也即它的属性grad_fn=None,那么它就是叶子节点.
如果该张量的属性requires_grad=True,但是它不是用户直接创建的,而是由其他张量
经过某些运算操作产生的,那么它就不是叶子张量,而是中间节点张量,并且它的属性
grad_fn不是None,比如:grad_fn=<MeanBackward0>,这表示该张量是通过torch.mean()
运算操作产生的,是中间结果,所以是中间节点张量,所以不是叶子节点张量.
判断一个张量是不是叶子节点,可以通过它的属性is_leaf来查看.
一个张量的属性requires_grad用来指示在反向传播时,是否需要为这个张量计算梯度.
如果这个张量的属性requires_grad=False,那么就不需要为这个张量计算梯度,也就
不需要为这个张量进行优化学习.
在PyTorch的运算操作中,如果参加这个运算操作的所有输入张量的属性requires_grad都
是False的话,那么这个运算操作产生的结果,即输出张量的属性requires_grad也是False,
否则是True. 即输入的张量只要有一个需要求梯度(属性requires_grad=True),那么得到的
结果张量也是需要求梯度的(属性requires_grad=True).只有当所有的输入张量都不需要求
梯度时,得到的结果张量才会不需要求梯度.
对于属性requires_grad=True的张量,在反向传播时,会为该张量计算梯度. 但是pytorch的
自动梯度机制不会为中间结果保存梯度,即只会为叶子节点计算的梯度保存起来,保存到该
叶子节点张量的属性grad中,不会在中间节点张量的属性grad中保存这个张量的梯度,这是
出于对效率的考虑,中间节点张量的属性grad是None.如果用户需要为中间节点保存梯度的
话,可以让这个中间节点调用方法retain_grad(),这样梯度就会保存在这个中间节点的grad属性中.