欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

解决pytorch中的kl divergence计算问题

程序员文章站 2022-06-24 08:12:59
偶然从pytorch讨论论坛中看到的一个问题,kl divergence different results from tf,kl divergence 在tensorflow中和pytorch中计算...

偶然从pytorch讨论论坛中看到的一个问题,kl divergence different results from tf,kl divergence 在tensorflow中和pytorch中计算结果不同,平时没有注意到,记录下

一篇关于kl散度、js散度以及交叉熵对比的文章

kl divergence 介绍

kl散度( kullback–leibler divergence),又称相对熵,是描述两个概率分布 p 和 q 差异的一种方法。计算公式:

解决pytorch中的kl divergence计算问题

可以发现,p 和 q 中元素的个数不用相等,只需要两个分布中的离散元素一致。

举个简单例子:

两个离散分布分布分别为 p 和 q

p 的分布为:{1,1,2,2,3}

q 的分布为:{1,1,1,1,1,2,3,3,3,3}

我们发现,虽然两个分布中元素个数不相同,p 的元素个数为 5,q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。

当 x = 1时,在 p 分布中,“1” 这个元素的个数为 2,故 p(x = 1) = 2/5 = 0.4,在 q 分布中,“1” 这个元素的个数为 5,故 q(x = 1) = 5/10 = 0.5

同理,

当 x = 2 时,p(x = 2) = 2/5 = 0.4 ,q(x = 2) = 1/10 = 0.1

当 x = 3 时,p(x = 3) = 1/5 = 0.2 ,q(x = 3) = 4/10 = 0.4

把上述概率带入公式:

解决pytorch中的kl divergence计算问题

至此,就计算完成了两个离散变量分布的kl散度。

pytorch 中的 kl_div 函数

pytorch中有用于计算kl散度的函数 kl_div

解决pytorch中的kl divergence计算问题

计算 d (p||q)

1、不用这个函数的计算结果为:

解决pytorch中的kl divergence计算问题

与手算结果相同

2、使用函数:

(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)

解决pytorch中的kl divergence计算问题

注意:

1、函数中的 p q 位置相反(也就是想要计算d(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log

2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和

好别扭的用法,不知道为啥官方把它设计成这样

补充:pytorch 的kl divergence的实现

看代码吧~

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。