欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch KL散度学习

程序员文章站 2024-02-11 16:36:58
...

pytorch官方文档中给出了说明pytorch KL散度学习
下面是在学习过程中需要注意的:

  1. KL散度,也叫做相对熵,计算公式如下:

    KL(pq)=P(x)log(P(X)Q(x))KL(p||q) =\sum{P(x)log(\frac{P(X)}{Q(x)})}
    其中P(x)P(x)是真实的分布,是目标;Q(x)Q(x)是拟合分布,是想要改变的分布。KL散度值越小,分布越接近。

  2. 性质

  • KL散度值 \geq 0,当P(x)=Q(x)P(x) = Q(x) 时等号成立。
  • KL散度是非对称的,即KL(pq)KL(qp)KL(p||q)\neq KL(q||p)
  1. 输入参数
  • 输入xx (自己生成的标签)需要经过log_softmax层,把概率分布变换到loglog
  • 输入yy (想要拟合到的目标分布)需要经过softmax层计算
  1. 实现代码
import torch.nn as nn

x = F.log_softmax(x)
y = F.softmax(y, dim=1)

criterion = nn.KLDivLoss()
klloss = criterion(x, y)
相关标签: ML