pytorch KL散度学习
程序员文章站
2024-02-11 16:36:58
...
pytorch官方文档中给出了说明
下面是在学习过程中需要注意的:
-
KL散度,也叫做相对熵,计算公式如下:
其中是真实的分布,是目标;是拟合分布,是想要改变的分布。KL散度值越小,分布越接近。 -
性质
- KL散度值 0,当 时等号成立。
- KL散度是非对称的,即
- 输入参数
- 输入 (自己生成的标签)需要经过log_softmax层,把概率分布变换到上
- 输入 (想要拟合到的目标分布)需要经过softmax层计算
- 实现代码
import torch.nn as nn
x = F.log_softmax(x)
y = F.softmax(y, dim=1)
criterion = nn.KLDivLoss()
klloss = criterion(x, y)
下一篇: 决策树 鸢尾花分类 数据挖掘Python